Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| fca2b1449f | |||
| 01f6003d35 | |||
| d6d6bc3bc8 | |||
| 2e928310cf | |||
| 40a39a0f1a | |||
| dd4e0c24a8 | |||
| 52d1bc5cf4 |
8
.gitignore
vendored
8
.gitignore
vendored
@@ -205,3 +205,11 @@ cython_debug/
|
|||||||
marimo/_static/
|
marimo/_static/
|
||||||
marimo/_lsp/
|
marimo/_lsp/
|
||||||
__marimo__/
|
__marimo__/
|
||||||
|
|
||||||
|
# Other
|
||||||
|
configuration/
|
||||||
|
logs/
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Test
|
||||||
|
localstore/
|
||||||
|
|||||||
18
pyproject.toml
Normal file
18
pyproject.toml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
[project]
|
||||||
|
name = "service"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "This project is the web servcie sub-system for if.u projuect"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"alibabacloud-ocr-api20210707>=3.1.3",
|
||||||
|
"chromadb>=1.1.1",
|
||||||
|
"fastapi>=0.118.2",
|
||||||
|
"langchain>=0.3.27",
|
||||||
|
"langchain-openai>=0.3.35",
|
||||||
|
"numpy>=2.3.3",
|
||||||
|
"pymysql>=1.1.2",
|
||||||
|
"python-multipart>=0.0.20",
|
||||||
|
"qiniu>=7.17.0",
|
||||||
|
"requests>=2.32.5",
|
||||||
|
]
|
||||||
0
src/ai/__init__.py
Normal file
0
src/ai/__init__.py
Normal file
54
src/ai/agent.py
Normal file
54
src/ai/agent.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
from models.people import People
|
||||||
|
|
||||||
|
class BaseAgent:
|
||||||
|
def __init__(self):
|
||||||
|
self.llm = ChatOpenAI(
|
||||||
|
openai_api_key="56d82040-85c7-4701-8f87-734985e27909",
|
||||||
|
openai_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||||
|
model_name="ep-20250722161445-n9lfq"
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ExtractPeopleAgent(BaseAgent):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.prompt = ChatPromptTemplate.from_messages([
|
||||||
|
(
|
||||||
|
"system",
|
||||||
|
"你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n"
|
||||||
|
"姓名 name\n"
|
||||||
|
"性别 gender\n"
|
||||||
|
"年龄 age\n"
|
||||||
|
"身高(cm) height\n"
|
||||||
|
"婚姻状况 marital_status\n"
|
||||||
|
"择偶要求 match_requirement\n"
|
||||||
|
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n"
|
||||||
|
"除了上述基本信息,还有一个字段\n"
|
||||||
|
"个人介绍 introduction\n"
|
||||||
|
"其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n"
|
||||||
|
"个人介绍的字典的 key 需要使用提炼好的中文。\n"
|
||||||
|
),
|
||||||
|
("human", "{input}")
|
||||||
|
])
|
||||||
|
|
||||||
|
def extract_people_info(self, text: str) -> People:
|
||||||
|
"""从文本中提取个人信息"""
|
||||||
|
prompt = self.prompt.format_prompt(input=text)
|
||||||
|
response = self.llm.invoke(prompt)
|
||||||
|
logging.info(f"llm response: {response.content}")
|
||||||
|
try:
|
||||||
|
return People.from_dict(json.loads(response.content))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
|
||||||
|
return None
|
||||||
|
pass
|
||||||
|
|
||||||
|
class SummaryPeopleAgent(BaseAgent):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
pass
|
||||||
0
src/app/__init__.py
Normal file
0
src/app/__init__.py
Normal file
130
src/app/api.py
Normal file
130
src/app/api.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
from fastapi import FastAPI, File, UploadFile, Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ai.agent import ExtractPeopleAgent
|
||||||
|
from models.people import People
|
||||||
|
from utils import obs, ocr, vsdb
|
||||||
|
from storage.people_store import get_instance as get_people_store
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
|
||||||
|
api.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
class BaseResponse(BaseModel):
|
||||||
|
error_code: int
|
||||||
|
error_info: str
|
||||||
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
class PostInputRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
@api.post("/input")
|
||||||
|
async def post_input(request: PostInputRequest):
|
||||||
|
extra_agent = ExtractPeopleAgent()
|
||||||
|
people = extra_agent.extract_people_info(request.text)
|
||||||
|
logging.info(f"people: {people}")
|
||||||
|
resp = BaseResponse(error_code=0, error_info="success")
|
||||||
|
resp.data = people.to_dict()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@api.post("/input_image")
|
||||||
|
async def post_input_image(image: UploadFile = File(...)):
|
||||||
|
# 实现上传图片的处理
|
||||||
|
# 保存上传的图片文件
|
||||||
|
# 生成唯一的文件名
|
||||||
|
file_extension = os.path.splitext(image.filename)[1]
|
||||||
|
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||||
|
|
||||||
|
# 确保uploads目录存在
|
||||||
|
os.makedirs("uploads", exist_ok=True)
|
||||||
|
|
||||||
|
# 保存文件到对象存储
|
||||||
|
file_path = f"uploads/{unique_filename}"
|
||||||
|
obs_util = obs.get_instance()
|
||||||
|
obs_util.Put(file_path, await image.read())
|
||||||
|
|
||||||
|
# 获取对象存储外链
|
||||||
|
obs_url = obs_util.Link(file_path)
|
||||||
|
logging.info(f"obs_url: {obs_url}")
|
||||||
|
|
||||||
|
# 调用OCR处理图片
|
||||||
|
ocr_util = ocr.get_instance()
|
||||||
|
ocr_result = ocr_util.recognize_image_text(obs_url)
|
||||||
|
logging.info(f"ocr_result: {ocr_result}")
|
||||||
|
|
||||||
|
post_input_request = PostInputRequest(text=ocr_result)
|
||||||
|
return await post_input(post_input_request)
|
||||||
|
|
||||||
|
class PostPeopleRequest(BaseModel):
|
||||||
|
people: dict
|
||||||
|
|
||||||
|
@api.post("/peoples")
|
||||||
|
async def post_people(post_people_request: PostPeopleRequest):
|
||||||
|
logging.debug(f"post_people_request: {post_people_request}")
|
||||||
|
people = People.from_dict(post_people_request.people)
|
||||||
|
store = get_people_store()
|
||||||
|
people.id = store.save(people)
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=people.id)
|
||||||
|
|
||||||
|
class GetPeopleRequest(BaseModel):
|
||||||
|
query: Optional[str] = None
|
||||||
|
conds: Optional[dict] = None
|
||||||
|
top_k: int = 5
|
||||||
|
|
||||||
|
@api.get("/peoples")
|
||||||
|
async def get_peoples(
|
||||||
|
name: Optional[str] = Query(None, description="姓名"),
|
||||||
|
gender: Optional[str] = Query(None, description="性别"),
|
||||||
|
age: Optional[int] = Query(None, description="年龄"),
|
||||||
|
height: Optional[int] = Query(None, description="身高"),
|
||||||
|
marital_status: Optional[str] = Query(None, description="婚姻状态"),
|
||||||
|
limit: int = Query(10, description="分页大小"),
|
||||||
|
offset: int = Query(0, description="分页偏移量"),
|
||||||
|
search: Optional[str] = Query(None, description="搜索内容"),
|
||||||
|
top_k: int = Query(5, description="搜索结果数量"),
|
||||||
|
):
|
||||||
|
|
||||||
|
# 解析查询参数为字典
|
||||||
|
conds = {}
|
||||||
|
if name:
|
||||||
|
conds["name"] = name
|
||||||
|
if gender:
|
||||||
|
conds["gender"] = gender
|
||||||
|
if age:
|
||||||
|
conds["age"] = age
|
||||||
|
if height:
|
||||||
|
conds["height"] = height
|
||||||
|
if marital_status:
|
||||||
|
conds["marital_status"] = marital_status
|
||||||
|
|
||||||
|
logging.info(f"conds: , limit: {limit}, offset: {offset}, search: {search}, top_k: {top_k}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
store = get_people_store()
|
||||||
|
if search:
|
||||||
|
results = store.search(search, conds, ids=None, top_k=top_k)
|
||||||
|
logging.info(f"search results: {results}")
|
||||||
|
else:
|
||||||
|
results = store.query(conds, limit=limit, offset=offset)
|
||||||
|
logging.info(f"query results: {results}")
|
||||||
|
peoples = [people.to_dict() for people in results]
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=peoples)
|
||||||
|
|
||||||
|
@api.delete("/peoples/{people_id}")
|
||||||
|
async def delete_people(people_id: str):
|
||||||
|
store = get_people_store()
|
||||||
|
store.delete(people_id)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
52
src/app/app.py
Normal file
52
src/app/app.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# created by mmmy on 2025-09-27
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from ai.agent import ExtractPeopleAgent
|
||||||
|
from utils import ocr, vsdb, obs
|
||||||
|
from models.people import People
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
|
||||||
|
|
||||||
|
class App:
|
||||||
|
def __init__(self):
|
||||||
|
self.extract_people_agent = ExtractPeopleAgent()
|
||||||
|
self.ocr = ocr.get_instance()
|
||||||
|
self.vedb = vsdb.get_instance(db_type='chromadb')
|
||||||
|
self.obs = obs.get_instance()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class InputForPeople:
|
||||||
|
image: bytes
|
||||||
|
text: str
|
||||||
|
|
||||||
|
def input_to_people(self, input: InputForPeople) -> People:
|
||||||
|
if not input.image and not input.text:
|
||||||
|
return None
|
||||||
|
if input.image:
|
||||||
|
content = self.ocr.recognize_image_text(input.image)
|
||||||
|
else:
|
||||||
|
content = input.text
|
||||||
|
print(content)
|
||||||
|
people = self._trans_text_to_people(content)
|
||||||
|
return people
|
||||||
|
|
||||||
|
def _trans_text_to_people(self, text: str) -> People:
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
person = self.extract_people_agent.extract_people_info(text)
|
||||||
|
print(person)
|
||||||
|
return person
|
||||||
|
|
||||||
|
def create_people(self, people: People) -> bool:
|
||||||
|
if not people:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
people.save()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"保存人物到数据库失败: {e}")
|
||||||
|
return False
|
||||||
|
return True
|
||||||
30
src/main.py
Normal file
30
src/main.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# created by mmmy on 2025-09-27
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
from venv import logger
|
||||||
|
import uvicorn
|
||||||
|
from app.api import api
|
||||||
|
from utils import obs, ocr, vsdb, logger, config
|
||||||
|
from storage import people_store
|
||||||
|
|
||||||
|
# 主函数
|
||||||
|
def main():
|
||||||
|
main_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parser = argparse.ArgumentParser(description='IF.u 服务')
|
||||||
|
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
|
||||||
|
args = parser.parse_args()
|
||||||
|
config.init(args.config)
|
||||||
|
logger.init()
|
||||||
|
obs.init()
|
||||||
|
ocr.init()
|
||||||
|
vsdb.init()
|
||||||
|
people_store.init()
|
||||||
|
conf = config.get_instance()
|
||||||
|
host = conf.get('web_service', 'server_host', fallback='127.0.0.1')
|
||||||
|
port = conf.getint('web_service', 'server_port', fallback=8099)
|
||||||
|
uvicorn.run(api, host=host, port=port)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
src/models/__init__.py
Normal file
0
src/models/__init__.py
Normal file
121
src/models/people.py
Normal file
121
src/models/people.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# created by mmmy on 2025-09-30
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
class People:
|
||||||
|
# 数据库 ID
|
||||||
|
id: str
|
||||||
|
# 姓名
|
||||||
|
name: str
|
||||||
|
# 联系人
|
||||||
|
contact: str
|
||||||
|
# 性别
|
||||||
|
gender: str
|
||||||
|
# 年龄
|
||||||
|
age: int
|
||||||
|
# 身高(cm)
|
||||||
|
height: int
|
||||||
|
# 体重(kg)
|
||||||
|
# weight: int
|
||||||
|
# 婚姻状况
|
||||||
|
# [
|
||||||
|
# "未婚(single)",
|
||||||
|
# "已婚(married)",
|
||||||
|
# "离异(divorced)",
|
||||||
|
# "丧偶(widowed)"
|
||||||
|
# ]
|
||||||
|
marital_status: str
|
||||||
|
# 择偶要求
|
||||||
|
match_requirement: str
|
||||||
|
# 个人介绍
|
||||||
|
introduction: Dict[str, str]
|
||||||
|
# 总结评价
|
||||||
|
comments: Dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# 初始化所有属性,从kwargs中获取值,如果不存在则设置默认值
|
||||||
|
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
|
||||||
|
self.name = kwargs.get('name', '') if kwargs.get('name', '') is not None else ''
|
||||||
|
self.contact = kwargs.get('contact', '') if kwargs.get('contact', '') is not None else ''
|
||||||
|
self.gender = kwargs.get('gender', '') if kwargs.get('gender', '') is not None else ''
|
||||||
|
self.age = kwargs.get('age', 0) if kwargs.get('age', 0) is not None else 0
|
||||||
|
self.height = kwargs.get('height', 0) if kwargs.get('height', 0) is not None else 0
|
||||||
|
self.marital_status = kwargs.get('marital_status', '') if kwargs.get('marital_status', '') is not None else ''
|
||||||
|
self.match_requirement = kwargs.get('match_requirement', '') if kwargs.get('match_requirement', '') is not None else ''
|
||||||
|
self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {}
|
||||||
|
self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {}
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.tonl()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict):
|
||||||
|
if 'created_at' in data:
|
||||||
|
# 移除 created_at 字段,避免类型错误
|
||||||
|
del data['created_at']
|
||||||
|
if 'updated_at' in data:
|
||||||
|
# 移除 updated_at 字段,避免类型错误
|
||||||
|
del data['updated_at']
|
||||||
|
if 'deleted_at' in data:
|
||||||
|
# 移除 deleted_at 字段,避免类型错误
|
||||||
|
del data['deleted_at']
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
# 将对象转换为字典格式
|
||||||
|
return {
|
||||||
|
'id': self.id,
|
||||||
|
'name': self.name,
|
||||||
|
'contact': self.contact,
|
||||||
|
'gender': self.gender,
|
||||||
|
'age': self.age,
|
||||||
|
'height': self.height,
|
||||||
|
'marital_status': self.marital_status,
|
||||||
|
'match_requirement': self.match_requirement,
|
||||||
|
'introduction': self.introduction,
|
||||||
|
'comments': self.comments,
|
||||||
|
}
|
||||||
|
|
||||||
|
def meta(self) -> Dict[str, str]:
|
||||||
|
# 返回对象的元数据信息
|
||||||
|
meta = {
|
||||||
|
'id': self.id,
|
||||||
|
'name': self.name,
|
||||||
|
'gender': self.gender,
|
||||||
|
'age': self.age,
|
||||||
|
'height': self.height,
|
||||||
|
'marital_status': self.marital_status,
|
||||||
|
}
|
||||||
|
logging.info(f"people meta: {meta}")
|
||||||
|
return meta
|
||||||
|
|
||||||
|
def tonl(self) -> str:
|
||||||
|
# 将对象转换为文档格式的字符串
|
||||||
|
doc = []
|
||||||
|
doc.append(f"姓名: {self.name}")
|
||||||
|
doc.append(f"性别: {self.gender}")
|
||||||
|
if self.age:
|
||||||
|
doc.append(f"年龄: {self.age}")
|
||||||
|
if self.height:
|
||||||
|
doc.append(f"身高: {self.height}cm")
|
||||||
|
if self.marital_status:
|
||||||
|
doc.append(f"婚姻状况: {self.marital_status}")
|
||||||
|
if self.match_requirement:
|
||||||
|
doc.append(f"择偶要求: {self.match_requirement}")
|
||||||
|
if self.introduction:
|
||||||
|
doc.append("个人介绍:")
|
||||||
|
for key, value in self.introduction.items():
|
||||||
|
doc.append(f" - {key}: {value}")
|
||||||
|
if self.comments:
|
||||||
|
doc.append("总结评价:")
|
||||||
|
for key, value in self.comments.items():
|
||||||
|
doc.append(f" - {key}: {value}")
|
||||||
|
return '\n'.join(doc)
|
||||||
|
|
||||||
|
def comment(self, comment: Dict[str, str]):
|
||||||
|
# 添加总结评价
|
||||||
|
self.comments.update(comment)
|
||||||
0
src/storage/__init__.py
Normal file
0
src/storage/__init__.py
Normal file
216
src/storage/people_store.py
Normal file
216
src/storage/people_store.py
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, func
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from utils.config import get_instance as get_config
|
||||||
|
from utils.vsdb import VectorDB, get_instance as get_vsdb
|
||||||
|
from utils.obs import OBS, get_instance as get_obs
|
||||||
|
|
||||||
|
from models.people import People
|
||||||
|
|
||||||
|
people_store = None
|
||||||
|
Base = declarative_base()
|
||||||
|
class PeopleORM(Base):
|
||||||
|
__tablename__ = 'peoples'
|
||||||
|
id = Column(String(36), primary_key=True)
|
||||||
|
name = Column(String(255), index=True)
|
||||||
|
contact = Column(String(255), index=True)
|
||||||
|
gender = Column(String(10))
|
||||||
|
age = Column(Integer)
|
||||||
|
height = Column(Integer)
|
||||||
|
marital_status = Column(String(20))
|
||||||
|
match_requirement = Column(Text)
|
||||||
|
introduction = Column(Text)
|
||||||
|
comments = Column(Text)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||||
|
def parse_from_people(self, people: People):
|
||||||
|
import json
|
||||||
|
self.id = people.id
|
||||||
|
self.name = people.name
|
||||||
|
self.contact = people.contact
|
||||||
|
self.gender = people.gender
|
||||||
|
self.age = people.age
|
||||||
|
self.height = people.height
|
||||||
|
self.marital_status = people.marital_status
|
||||||
|
self.match_requirement = people.match_requirement
|
||||||
|
# 将字典类型字段序列化为JSON字符串存储
|
||||||
|
self.introduction = json.dumps(people.introduction, ensure_ascii=False)
|
||||||
|
self.comments = json.dumps(people.comments, ensure_ascii=False)
|
||||||
|
|
||||||
|
def to_people(self) -> People:
|
||||||
|
import json
|
||||||
|
people = People()
|
||||||
|
people.id = self.id
|
||||||
|
people.name = self.name
|
||||||
|
people.contact = self.contact
|
||||||
|
people.gender = self.gender
|
||||||
|
people.age = self.age
|
||||||
|
people.height = self.height
|
||||||
|
people.marital_status = self.marital_status
|
||||||
|
people.match_requirement = self.match_requirement
|
||||||
|
# 将JSON字符串反序列化为字典类型字段
|
||||||
|
try:
|
||||||
|
people.introduction = json.loads(self.introduction) if self.introduction else {}
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
people.introduction = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
people.comments = json.loads(self.comments) if self.comments else {}
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
people.comments = {}
|
||||||
|
|
||||||
|
return people
|
||||||
|
|
||||||
|
class PeopleStore:
|
||||||
|
def __init__(self):
|
||||||
|
config = get_config()
|
||||||
|
self.sqldb_engine = create_engine(config.get("sqlalchemy", "database_dsn"))
|
||||||
|
Base.metadata.create_all(self.sqldb_engine)
|
||||||
|
self.session_maker = sessionmaker(bind=self.sqldb_engine)
|
||||||
|
self.vsdb: VectorDB = get_vsdb()
|
||||||
|
self.obs: OBS = get_obs()
|
||||||
|
|
||||||
|
def save(self, people: People) -> str:
|
||||||
|
"""
|
||||||
|
保存人物到数据库和向量数据库
|
||||||
|
|
||||||
|
:param people: 人物对象
|
||||||
|
:return: 人物ID
|
||||||
|
"""
|
||||||
|
# 0. 生成 people id
|
||||||
|
people.id = people.id if people.id else uuid.uuid4().hex
|
||||||
|
|
||||||
|
# 1. 转换模型,并保存到 SQL 数据库
|
||||||
|
people_orm = PeopleORM()
|
||||||
|
people_orm.parse_from_people(people)
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.add(people_orm)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# 2. 保存到向量数据库
|
||||||
|
people_metadata = people.meta()
|
||||||
|
people_document = people.tonl()
|
||||||
|
logging.info(f"people: {people}")
|
||||||
|
logging.info(f"people_metadata: {people_metadata}")
|
||||||
|
logging.info(f"people_document: {people_document}")
|
||||||
|
results = self.vsdb.insert(metadatas=[people_metadata], documents=[people_document], ids=[people.id])
|
||||||
|
logging.info(f"results: {results}")
|
||||||
|
if len(results) == 0:
|
||||||
|
raise Exception("insert failed")
|
||||||
|
|
||||||
|
# 3. 保存到 OBS 存储
|
||||||
|
people_dict = people.to_dict()
|
||||||
|
people_json = json.dumps(people_dict, ensure_ascii=False)
|
||||||
|
obs_url = self.obs.Put(f"peoples/{people.id}/detail.json", people_json.encode('utf-8'))
|
||||||
|
logging.info(f"obs_url: {obs_url}")
|
||||||
|
|
||||||
|
return people.id
|
||||||
|
|
||||||
|
def update(self, people: People) -> None:
|
||||||
|
raise Exception("update not implemented")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find(self, people_id: str) -> People:
|
||||||
|
"""
|
||||||
|
根据人物ID查询人物
|
||||||
|
|
||||||
|
:param people_id: 人物ID
|
||||||
|
:return: 人物对象
|
||||||
|
"""
|
||||||
|
with self.session_maker() as session:
|
||||||
|
people_orm = session.query(PeopleORM).filter(
|
||||||
|
PeopleORM.id == people_id,
|
||||||
|
PeopleORM.deleted_at.is_(None)
|
||||||
|
).first()
|
||||||
|
if not people_orm:
|
||||||
|
raise Exception(f"people not found, people_id: {people_id}")
|
||||||
|
return people_orm.to_people()
|
||||||
|
|
||||||
|
def query(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> list[People]:
|
||||||
|
"""
|
||||||
|
根据查询条件查询人物
|
||||||
|
|
||||||
|
:param conds: 查询条件字典
|
||||||
|
:param limit: 分页大小
|
||||||
|
:param offset: 分页偏移量
|
||||||
|
:return: 人物对象列表
|
||||||
|
"""
|
||||||
|
if conds is None:
|
||||||
|
conds = {}
|
||||||
|
with self.session_maker() as session:
|
||||||
|
people_orms = session.query(PeopleORM).filter_by(**conds).filter(
|
||||||
|
PeopleORM.deleted_at.is_(None)
|
||||||
|
).limit(limit).offset(offset).all()
|
||||||
|
return [people_orm.to_people() for people_orm in people_orms]
|
||||||
|
|
||||||
|
def search(self, search: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[People]:
|
||||||
|
"""
|
||||||
|
根据搜索内容和查询条件查询人物
|
||||||
|
|
||||||
|
:param search: 搜索内容
|
||||||
|
:param metadatas: 查询条件字典
|
||||||
|
:param ids: 可选的人物ID列表,用于过滤结果
|
||||||
|
:param top_k: 返回结果数量
|
||||||
|
:return: 人物对象列表
|
||||||
|
"""
|
||||||
|
peoples = []
|
||||||
|
results = self.vsdb.search(document=search, metadatas=metadatas, ids=ids, top_k=top_k)
|
||||||
|
logging.info(f"results: {results}")
|
||||||
|
people_id_list = []
|
||||||
|
for result in results:
|
||||||
|
people_id = result.get('id', '')
|
||||||
|
if not people_id:
|
||||||
|
continue
|
||||||
|
people_id_list.append(people_id)
|
||||||
|
logging.info(f"people_id_list: {people_id_list}")
|
||||||
|
with self.session_maker() as session:
|
||||||
|
people_orms = session.query(PeopleORM).filter(PeopleORM.id.in_(people_id_list)).filter(
|
||||||
|
PeopleORM.deleted_at.is_(None)
|
||||||
|
).all()
|
||||||
|
# 根据 people_id_list 的顺序对查询结果进行排序
|
||||||
|
order_map = {pid: idx for idx, pid in enumerate(people_id_list)}
|
||||||
|
people_orms.sort(key=lambda orm: order_map.get(orm.id, len(order_map)))
|
||||||
|
for people_orm in people_orms:
|
||||||
|
people = people_orm.to_people()
|
||||||
|
peoples.append(people)
|
||||||
|
return peoples
|
||||||
|
|
||||||
|
def delete(self, people_id: str) -> None:
|
||||||
|
"""
|
||||||
|
删除人物从数据库和向量数据库
|
||||||
|
|
||||||
|
:param people_id: 人物ID
|
||||||
|
"""
|
||||||
|
# 1. 从 SQL 数据库软删除人物
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.query(PeopleORM).filter(
|
||||||
|
PeopleORM.id == people_id,
|
||||||
|
PeopleORM.deleted_at.is_(None)
|
||||||
|
).update({PeopleORM.deleted_at: func.now()}, synchronize_session=False)
|
||||||
|
session.commit()
|
||||||
|
logging.info(f"人物 {people_id} 标记删除 SQL 成功")
|
||||||
|
|
||||||
|
# 2. 删除向量数据库中的记录
|
||||||
|
self.vsdb.delete(ids=[people_id])
|
||||||
|
logging.info(f"人物 {people_id} 删除向量数据库成功")
|
||||||
|
|
||||||
|
# 3. 删除 OBS 存储中的文件
|
||||||
|
keys = self.obs.List(f"peoples/{people_id}/")
|
||||||
|
for key in keys:
|
||||||
|
self.obs.Del(key)
|
||||||
|
logging.info(f"文件 {key} 删除 OBS 成功")
|
||||||
|
|
||||||
|
def init():
|
||||||
|
global people_store
|
||||||
|
people_store = PeopleStore()
|
||||||
|
|
||||||
|
|
||||||
|
def get_instance() -> PeopleStore:
|
||||||
|
return people_store
|
||||||
3
src/utils/__init__.py
Normal file
3
src/utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# 导出utils模块中的子模块
|
||||||
|
from . import config, obs, ocr, vsdb, logger
|
||||||
|
__all__ = ['config', 'obs', 'ocr', 'vsdb', 'logger']
|
||||||
25
src/utils/config.py
Normal file
25
src/utils/config.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import configparser
|
||||||
|
|
||||||
|
config = None
|
||||||
|
|
||||||
|
|
||||||
|
def init(config_file: str):
|
||||||
|
global config
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read(config_file)
|
||||||
|
|
||||||
|
def get_instance() -> configparser.ConfigParser:
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 本文件的绝对路径
|
||||||
|
import os
|
||||||
|
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
|
||||||
|
init(config_file)
|
||||||
|
conf = get_instance()
|
||||||
|
print(conf.sections())
|
||||||
|
for section in conf.sections():
|
||||||
|
print(conf.options(section))
|
||||||
|
for option in conf.options(section):
|
||||||
|
print(f"{section}.{option}={conf.get(section, option)}")
|
||||||
91
src/utils/logger.py
Normal file
91
src/utils/logger.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from .config import get_instance as get_config
|
||||||
|
|
||||||
|
# 定义颜色代码
|
||||||
|
class Colors:
|
||||||
|
RED = '\033[31m'
|
||||||
|
GREEN = '\033[32m'
|
||||||
|
YELLOW = '\033[33m'
|
||||||
|
BLUE = '\033[34m'
|
||||||
|
MAGENTA = '\033[35m'
|
||||||
|
CYAN = '\033[36m'
|
||||||
|
WHITE = '\033[37m'
|
||||||
|
RESET = '\033[0m' # 重置颜色
|
||||||
|
|
||||||
|
# 自定义控制台处理器,为不同日志级别添加颜色
|
||||||
|
class ColoredConsoleHandler(logging.StreamHandler):
|
||||||
|
def emit(self, record):
|
||||||
|
# 为不同日志级别设置颜色
|
||||||
|
colors = {
|
||||||
|
logging.DEBUG: Colors.CYAN,
|
||||||
|
logging.INFO: Colors.GREEN,
|
||||||
|
logging.WARNING: Colors.YELLOW,
|
||||||
|
logging.ERROR: Colors.RED,
|
||||||
|
logging.CRITICAL: Colors.MAGENTA
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取对应级别的颜色,默认为白色
|
||||||
|
color = colors.get(record.levelno, Colors.WHITE)
|
||||||
|
|
||||||
|
# 获取原始消息
|
||||||
|
message = self.format(record)
|
||||||
|
|
||||||
|
# 添加颜色并输出
|
||||||
|
self.stream.write(f"{color}{message}{Colors.RESET}\n")
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
def init():
|
||||||
|
config = get_config()
|
||||||
|
log_dir = config.get("log", "log_dir", fallback="logs")
|
||||||
|
log_file = config.get("log", "log_file", fallback="if.u.service")
|
||||||
|
log_level = config.get("log", "log_level", fallback=logging.INFO)
|
||||||
|
console_log_level = config.get("log", "console_log_level", fallback=logging.DEBUG)
|
||||||
|
|
||||||
|
# 创建logs目录(如果不存在)
|
||||||
|
if not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
|
||||||
|
# 设置日志格式
|
||||||
|
log_format = "[%(asctime)s.%(msecs)03d][%(filename)s:%(lineno)d][%(levelname)s] %(message)s"
|
||||||
|
date_format = "%Y-%m-%d %H:%M:%S"
|
||||||
|
|
||||||
|
# 创建格式化器
|
||||||
|
formatter = logging.Formatter(log_format, datefmt=date_format)
|
||||||
|
|
||||||
|
# 获取根日志记录器
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.setLevel(logging.NOTSET)
|
||||||
|
|
||||||
|
# 清除现有的处理器
|
||||||
|
root_logger.handlers.clear()
|
||||||
|
|
||||||
|
# 创建控制台处理器并设置颜色
|
||||||
|
console_handler = ColoredConsoleHandler()
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
console_handler.setLevel(console_log_level)
|
||||||
|
root_logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
# 创建文件处理器
|
||||||
|
log_filename = os.path.join(log_dir, f"{log_file}_{datetime.now().strftime('%Y%m%d')}.log")
|
||||||
|
file_handler = logging.FileHandler(log_filename, encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
file_handler.setLevel(log_level)
|
||||||
|
root_logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# 确保日志消息被正确处理
|
||||||
|
logging.addLevelName(logging.DEBUG, "D")
|
||||||
|
logging.addLevelName(logging.INFO, "I")
|
||||||
|
logging.addLevelName(logging.WARNING, "W")
|
||||||
|
logging.addLevelName(logging.ERROR, "E")
|
||||||
|
logging.addLevelName(logging.CRITICAL, "C")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
init(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG)
|
||||||
|
logging.debug("debug log")
|
||||||
|
logging.info("info log")
|
||||||
|
logging.warning("warning log")
|
||||||
|
logging.error("error log")
|
||||||
|
logging.critical("critical log")
|
||||||
220
src/utils/obs.py
Normal file
220
src/utils/obs.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Protocol
|
||||||
|
import qiniu
|
||||||
|
import requests
|
||||||
|
from .config import get_instance as get_config
|
||||||
|
|
||||||
|
|
||||||
|
class OBS(Protocol):
|
||||||
|
def Put(self, obs_path: str, content: bytes) -> str:
|
||||||
|
"""
|
||||||
|
上传文件到OBS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_path (str): OBS目标路径
|
||||||
|
content (bytes): 文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: OBS文件路径
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def Get(self, obs_path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
从OBS下载文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_path (str): OBS文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: 文件内容
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def List(self, obs_path: str) -> list:
|
||||||
|
"""
|
||||||
|
列出OBS目录下的所有文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_path (str): OBS目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 所有文件路径列表
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def Del(self, obs_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除OBS文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_path (str): OBS文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否删除成功
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def Link(self, obs_path: str) -> str:
|
||||||
|
"""
|
||||||
|
获取OBS文件链接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_path (str): OBS文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: OBS文件链接
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class Koodo:
|
||||||
|
def __init__(self):
|
||||||
|
config = get_config()
|
||||||
|
self.bucket_name = config.get('koodo_obs', 'bucket_name')
|
||||||
|
self.prefix_path = config.get('koodo_obs', 'prefix_path')
|
||||||
|
self.access_key = config.get('koodo_obs', 'access_key')
|
||||||
|
self.secret_key = config.get('koodo_obs', 'secret_key')
|
||||||
|
self.outer_domain = config.get('koodo_obs', 'outer_domain')
|
||||||
|
self.auth = qiniu.Auth(self.access_key, self.secret_key)
|
||||||
|
self.bucket = qiniu.BucketManager(self.auth)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def Put(self, obs_path: str, content: bytes) -> str:
|
||||||
|
"""
|
||||||
|
上传文件到OBS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_path (str): OBS目标路径
|
||||||
|
content (bytes): 文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: OBS文件路径
|
||||||
|
"""
|
||||||
|
full_path = f"{self.prefix_path}{obs_path}"
|
||||||
|
token = self.auth.upload_token(self.bucket_name, full_path)
|
||||||
|
ret, info = qiniu.put_data(token, full_path, content)
|
||||||
|
logging.debug(f"文件 {obs_path} 上传到 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}")
|
||||||
|
if ret is None or info.status_code != 200:
|
||||||
|
logging.error(f"文件 {obs_path} 上传失败, 错误信息: {info.text_body}")
|
||||||
|
return ""
|
||||||
|
logging.info(f"文件 {obs_path} 上传成功, OBS路径: {full_path}")
|
||||||
|
return f"{self.outer_domain}/{full_path}"
|
||||||
|
|
||||||
|
def Get(self, obs_path: str) -> bytes:
|
||||||
|
"""
|
||||||
|
从OBS下载文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_path (str): OBS文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: 文件内容
|
||||||
|
"""
|
||||||
|
link = f"{self.outer_domain}/{self.prefix_path}{obs_path}"
|
||||||
|
resp = requests.get(link)
|
||||||
|
data = json.loads(resp.text)
|
||||||
|
if 'error' in data and data['error']:
|
||||||
|
logging.error(f"从 OBS {obs_path} 下载文件失败, 错误信息: {data['error']}")
|
||||||
|
return None
|
||||||
|
return resp.content
|
||||||
|
|
||||||
|
def List(self, prefix: str = "") -> list[str]:
|
||||||
|
"""
|
||||||
|
列出OBS目录下的所有文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix (str, optional): OBS目录路径前缀. Defaults to "".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: 文件路径列表
|
||||||
|
"""
|
||||||
|
prefix = f"{self.prefix_path}{prefix}"
|
||||||
|
ret, eof, info = self.bucket.list(self.bucket_name, prefix)
|
||||||
|
keys = []
|
||||||
|
for item in ret['items']:
|
||||||
|
item['key'] = item['key'].replace(prefix, "")
|
||||||
|
keys.append(item['key'])
|
||||||
|
# logging.debug(f"文件 {item['key']} 路径: {item['key']}")
|
||||||
|
# logging.debug(f"ret: {ret}")
|
||||||
|
# logging.debug(f"eof: {eof}")
|
||||||
|
# logging.debug(f"info: {info}")
|
||||||
|
return keys
|
||||||
|
|
||||||
|
def Del(self, obs_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
删除OBS文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_path (str): OBS文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否删除成功
|
||||||
|
"""
|
||||||
|
ret, info = self.bucket.delete(self.bucket_name, f"{self.prefix_path}{obs_path}")
|
||||||
|
logging.debug(f"文件 {obs_path} 删除 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}")
|
||||||
|
if ret is None or info.status_code != 200:
|
||||||
|
logging.error(f"文件 {obs_path} 删除 OBS 失败, 错误信息: {info.text_body}")
|
||||||
|
return False
|
||||||
|
logging.info(f"文件 {obs_path} 删除 OBS 成功")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def Link(self, obs_path: str) -> str:
|
||||||
|
"""
|
||||||
|
获取OBS文件链接
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obs_path (str): OBS文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: OBS文件链接
|
||||||
|
"""
|
||||||
|
return f"{self.outer_domain}/{self.prefix_path}{obs_path}"
|
||||||
|
|
||||||
|
|
||||||
|
_obs_instance: OBS = None
|
||||||
|
|
||||||
|
def init():
|
||||||
|
global _obs_instance
|
||||||
|
_obs_instance = Koodo()
|
||||||
|
|
||||||
|
def get_instance() -> OBS:
|
||||||
|
global _obs_instance
|
||||||
|
return _obs_instance
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
|
||||||
|
from .logger import init as init_logger
|
||||||
|
init_logger(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG)
|
||||||
|
|
||||||
|
from .config import init as init_config, get_instance as get_config
|
||||||
|
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
|
||||||
|
init_config(config_file)
|
||||||
|
|
||||||
|
init()
|
||||||
|
obs = get_instance()
|
||||||
|
|
||||||
|
# 从OBS下载测试图片
|
||||||
|
# obs_path = "test111.PNG"
|
||||||
|
# local_path = os.path.join(os.path.dirname(__file__), "../../test/9e03ad5eb8b1a51e752fb79cd8f98169.PNG")
|
||||||
|
# content = None
|
||||||
|
# with open(local_path, "rb") as f:
|
||||||
|
# content = f.read()
|
||||||
|
# if content is None:
|
||||||
|
# print(f"文件 {local_path} 读取失败")
|
||||||
|
# exit(1)
|
||||||
|
# obs.Put(obs_path, content)
|
||||||
|
|
||||||
|
# link = obs.Link(obs_path)
|
||||||
|
# print(f"文件 {obs_path} 链接: {link}")
|
||||||
|
|
||||||
|
# 列出OBS目录下的所有文件
|
||||||
|
keys = obs.List("")
|
||||||
|
print(f"OBS 目录下的所有文件: {keys}")
|
||||||
|
for key in keys:
|
||||||
|
link = obs.Del(key)
|
||||||
|
print(f"文件 {key} 删除 OBS 成功: {link}")
|
||||||
105
src/utils/ocr.py
Normal file
105
src/utils/ocr.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Protocol
|
||||||
|
from alibabacloud_ocr_api20210707.client import Client as OcrClient
|
||||||
|
from alibabacloud_tea_openapi import models as open_api_models
|
||||||
|
from alibabacloud_ocr_api20210707 import models as ocr_models
|
||||||
|
from alibabacloud_tea_util import models as util_models
|
||||||
|
from alibabacloud_tea_util.client import Client as UtilClient
|
||||||
|
from .config import get_instance as get_config
|
||||||
|
|
||||||
|
|
||||||
|
class OCR(Protocol):
|
||||||
|
def recognize_image_text(self, image_link: str) -> str:
|
||||||
|
"""
|
||||||
|
从图片提取文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_link (str): 图片链接
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 提取到的文本
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
class AliOCR:
|
||||||
|
def __init__(self):
|
||||||
|
config = get_config()
|
||||||
|
self.access_key = config.get("ali_ocr", "access_key")
|
||||||
|
self.secret_key = config.get("ali_ocr", "secret_key")
|
||||||
|
self.endpoint = config.get("ali_ocr", "endpoint")
|
||||||
|
self.client = self._create_client()
|
||||||
|
|
||||||
|
def _create_client(self):
|
||||||
|
config = open_api_models.Config(
|
||||||
|
access_key_id=self.access_key,
|
||||||
|
access_key_secret=self.secret_key,
|
||||||
|
)
|
||||||
|
config.endpoint = self.endpoint
|
||||||
|
return OcrClient(config)
|
||||||
|
|
||||||
|
def recognize_image_text(self, image_link: str) -> str:
|
||||||
|
"""
|
||||||
|
使用阿里云OCR从图片链接提取文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_link (str): 图片链接
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 提取到的文本
|
||||||
|
"""
|
||||||
|
# 创建OCR请求
|
||||||
|
recognize_general_request = ocr_models.RecognizeGeneralRequest(url=image_link)
|
||||||
|
runtime = util_models.RuntimeOptions()
|
||||||
|
try:
|
||||||
|
resp = self.client.recognize_general_with_options(recognize_general_request, runtime)
|
||||||
|
logging.debug(resp.body.data)
|
||||||
|
except Exception as error:
|
||||||
|
# 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
|
||||||
|
# 错误 message
|
||||||
|
logging.error(error.message)
|
||||||
|
# 诊断地址
|
||||||
|
logging.error(error.data.get("Recommend"))
|
||||||
|
UtilClient.assert_as_string(error.message)
|
||||||
|
|
||||||
|
response = self.client.recognize_general_with_options(recognize_general_request, runtime)
|
||||||
|
if response.status_code == 200 and response.body:
|
||||||
|
result_data = response.body.data
|
||||||
|
result_body = json.loads(result_data)
|
||||||
|
if result_body and 'content' in result_body:
|
||||||
|
return result_body['content']
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 全局OCR实例
|
||||||
|
_ocr_instance = None
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
"""初始化OCR实例"""
|
||||||
|
global _ocr_instance
|
||||||
|
_ocr_instance = AliOCR()
|
||||||
|
|
||||||
|
|
||||||
|
def get_instance() -> OCR:
|
||||||
|
"""获取OCR实例"""
|
||||||
|
global _ocr_instance
|
||||||
|
if _ocr_instance is None:
|
||||||
|
raise RuntimeError("OCR模块未初始化,请先调用init()函数")
|
||||||
|
return _ocr_instance
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
|
||||||
|
from logger import init as init_logger
|
||||||
|
init_logger(console_log_level=logging.DEBUG)
|
||||||
|
|
||||||
|
from config import init as init_config
|
||||||
|
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
|
||||||
|
init_config(config_file)
|
||||||
|
|
||||||
|
init()
|
||||||
|
ocr = get_instance()
|
||||||
|
text = ocr.recognize_image_text(image_link="https://pic.mamamiyear.site/test.if.u/test111.PNG")
|
||||||
|
print(text)
|
||||||
241
src/utils/vsdb.py
Normal file
241
src/utils/vsdb.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import uuid
|
||||||
|
import chromadb
|
||||||
|
import logging
|
||||||
|
from typing import Protocol
|
||||||
|
from chromadb.config import Settings
|
||||||
|
from chromadb.utils import embedding_functions
|
||||||
|
from .config import get_instance as get_config
|
||||||
|
|
||||||
|
class VectorDB(Protocol):
|
||||||
|
|
||||||
|
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
|
||||||
|
"""
|
||||||
|
插入向量到数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector (list[float]): 向量
|
||||||
|
metadata (dict): 元数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否插入成功
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def delete(self, ids: list[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Delete documents from a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of IDs to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether deletion was successful
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def query(self, metadatas: dict, ids: list[str], top_k: int = 5) -> list[dict]:
|
||||||
|
"""
|
||||||
|
查询向量数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_vector (list[float]): 查询向量
|
||||||
|
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: 查询结果列表
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
||||||
|
"""
|
||||||
|
搜索向量数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: Document to search
|
||||||
|
metadatas: Metadata to filter by
|
||||||
|
ids: List of IDs to filter by
|
||||||
|
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: 查询结果列表
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
class ChromaDB:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize the ChromaDB instance.
|
||||||
|
"""
|
||||||
|
config = get_config()
|
||||||
|
self.embedding_functions = embedding_functions.OpenAIEmbeddingFunction(
|
||||||
|
api_base=config.get("voc-engine_embedding", "api_url"),
|
||||||
|
api_key=config.get("voc-engine_embedding", "api_key"),
|
||||||
|
model_name=config.get("voc-engine_embedding", "endpoint"),
|
||||||
|
)
|
||||||
|
persist_directory = config.get("chroma_vsdb", "database_dir", fallback=None)
|
||||||
|
logging.debug(f"persist_directory: {persist_directory}")
|
||||||
|
if persist_directory:
|
||||||
|
self.client = chromadb.PersistentClient(
|
||||||
|
path=persist_directory,
|
||||||
|
settings=Settings(anonymized_telemetry=False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.client = chromadb.Client(
|
||||||
|
settings=Settings(anonymized_telemetry=False),
|
||||||
|
)
|
||||||
|
self.collection_name = config.get("chroma_vsdb", "collection_name", fallback="peoples")
|
||||||
|
metadata: dict = kwargs.get('collection_metadata', {'hnsw:space': 'cosine'})
|
||||||
|
metadata['hnsw:space'] = metadata.get('hnsw:space', 'cosine')
|
||||||
|
self.collection = self.client.get_or_create_collection(
|
||||||
|
name=self.collection_name,
|
||||||
|
embedding_function=self.embedding_functions,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
|
||||||
|
"""
|
||||||
|
Insert documents into a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadatas: List of metadata corresponding to each document
|
||||||
|
documents: List of documents to insert
|
||||||
|
ids: Optional list of unique IDs for each document. If None, IDs will be generated.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: List of inserted IDs
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
# Generate unique IDs if not provided
|
||||||
|
ids = [str(uuid.uuid4()) for _ in range(len(documents))]
|
||||||
|
|
||||||
|
self.collection.add(
|
||||||
|
documents=documents,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids
|
||||||
|
)
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def delete(self, ids: list[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Delete documents from a collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of IDs to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether deletion was successful
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.collection.delete(ids)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error deleting documents: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def query(self, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
||||||
|
"""
|
||||||
|
查询向量数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metadatas: Metadata to filter by
|
||||||
|
ids: List of IDs to query
|
||||||
|
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: 查询结果列表
|
||||||
|
"""
|
||||||
|
results = self.collection.query(
|
||||||
|
query_embeddings=None,
|
||||||
|
query_texts=None,
|
||||||
|
n_results=top_k,
|
||||||
|
where=metadatas,
|
||||||
|
ids=ids,
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
)
|
||||||
|
formatted_results = []
|
||||||
|
for i in range(len(results['ids'][0])):
|
||||||
|
result = {
|
||||||
|
'id': results['ids'][0][i],
|
||||||
|
'distance': results['distances'][0][i],
|
||||||
|
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
|
||||||
|
'document': results['documents'][0][i] if results['documents'][0] else ''
|
||||||
|
}
|
||||||
|
formatted_results.append(result)
|
||||||
|
return formatted_results
|
||||||
|
|
||||||
|
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
||||||
|
"""
|
||||||
|
搜索向量数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: Document to search
|
||||||
|
metadatas: Metadata to filter by
|
||||||
|
ids: List of IDs to filter by
|
||||||
|
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: 查询结果列表
|
||||||
|
"""
|
||||||
|
results = self.collection.query(
|
||||||
|
query_embeddings=None,
|
||||||
|
query_texts=[document],
|
||||||
|
n_results=top_k,
|
||||||
|
where=metadatas if metadatas else None,
|
||||||
|
ids=ids,
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
)
|
||||||
|
formatted_results = []
|
||||||
|
for i in range(len(results['ids'][0])):
|
||||||
|
logging.info(f"result id: {results['ids'][0][i]}, distance: {results['distances'][0][i]}")
|
||||||
|
result = {
|
||||||
|
'id': results['ids'][0][i],
|
||||||
|
'distance': results['distances'][0][i],
|
||||||
|
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
|
||||||
|
'document': results['documents'][0][i] if results['documents'][0] else ''
|
||||||
|
}
|
||||||
|
formatted_results.append(result)
|
||||||
|
return formatted_results
|
||||||
|
pass
|
||||||
|
|
||||||
|
_vsdb_instance: VectorDB = None
|
||||||
|
|
||||||
|
def init():
|
||||||
|
global _vsdb_instance
|
||||||
|
_vsdb_instance = ChromaDB()
|
||||||
|
|
||||||
|
def get_instance() -> VectorDB:
|
||||||
|
global _vsdb_instance
|
||||||
|
return _vsdb_instance
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import os
|
||||||
|
|
||||||
|
from logger import init as init_logger
|
||||||
|
init_logger(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG)
|
||||||
|
|
||||||
|
from config import init as init_config
|
||||||
|
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
|
||||||
|
init_config(config_file)
|
||||||
|
|
||||||
|
init()
|
||||||
|
vsdb = get_instance()
|
||||||
|
metadatas = [
|
||||||
|
{'name': '丽丽'},
|
||||||
|
{'name': '志刚'},
|
||||||
|
{'name': '张三'},
|
||||||
|
{'name': '李四'},
|
||||||
|
]
|
||||||
|
documents = [
|
||||||
|
'姓名: 丽丽, 性别: 女, 年龄: 23, 爱好: 爬山、骑行、攀岩、跳伞、蹦极',
|
||||||
|
"姓名: 志刚, 性别: 男, 年龄: 25, 爱好: 读书、游戏",
|
||||||
|
"姓名: 张三, 性别: 男, 年龄: 30, 爱好: 画画、写作、阅读、逛展、旅行",
|
||||||
|
"姓名: 李四, 性别: 男, 年龄: 35, 爱好: 做饭、美食、旅游"
|
||||||
|
]
|
||||||
|
search_text = '25岁以下的'
|
||||||
|
ids = vsdb.insert(metadatas, documents)
|
||||||
|
results = vsdb.search(search_text, None, None, top_k=4)
|
||||||
|
for result in results:
|
||||||
|
print(result['document'], ' ', result['distance'])
|
||||||
16
test/test_logger.py
Normal file
16
test/test_logger.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.join(os.path.dirname(__file__), '../src'))
|
||||||
|
|
||||||
|
from utils.logger import init
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 初始化日志
|
||||||
|
init()
|
||||||
|
|
||||||
|
# 测试不同级别的日志
|
||||||
|
logging.debug("这是一条调试信息")
|
||||||
|
logging.info("这是一条普通信息")
|
||||||
|
logging.warning("这是一条警告信息")
|
||||||
|
logging.error("这是一条错误信息")
|
||||||
|
logging.critical("这是一条严重错误信息")
|
||||||
Reference in New Issue
Block a user