Compare commits
15 Commits
fca2b1449f
...
release_v0
| Author | SHA1 | Date | |
|---|---|---|---|
| 8174c4cfe5 | |||
| 4c48d11bfa | |||
| e74279ca5e | |||
| 13b70ba424 | |||
| fae93b5ab8 | |||
| 1a092248eb | |||
| 0a749d56e8 | |||
| 3d13aa18ae | |||
| d179418e7d | |||
| 496f35a386 | |||
| c99b324b81 | |||
| 7a189eb631 | |||
| 98cbc754f6 | |||
| 736d8ed193 | |||
| 268eb8be2b |
@@ -1,18 +1,18 @@
|
||||
[project]
|
||||
name = "service"
|
||||
version = "0.1.0"
|
||||
version = "0.1"
|
||||
description = "This project is the web servcie sub-system for if.u projuect"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"alibabacloud-ocr-api20210707>=3.1.3",
|
||||
"chromadb>=1.1.1",
|
||||
"fastapi>=0.118.2",
|
||||
"langchain>=0.3.27",
|
||||
"langchain-openai>=0.3.35",
|
||||
"numpy>=2.3.3",
|
||||
"alibabacloud-tea-openapi>=0.4.1",
|
||||
"fastapi>=0.118.3",
|
||||
"langchain==0.3.27",
|
||||
"langchain-openai==0.3.35",
|
||||
"pymysql>=1.1.2",
|
||||
"python-multipart>=0.0.20",
|
||||
"qiniu>=7.17.0",
|
||||
"requests>=2.32.5",
|
||||
"sqlalchemy>=2.0.44",
|
||||
"uvicorn>=0.38.0",
|
||||
]
|
||||
|
||||
22
src/agents/base_agent.py
Normal file
22
src/agents/base_agent.py
Normal file
@@ -0,0 +1,22 @@
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from utils.config import get_instance as get_config
|
||||
|
||||
class BaseAgent:
|
||||
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
|
||||
config = get_config()
|
||||
llm_api_url = api_url or config.get("ai", "llm_api_url")
|
||||
llm_api_key = api_key or config.get("ai", "llm_api_key")
|
||||
llm_model_name = model_name or config.get("ai", "llm_model_name")
|
||||
self.llm = ChatOpenAI(
|
||||
openai_api_key=llm_api_key,
|
||||
openai_api_base=llm_api_url,
|
||||
model_name=llm_model_name,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class SummaryPeopleAgent(BaseAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
@@ -1,25 +1,19 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
from models.people import People
|
||||
|
||||
class BaseAgent:
|
||||
def __init__(self):
|
||||
self.llm = ChatOpenAI(
|
||||
openai_api_key="56d82040-85c7-4701-8f87-734985e27909",
|
||||
openai_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||
model_name="ep-20250722161445-n9lfq"
|
||||
)
|
||||
pass
|
||||
|
||||
class ExtractPeopleAgent(BaseAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
|
||||
super().__init__(api_url, api_key, model_name)
|
||||
self.prompt = ChatPromptTemplate.from_messages([
|
||||
(
|
||||
"system",
|
||||
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')},"
|
||||
"你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n"
|
||||
"姓名 name\n"
|
||||
"性别 gender\n"
|
||||
@@ -28,6 +22,8 @@ class ExtractPeopleAgent(BaseAgent):
|
||||
"婚姻状况 marital_status\n"
|
||||
"择偶要求 match_requirement\n"
|
||||
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n"
|
||||
"其中,'年龄 age' 和 '身高(cm) height' 必须是一个整数,不能是一个字符串;\n"
|
||||
"并且,'性别 gender' 根据识别结果,必须从 男,女,未知 三选一填写。\n"
|
||||
"除了上述基本信息,还有一个字段\n"
|
||||
"个人介绍 introduction\n"
|
||||
"其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n"
|
||||
@@ -42,13 +38,15 @@ class ExtractPeopleAgent(BaseAgent):
|
||||
response = self.llm.invoke(prompt)
|
||||
logging.info(f"llm response: {response.content}")
|
||||
try:
|
||||
return People.from_dict(json.loads(response.content))
|
||||
people = People.from_dict(json.loads(response.content))
|
||||
err = people.validate()
|
||||
if not err.success:
|
||||
raise ValueError(f"Failed to validate people info: {err.info}")
|
||||
return people
|
||||
except json.JSONDecodeError:
|
||||
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
|
||||
return None
|
||||
except ValueError as e:
|
||||
logging.error(f"Failed to validate people info: {e}")
|
||||
return None
|
||||
pass
|
||||
|
||||
class SummaryPeopleAgent(BaseAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
@@ -1,52 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# created by mmmy on 2025-09-27
|
||||
|
||||
import logging
|
||||
from ai.agent import ExtractPeopleAgent
|
||||
from utils import ocr, vsdb, obs
|
||||
from models.people import People
|
||||
from fastapi import FastAPI
|
||||
|
||||
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
|
||||
|
||||
class App:
|
||||
def __init__(self):
|
||||
self.extract_people_agent = ExtractPeopleAgent()
|
||||
self.ocr = ocr.get_instance()
|
||||
self.vedb = vsdb.get_instance(db_type='chromadb')
|
||||
self.obs = obs.get_instance()
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
class InputForPeople:
|
||||
image: bytes
|
||||
text: str
|
||||
|
||||
def input_to_people(self, input: InputForPeople) -> People:
|
||||
if not input.image and not input.text:
|
||||
return None
|
||||
if input.image:
|
||||
content = self.ocr.recognize_image_text(input.image)
|
||||
else:
|
||||
content = input.text
|
||||
print(content)
|
||||
people = self._trans_text_to_people(content)
|
||||
return people
|
||||
|
||||
def _trans_text_to_people(self, text: str) -> People:
|
||||
if not text:
|
||||
return None
|
||||
person = self.extract_people_agent.extract_people_info(text)
|
||||
print(person)
|
||||
return person
|
||||
|
||||
def create_people(self, people: People) -> bool:
|
||||
if not people:
|
||||
return False
|
||||
try:
|
||||
people.save()
|
||||
except Exception as e:
|
||||
logging.error(f"保存人物到数据库失败: {e}")
|
||||
return False
|
||||
return True
|
||||
23
src/main.py
23
src/main.py
@@ -1,13 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# created by mmmy on 2025-09-27
|
||||
import logging
|
||||
import os
|
||||
import argparse
|
||||
from venv import logger
|
||||
import uvicorn
|
||||
from app.api import api
|
||||
from utils import obs, ocr, vsdb, logger, config
|
||||
from storage import people_store
|
||||
from services import people as people_service
|
||||
from utils import config, logger, obs, ocr, rldb
|
||||
|
||||
from web.api import api
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
@@ -15,14 +14,20 @@ def main():
|
||||
parser = argparse.ArgumentParser(description='IF.u 服务')
|
||||
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
|
||||
args = parser.parse_args()
|
||||
|
||||
config.init(args.config)
|
||||
logger.init()
|
||||
obs.init()
|
||||
|
||||
rldb.init()
|
||||
|
||||
ocr.init()
|
||||
vsdb.init()
|
||||
people_store.init()
|
||||
obs.init()
|
||||
|
||||
people_service.init()
|
||||
|
||||
conf = config.get_instance()
|
||||
host = conf.get('web_service', 'server_host', fallback='127.0.0.1')
|
||||
|
||||
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
|
||||
port = conf.getint('web_service', 'server_port', fallback=8099)
|
||||
uvicorn.run(api, host=host, port=port)
|
||||
|
||||
|
||||
@@ -1,8 +1,29 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# created by mmmy on 2025-09-30
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, func
|
||||
from utils.rldb import RLDBBaseModel
|
||||
from utils.error import ErrorCode, error
|
||||
|
||||
class PeopleRLDBModel(RLDBBaseModel):
|
||||
__tablename__ = 'peoples'
|
||||
id = Column(String(36), primary_key=True)
|
||||
name = Column(String(255), index=True)
|
||||
contact = Column(String(255), index=True)
|
||||
gender = Column(String(10))
|
||||
age = Column(Integer)
|
||||
height = Column(Integer)
|
||||
marital_status = Column(String(20))
|
||||
match_requirement = Column(Text)
|
||||
introduction = Column(Text)
|
||||
comments = Column(Text)
|
||||
cover = Column(String(255), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
|
||||
class People:
|
||||
@@ -18,15 +39,7 @@ class People:
|
||||
age: int
|
||||
# 身高(cm)
|
||||
height: int
|
||||
# 体重(kg)
|
||||
# weight: int
|
||||
# 婚姻状况
|
||||
# [
|
||||
# "未婚(single)",
|
||||
# "已婚(married)",
|
||||
# "离异(divorced)",
|
||||
# "丧偶(widowed)"
|
||||
# ]
|
||||
marital_status: str
|
||||
# 择偶要求
|
||||
match_requirement: str
|
||||
@@ -34,8 +47,9 @@ class People:
|
||||
introduction: Dict[str, str]
|
||||
# 总结评价
|
||||
comments: Dict[str, str]
|
||||
# 封面
|
||||
cover: str = None
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# 初始化所有属性,从kwargs中获取值,如果不存在则设置默认值
|
||||
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
|
||||
@@ -48,9 +62,14 @@ class People:
|
||||
self.match_requirement = kwargs.get('match_requirement', '') if kwargs.get('match_requirement', '') is not None else ''
|
||||
self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {}
|
||||
self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {}
|
||||
self.cover = kwargs.get('cover', None) if kwargs.get('cover', None) is not None else None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.tonl()
|
||||
# 返回对象的字符串表示,包含所有属性
|
||||
return (f"People(id={self.id}, name={self.name}, contact={self.contact}, gender={self.gender}, "
|
||||
f"age={self.age}, height={self.height}, marital_status={self.marital_status}, "
|
||||
f"match_requirement={self.match_requirement}, introduction={self.introduction}, "
|
||||
f"comments={self.comments}, cover={self.cover})")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
@@ -65,6 +84,23 @@ class People:
|
||||
del data['deleted_at']
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_rldb_model(cls, data: PeopleRLDBModel):
|
||||
# 将关系数据库模型转换为对象
|
||||
return cls(
|
||||
id=data.id,
|
||||
name=data.name,
|
||||
contact=data.contact,
|
||||
gender=data.gender,
|
||||
age=data.age,
|
||||
height=data.height,
|
||||
marital_status=data.marital_status,
|
||||
match_requirement=data.match_requirement,
|
||||
introduction=json.loads(data.introduction) if data.introduction else {},
|
||||
comments=json.loads(data.comments) if data.comments else {},
|
||||
cover=data.cover,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# 将对象转换为字典格式
|
||||
return {
|
||||
@@ -78,44 +114,37 @@ class People:
|
||||
'match_requirement': self.match_requirement,
|
||||
'introduction': self.introduction,
|
||||
'comments': self.comments,
|
||||
}
|
||||
'cover': self.cover,
|
||||
}
|
||||
|
||||
def to_rldb_model(self) -> PeopleRLDBModel:
|
||||
# 将对象转换为关系数据库模型
|
||||
return PeopleRLDBModel(
|
||||
id=self.id,
|
||||
name=self.name,
|
||||
contact=self.contact,
|
||||
gender=self.gender,
|
||||
age=self.age,
|
||||
height=self.height,
|
||||
marital_status=self.marital_status,
|
||||
match_requirement=self.match_requirement,
|
||||
introduction=json.dumps(self.introduction, ensure_ascii=False),
|
||||
comments=json.dumps(self.comments, ensure_ascii=False),
|
||||
cover=self.cover,
|
||||
)
|
||||
|
||||
def meta(self) -> Dict[str, str]:
|
||||
# 返回对象的元数据信息
|
||||
meta = {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'gender': self.gender,
|
||||
'age': self.age,
|
||||
'height': self.height,
|
||||
'marital_status': self.marital_status,
|
||||
}
|
||||
logging.info(f"people meta: {meta}")
|
||||
return meta
|
||||
|
||||
def tonl(self) -> str:
|
||||
# 将对象转换为文档格式的字符串
|
||||
doc = []
|
||||
doc.append(f"姓名: {self.name}")
|
||||
doc.append(f"性别: {self.gender}")
|
||||
if self.age:
|
||||
doc.append(f"年龄: {self.age}")
|
||||
if self.height:
|
||||
doc.append(f"身高: {self.height}cm")
|
||||
if self.marital_status:
|
||||
doc.append(f"婚姻状况: {self.marital_status}")
|
||||
if self.match_requirement:
|
||||
doc.append(f"择偶要求: {self.match_requirement}")
|
||||
if self.introduction:
|
||||
doc.append("个人介绍:")
|
||||
for key, value in self.introduction.items():
|
||||
doc.append(f" - {key}: {value}")
|
||||
if self.comments:
|
||||
doc.append("总结评价:")
|
||||
for key, value in self.comments.items():
|
||||
doc.append(f" - {key}: {value}")
|
||||
return '\n'.join(doc)
|
||||
|
||||
def comment(self, comment: Dict[str, str]):
|
||||
# 添加总结评价
|
||||
self.comments.update(comment)
|
||||
def validate(self) -> error:
|
||||
err = error(ErrorCode.SUCCESS, "")
|
||||
if not self.name:
|
||||
logging.error("Name is required")
|
||||
err = error(ErrorCode.MODEL_ERROR, "Name is required")
|
||||
if not self.gender in ['男', '女', '未知']:
|
||||
logging.error("Gender must be '男', '女', or '未知'")
|
||||
err = error(ErrorCode.MODEL_ERROR, "Gender must be '男', '女', or '未知'")
|
||||
if not isinstance(self.age, int) or self.age <= 0:
|
||||
logging.error("Age must be an integer and greater than 0")
|
||||
err = error(ErrorCode.MODEL_ERROR, "Age must be an integer and greater than 0")
|
||||
if not isinstance(self.height, int) or self.height <= 0:
|
||||
logging.error("Height must be an integer and greater than 0")
|
||||
err = error(ErrorCode.MODEL_ERROR, "Height must be an integer and greater than 0")
|
||||
return err
|
||||
|
||||
78
src/services/people.py
Normal file
78
src/services/people.py
Normal file
@@ -0,0 +1,78 @@
|
||||
|
||||
|
||||
|
||||
import uuid
|
||||
from models.people import People, PeopleRLDBModel
|
||||
from utils.error import ErrorCode, error
|
||||
from utils import rldb
|
||||
|
||||
|
||||
class PeopleService:
|
||||
def __init__(self):
|
||||
self.rldb = rldb.get_instance()
|
||||
|
||||
def save(self, people: People) -> (str, error):
|
||||
"""
|
||||
保存人物到数据库和向量数据库
|
||||
|
||||
:param people: 人物对象
|
||||
:return: 人物ID
|
||||
"""
|
||||
# 0. 生成 people id
|
||||
people.id = people.id if people.id else uuid.uuid4().hex
|
||||
|
||||
# 1. 转换模型,并保存到 SQL 数据库
|
||||
people_orm = people.to_rldb_model()
|
||||
self.rldb.upsert(people_orm)
|
||||
|
||||
return people.id, error(ErrorCode.SUCCESS, "")
|
||||
|
||||
def delete(self, people_id: str) -> error:
|
||||
"""
|
||||
删除人物从数据库和向量数据库
|
||||
|
||||
:param people_id: 人物ID
|
||||
:return: 错误对象
|
||||
"""
|
||||
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
|
||||
if not people_orm:
|
||||
return error(ErrorCode.RLDB_ERROR, f"people {people_id} not found")
|
||||
self.rldb.delete(people_orm)
|
||||
return error(ErrorCode.SUCCESS, "")
|
||||
|
||||
def get(self, people_id: str) -> (People, error):
|
||||
"""
|
||||
从数据库获取人物
|
||||
|
||||
:param people_id: 人物ID
|
||||
:return: 人物对象
|
||||
"""
|
||||
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
|
||||
if not people_orm:
|
||||
return None, error(ErrorCode.MODEL_ERROR, f"people {people_id} not found")
|
||||
return People.from_rldb_model(people_orm), error(ErrorCode.SUCCESS, "")
|
||||
|
||||
def list(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> (list[People], error):
|
||||
"""
|
||||
从数据库列出人物
|
||||
|
||||
:param conds: 查询条件字典
|
||||
:param limit: 分页大小
|
||||
:param offset: 分页偏移量
|
||||
:return: 人物对象列表
|
||||
"""
|
||||
people_orms = self.rldb.query(PeopleRLDBModel, **conds)
|
||||
peoples = [People.from_rldb_model(people_orm) for people_orm in people_orms]
|
||||
|
||||
return peoples, error(ErrorCode.SUCCESS, "")
|
||||
|
||||
|
||||
|
||||
people_service = None
|
||||
|
||||
def init():
|
||||
global people_service
|
||||
people_service = PeopleService()
|
||||
|
||||
def get_instance() -> PeopleService:
|
||||
return people_service
|
||||
@@ -1,216 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, func
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from utils.config import get_instance as get_config
|
||||
from utils.vsdb import VectorDB, get_instance as get_vsdb
|
||||
from utils.obs import OBS, get_instance as get_obs
|
||||
|
||||
from models.people import People
|
||||
|
||||
people_store = None
|
||||
Base = declarative_base()
|
||||
class PeopleORM(Base):
|
||||
__tablename__ = 'peoples'
|
||||
id = Column(String(36), primary_key=True)
|
||||
name = Column(String(255), index=True)
|
||||
contact = Column(String(255), index=True)
|
||||
gender = Column(String(10))
|
||||
age = Column(Integer)
|
||||
height = Column(Integer)
|
||||
marital_status = Column(String(20))
|
||||
match_requirement = Column(Text)
|
||||
introduction = Column(Text)
|
||||
comments = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
def parse_from_people(self, people: People):
|
||||
import json
|
||||
self.id = people.id
|
||||
self.name = people.name
|
||||
self.contact = people.contact
|
||||
self.gender = people.gender
|
||||
self.age = people.age
|
||||
self.height = people.height
|
||||
self.marital_status = people.marital_status
|
||||
self.match_requirement = people.match_requirement
|
||||
# 将字典类型字段序列化为JSON字符串存储
|
||||
self.introduction = json.dumps(people.introduction, ensure_ascii=False)
|
||||
self.comments = json.dumps(people.comments, ensure_ascii=False)
|
||||
|
||||
def to_people(self) -> People:
|
||||
import json
|
||||
people = People()
|
||||
people.id = self.id
|
||||
people.name = self.name
|
||||
people.contact = self.contact
|
||||
people.gender = self.gender
|
||||
people.age = self.age
|
||||
people.height = self.height
|
||||
people.marital_status = self.marital_status
|
||||
people.match_requirement = self.match_requirement
|
||||
# 将JSON字符串反序列化为字典类型字段
|
||||
try:
|
||||
people.introduction = json.loads(self.introduction) if self.introduction else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
people.introduction = {}
|
||||
|
||||
try:
|
||||
people.comments = json.loads(self.comments) if self.comments else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
people.comments = {}
|
||||
|
||||
return people
|
||||
|
||||
class PeopleStore:
|
||||
def __init__(self):
|
||||
config = get_config()
|
||||
self.sqldb_engine = create_engine(config.get("sqlalchemy", "database_dsn"))
|
||||
Base.metadata.create_all(self.sqldb_engine)
|
||||
self.session_maker = sessionmaker(bind=self.sqldb_engine)
|
||||
self.vsdb: VectorDB = get_vsdb()
|
||||
self.obs: OBS = get_obs()
|
||||
|
||||
def save(self, people: People) -> str:
|
||||
"""
|
||||
保存人物到数据库和向量数据库
|
||||
|
||||
:param people: 人物对象
|
||||
:return: 人物ID
|
||||
"""
|
||||
# 0. 生成 people id
|
||||
people.id = people.id if people.id else uuid.uuid4().hex
|
||||
|
||||
# 1. 转换模型,并保存到 SQL 数据库
|
||||
people_orm = PeopleORM()
|
||||
people_orm.parse_from_people(people)
|
||||
with self.session_maker() as session:
|
||||
session.add(people_orm)
|
||||
session.commit()
|
||||
|
||||
# 2. 保存到向量数据库
|
||||
people_metadata = people.meta()
|
||||
people_document = people.tonl()
|
||||
logging.info(f"people: {people}")
|
||||
logging.info(f"people_metadata: {people_metadata}")
|
||||
logging.info(f"people_document: {people_document}")
|
||||
results = self.vsdb.insert(metadatas=[people_metadata], documents=[people_document], ids=[people.id])
|
||||
logging.info(f"results: {results}")
|
||||
if len(results) == 0:
|
||||
raise Exception("insert failed")
|
||||
|
||||
# 3. 保存到 OBS 存储
|
||||
people_dict = people.to_dict()
|
||||
people_json = json.dumps(people_dict, ensure_ascii=False)
|
||||
obs_url = self.obs.Put(f"peoples/{people.id}/detail.json", people_json.encode('utf-8'))
|
||||
logging.info(f"obs_url: {obs_url}")
|
||||
|
||||
return people.id
|
||||
|
||||
def update(self, people: People) -> None:
|
||||
raise Exception("update not implemented")
|
||||
return None
|
||||
|
||||
def find(self, people_id: str) -> People:
|
||||
"""
|
||||
根据人物ID查询人物
|
||||
|
||||
:param people_id: 人物ID
|
||||
:return: 人物对象
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
people_orm = session.query(PeopleORM).filter(
|
||||
PeopleORM.id == people_id,
|
||||
PeopleORM.deleted_at.is_(None)
|
||||
).first()
|
||||
if not people_orm:
|
||||
raise Exception(f"people not found, people_id: {people_id}")
|
||||
return people_orm.to_people()
|
||||
|
||||
def query(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> list[People]:
|
||||
"""
|
||||
根据查询条件查询人物
|
||||
|
||||
:param conds: 查询条件字典
|
||||
:param limit: 分页大小
|
||||
:param offset: 分页偏移量
|
||||
:return: 人物对象列表
|
||||
"""
|
||||
if conds is None:
|
||||
conds = {}
|
||||
with self.session_maker() as session:
|
||||
people_orms = session.query(PeopleORM).filter_by(**conds).filter(
|
||||
PeopleORM.deleted_at.is_(None)
|
||||
).limit(limit).offset(offset).all()
|
||||
return [people_orm.to_people() for people_orm in people_orms]
|
||||
|
||||
def search(self, search: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[People]:
|
||||
"""
|
||||
根据搜索内容和查询条件查询人物
|
||||
|
||||
:param search: 搜索内容
|
||||
:param metadatas: 查询条件字典
|
||||
:param ids: 可选的人物ID列表,用于过滤结果
|
||||
:param top_k: 返回结果数量
|
||||
:return: 人物对象列表
|
||||
"""
|
||||
peoples = []
|
||||
results = self.vsdb.search(document=search, metadatas=metadatas, ids=ids, top_k=top_k)
|
||||
logging.info(f"results: {results}")
|
||||
people_id_list = []
|
||||
for result in results:
|
||||
people_id = result.get('id', '')
|
||||
if not people_id:
|
||||
continue
|
||||
people_id_list.append(people_id)
|
||||
logging.info(f"people_id_list: {people_id_list}")
|
||||
with self.session_maker() as session:
|
||||
people_orms = session.query(PeopleORM).filter(PeopleORM.id.in_(people_id_list)).filter(
|
||||
PeopleORM.deleted_at.is_(None)
|
||||
).all()
|
||||
# 根据 people_id_list 的顺序对查询结果进行排序
|
||||
order_map = {pid: idx for idx, pid in enumerate(people_id_list)}
|
||||
people_orms.sort(key=lambda orm: order_map.get(orm.id, len(order_map)))
|
||||
for people_orm in people_orms:
|
||||
people = people_orm.to_people()
|
||||
peoples.append(people)
|
||||
return peoples
|
||||
|
||||
def delete(self, people_id: str) -> None:
|
||||
"""
|
||||
删除人物从数据库和向量数据库
|
||||
|
||||
:param people_id: 人物ID
|
||||
"""
|
||||
# 1. 从 SQL 数据库软删除人物
|
||||
with self.session_maker() as session:
|
||||
session.query(PeopleORM).filter(
|
||||
PeopleORM.id == people_id,
|
||||
PeopleORM.deleted_at.is_(None)
|
||||
).update({PeopleORM.deleted_at: func.now()}, synchronize_session=False)
|
||||
session.commit()
|
||||
logging.info(f"人物 {people_id} 标记删除 SQL 成功")
|
||||
|
||||
# 2. 删除向量数据库中的记录
|
||||
self.vsdb.delete(ids=[people_id])
|
||||
logging.info(f"人物 {people_id} 删除向量数据库成功")
|
||||
|
||||
# 3. 删除 OBS 存储中的文件
|
||||
keys = self.obs.List(f"peoples/{people_id}/")
|
||||
for key in keys:
|
||||
self.obs.Del(key)
|
||||
logging.info(f"文件 {key} 删除 OBS 成功")
|
||||
|
||||
def init():
|
||||
global people_store
|
||||
people_store = PeopleStore()
|
||||
|
||||
|
||||
def get_instance() -> PeopleStore:
|
||||
return people_store
|
||||
@@ -1,3 +0,0 @@
|
||||
# 导出utils模块中的子模块
|
||||
from . import config, obs, ocr, vsdb, logger
|
||||
__all__ = ['config', 'obs', 'ocr', 'vsdb', 'logger']
|
||||
|
||||
@@ -2,7 +2,6 @@ import configparser
|
||||
|
||||
config = None
|
||||
|
||||
|
||||
def init(config_file: str):
|
||||
global config
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
30
src/utils/error.py
Normal file
30
src/utils/error.py
Normal file
@@ -0,0 +1,30 @@
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import Protocol
|
||||
|
||||
class ErrorCode(Enum):
|
||||
SUCCESS = 0
|
||||
MODEL_ERROR = 1000
|
||||
RLDB_ERROR = 2100
|
||||
|
||||
class error(Protocol):
|
||||
_error_code: int = 0
|
||||
_error_info: str = ""
|
||||
|
||||
def __init__(self, error_code: ErrorCode, error_info: str):
|
||||
self._error_code = int(error_code.value)
|
||||
self._error_info = error_info
|
||||
logging.info(f"errorcode: {type(self._error_code)}")
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self._error_code}, {self._error_info})"
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
return self._error_code
|
||||
@property
|
||||
def info(self) -> str:
|
||||
return self._error_info
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return self._error_code == 0
|
||||
167
src/utils/rldb.py
Normal file
167
src/utils/rldb.py
Normal file
@@ -0,0 +1,167 @@
|
||||
|
||||
from typing import Protocol
|
||||
import uuid
|
||||
from sqlalchemy import Column, DateTime, String, create_engine, func
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from .config import get_instance as get_config
|
||||
|
||||
SQLAlchemyBase = declarative_base()
|
||||
|
||||
|
||||
class RLDBBaseModel(SQLAlchemyBase):
|
||||
__abstract__ = True
|
||||
id = Column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
def __str__(self) -> str:
|
||||
# 遍历所有的field,打印出所有的field和value, id 永远排在第一, 三个时间戳排在最后, 其余字段按定义顺序排序
|
||||
fields = [field for field in self.__dict__ if not field.startswith('_')]
|
||||
fields.remove("id") if "id" in fields else None
|
||||
fields.remove("created_at") if "created_at" in fields else None
|
||||
fields.remove("updated_at") if "updated_at" in fields else None
|
||||
fields.remove("deleted_at") if "deleted_at" in fields else None
|
||||
fields = ["id"] + fields + ["created_at", "updated_at", "deleted_at"]
|
||||
field_values = [f"{field}={getattr(self, field)}" for field in fields]
|
||||
return f"{self.__class__.__name__}({', '.join(field_values)})"
|
||||
|
||||
|
||||
class RelationalDB(Protocol):
|
||||
def insert(self, data: RLDBBaseModel) -> str:
|
||||
...
|
||||
|
||||
def update(self, data: RLDBBaseModel) -> str:
|
||||
...
|
||||
|
||||
def upsert(self, data: RLDBBaseModel) -> str:
|
||||
...
|
||||
|
||||
def delete(self, data: RLDBBaseModel) -> str:
|
||||
...
|
||||
|
||||
def get(self,
|
||||
model: type[RLDBBaseModel],
|
||||
id: str,
|
||||
include_deleted: bool = False
|
||||
) -> RLDBBaseModel:
|
||||
...
|
||||
|
||||
def query(self,
|
||||
model: type[RLDBBaseModel],
|
||||
include_deleted: bool = False,
|
||||
limit: int = None,
|
||||
offset: int = None,
|
||||
**filters
|
||||
) -> list[RLDBBaseModel]:
|
||||
...
|
||||
|
||||
|
||||
class SqlAlchemyDB():
|
||||
def __init__(self, dsn: str = None) -> None:
|
||||
config = get_config()
|
||||
dsn = dsn if dsn else config.get("sqlalchemy", "database_dsn")
|
||||
self.sqldb_engine = create_engine(dsn)
|
||||
SQLAlchemyBase.metadata.create_all(self.sqldb_engine)
|
||||
self.session_maker = sessionmaker(bind=self.sqldb_engine)
|
||||
|
||||
def insert(self, data: RLDBBaseModel) -> str:
|
||||
with self.session_maker() as session:
|
||||
session.add(data)
|
||||
session.commit()
|
||||
return data.id
|
||||
|
||||
def update(self, data: RLDBBaseModel) -> str:
|
||||
with self.session_maker() as session:
|
||||
session.merge(data)
|
||||
session.commit()
|
||||
return data.id
|
||||
|
||||
def upsert(self, data: RLDBBaseModel) -> str:
|
||||
existed = data.id and data.id != "" and self.get(data.__class__, data.id) is not None
|
||||
with self.session_maker() as session:
|
||||
session.merge(data) if existed else session.add(data)
|
||||
session.commit()
|
||||
return data.id
|
||||
|
||||
def delete(self, data: RLDBBaseModel) -> str:
|
||||
with self.session_maker() as session:
|
||||
session.delete(data)
|
||||
session.commit()
|
||||
return data.id
|
||||
|
||||
def get(self,
|
||||
model: type[RLDBBaseModel],
|
||||
id: str,
|
||||
) -> RLDBBaseModel:
|
||||
with self.session_maker() as session:
|
||||
sel = session.query(model)
|
||||
sel = sel.filter(model.id == id)
|
||||
sel = sel.filter(model.deleted_at.is_(None))
|
||||
result = sel.first()
|
||||
return result
|
||||
|
||||
def query(self,
|
||||
model: type[RLDBBaseModel],
|
||||
limit: int = None,
|
||||
offset: int = None,
|
||||
**filters
|
||||
) -> list[RLDBBaseModel]:
|
||||
results: list[RLDBBaseModel] = []
|
||||
with self.session_maker() as session:
|
||||
sel = session.query(model)
|
||||
sel = sel.filter(model.deleted_at.is_(None))
|
||||
if filters:
|
||||
sel = sel.filter_by(**filters)
|
||||
if limit:
|
||||
sel = sel.limit(limit)
|
||||
if offset:
|
||||
sel = sel.offset(offset)
|
||||
results = sel.all()
|
||||
results.sort(key=lambda x: x.created_at, reverse=True)
|
||||
return results
|
||||
|
||||
_rldb_instance: RelationalDB = None
|
||||
|
||||
def init(type: str = "sqlalchemy", dsn: str = None):
|
||||
global _rldb_instance
|
||||
if type == "sqlalchemy":
|
||||
_rldb_instance = SqlAlchemyDB(dsn)
|
||||
else:
|
||||
raise ValueError(f"RelationalDB type {type} not supported")
|
||||
|
||||
def get_instance() -> RelationalDB:
|
||||
global _rldb_instance
|
||||
return _rldb_instance
|
||||
|
||||
if __name__ == "__main__":
|
||||
class TestModel(RLDBBaseModel):
|
||||
__tablename__ = "test_model"
|
||||
name = Column(String(36), nullable=True)
|
||||
conf = Column(String(96), nullable=True)
|
||||
init("sqlalchemy", dsn="sqlite:///./demo_storage/rldb.db")
|
||||
db = get_instance()
|
||||
|
||||
test_data = TestModel(name="test", conf="test.config")
|
||||
print(f"before insert: {test_data}")
|
||||
ret = db.insert(test_data)
|
||||
print(f"after insert: {test_data}")
|
||||
|
||||
print(f"before update: {test_data}")
|
||||
test_data.conf = "test.config.new"
|
||||
ret = db.update(test_data)
|
||||
print(f"after update: {test_data}")
|
||||
|
||||
test2_data = TestModel(name="test", conf="test2.config")
|
||||
print(f"before upsert: {test2_data}")
|
||||
ret = db.upsert(test2_data)
|
||||
print(f"after upsert: {test2_data}")
|
||||
|
||||
get_data = db.get(TestModel, test_data.id)
|
||||
print(f"get data: {get_data}")
|
||||
|
||||
query_data = db.query(TestModel, name="test")
|
||||
for data in query_data:
|
||||
print(data.id, data.name, data.conf)
|
||||
print(f"query data: {data}")
|
||||
ret = db.delete(data)
|
||||
print(f"delete data.id: {ret}")
|
||||
@@ -1,241 +0,0 @@
|
||||
import uuid
|
||||
import chromadb
|
||||
import logging
|
||||
from typing import Protocol
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils import embedding_functions
|
||||
from .config import get_instance as get_config
|
||||
|
||||
class VectorDB(Protocol):
|
||||
|
||||
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
|
||||
"""
|
||||
插入向量到数据库
|
||||
|
||||
Args:
|
||||
vector (list[float]): 向量
|
||||
metadata (dict): 元数据
|
||||
|
||||
Returns:
|
||||
bool: 是否插入成功
|
||||
"""
|
||||
...
|
||||
|
||||
def delete(self, ids: list[str]) -> bool:
|
||||
"""
|
||||
Delete documents from a collection.
|
||||
|
||||
Args:
|
||||
ids: List of IDs to delete
|
||||
|
||||
Returns:
|
||||
bool: Whether deletion was successful
|
||||
"""
|
||||
...
|
||||
|
||||
def query(self, metadatas: dict, ids: list[str], top_k: int = 5) -> list[dict]:
|
||||
"""
|
||||
查询向量数据库
|
||||
|
||||
Args:
|
||||
query_vector (list[float]): 查询向量
|
||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict]: 查询结果列表
|
||||
"""
|
||||
...
|
||||
|
||||
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
||||
"""
|
||||
搜索向量数据库
|
||||
|
||||
Args:
|
||||
document: Document to search
|
||||
metadatas: Metadata to filter by
|
||||
ids: List of IDs to filter by
|
||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict]: 查询结果列表
|
||||
"""
|
||||
...
|
||||
|
||||
class ChromaDB:
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize the ChromaDB instance.
|
||||
"""
|
||||
config = get_config()
|
||||
self.embedding_functions = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_base=config.get("voc-engine_embedding", "api_url"),
|
||||
api_key=config.get("voc-engine_embedding", "api_key"),
|
||||
model_name=config.get("voc-engine_embedding", "endpoint"),
|
||||
)
|
||||
persist_directory = config.get("chroma_vsdb", "database_dir", fallback=None)
|
||||
logging.debug(f"persist_directory: {persist_directory}")
|
||||
if persist_directory:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=persist_directory,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
else:
|
||||
self.client = chromadb.Client(
|
||||
settings=Settings(anonymized_telemetry=False),
|
||||
)
|
||||
self.collection_name = config.get("chroma_vsdb", "collection_name", fallback="peoples")
|
||||
metadata: dict = kwargs.get('collection_metadata', {'hnsw:space': 'cosine'})
|
||||
metadata['hnsw:space'] = metadata.get('hnsw:space', 'cosine')
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
embedding_function=self.embedding_functions,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
|
||||
"""
|
||||
Insert documents into a collection.
|
||||
|
||||
Args:
|
||||
metadatas: List of metadata corresponding to each document
|
||||
documents: List of documents to insert
|
||||
ids: Optional list of unique IDs for each document. If None, IDs will be generated.
|
||||
|
||||
Returns:
|
||||
list[str]: List of inserted IDs
|
||||
"""
|
||||
|
||||
if not ids:
|
||||
# Generate unique IDs if not provided
|
||||
ids = [str(uuid.uuid4()) for _ in range(len(documents))]
|
||||
|
||||
self.collection.add(
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids
|
||||
)
|
||||
|
||||
return ids
|
||||
|
||||
def delete(self, ids: list[str]) -> bool:
|
||||
"""
|
||||
Delete documents from a collection.
|
||||
|
||||
Args:
|
||||
ids: List of IDs to delete
|
||||
|
||||
Returns:
|
||||
bool: Whether deletion was successful
|
||||
"""
|
||||
try:
|
||||
self.collection.delete(ids)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error deleting documents: {e}")
|
||||
return False
|
||||
|
||||
def query(self, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
||||
"""
|
||||
查询向量数据库
|
||||
|
||||
Args:
|
||||
metadatas: Metadata to filter by
|
||||
ids: List of IDs to query
|
||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict]: 查询结果列表
|
||||
"""
|
||||
results = self.collection.query(
|
||||
query_embeddings=None,
|
||||
query_texts=None,
|
||||
n_results=top_k,
|
||||
where=metadatas,
|
||||
ids=ids,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
formatted_results = []
|
||||
for i in range(len(results['ids'][0])):
|
||||
result = {
|
||||
'id': results['ids'][0][i],
|
||||
'distance': results['distances'][0][i],
|
||||
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
|
||||
'document': results['documents'][0][i] if results['documents'][0] else ''
|
||||
}
|
||||
formatted_results.append(result)
|
||||
return formatted_results
|
||||
|
||||
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
||||
"""
|
||||
搜索向量数据库
|
||||
|
||||
Args:
|
||||
document: Document to search
|
||||
metadatas: Metadata to filter by
|
||||
ids: List of IDs to filter by
|
||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict]: 查询结果列表
|
||||
"""
|
||||
results = self.collection.query(
|
||||
query_embeddings=None,
|
||||
query_texts=[document],
|
||||
n_results=top_k,
|
||||
where=metadatas if metadatas else None,
|
||||
ids=ids,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
formatted_results = []
|
||||
for i in range(len(results['ids'][0])):
|
||||
logging.info(f"result id: {results['ids'][0][i]}, distance: {results['distances'][0][i]}")
|
||||
result = {
|
||||
'id': results['ids'][0][i],
|
||||
'distance': results['distances'][0][i],
|
||||
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
|
||||
'document': results['documents'][0][i] if results['documents'][0] else ''
|
||||
}
|
||||
formatted_results.append(result)
|
||||
return formatted_results
|
||||
pass
|
||||
|
||||
_vsdb_instance: VectorDB = None
|
||||
|
||||
def init():
|
||||
global _vsdb_instance
|
||||
_vsdb_instance = ChromaDB()
|
||||
|
||||
def get_instance() -> VectorDB:
|
||||
global _vsdb_instance
|
||||
return _vsdb_instance
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
from logger import init as init_logger
|
||||
init_logger(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG)
|
||||
|
||||
from config import init as init_config
|
||||
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
|
||||
init_config(config_file)
|
||||
|
||||
init()
|
||||
vsdb = get_instance()
|
||||
metadatas = [
|
||||
{'name': '丽丽'},
|
||||
{'name': '志刚'},
|
||||
{'name': '张三'},
|
||||
{'name': '李四'},
|
||||
]
|
||||
documents = [
|
||||
'姓名: 丽丽, 性别: 女, 年龄: 23, 爱好: 爬山、骑行、攀岩、跳伞、蹦极',
|
||||
"姓名: 志刚, 性别: 男, 年龄: 25, 爱好: 读书、游戏",
|
||||
"姓名: 张三, 性别: 男, 年龄: 30, 爱好: 画画、写作、阅读、逛展、旅行",
|
||||
"姓名: 李四, 性别: 男, 年龄: 35, 爱好: 做饭、美食、旅游"
|
||||
]
|
||||
search_text = '25岁以下的'
|
||||
ids = vsdb.insert(metadatas, documents)
|
||||
results = vsdb.search(search_text, None, None, top_k=4)
|
||||
for result in results:
|
||||
print(result['document'], ' ', result['distance'])
|
||||
@@ -1,20 +1,16 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
from fastapi import FastAPI, File, UploadFile, Query
|
||||
from fastapi import FastAPI, UploadFile, File, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ai.agent import ExtractPeopleAgent
|
||||
from models.people import People
|
||||
from utils import obs, ocr, vsdb
|
||||
from storage.people_store import get_instance as get_people_store
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from services.people import get_instance as get_people_service
|
||||
from models.people import People
|
||||
from agents.extract_people_agent import ExtractPeopleAgent
|
||||
from utils import obs, ocr
|
||||
|
||||
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
|
||||
api = FastAPI(title="Single People Management and Searching", version="0.1")
|
||||
api.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -28,19 +24,21 @@ class BaseResponse(BaseModel):
|
||||
error_info: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
@api.post("/ping")
|
||||
async def ping():
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
class PostInputRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
@api.post("/input")
|
||||
@api.post("/recognition/input")
|
||||
async def post_input(request: PostInputRequest):
|
||||
extra_agent = ExtractPeopleAgent()
|
||||
people = extra_agent.extract_people_info(request.text)
|
||||
logging.info(f"people: {people}")
|
||||
people = extract_people(request.text)
|
||||
resp = BaseResponse(error_code=0, error_info="success")
|
||||
resp.data = people.to_dict()
|
||||
return resp
|
||||
|
||||
@api.post("/input_image")
|
||||
@api.post("/recognition/image")
|
||||
async def post_input_image(image: UploadFile = File(...)):
|
||||
# 实现上传图片的处理
|
||||
# 保存上传的图片文件
|
||||
@@ -65,20 +63,53 @@ async def post_input_image(image: UploadFile = File(...)):
|
||||
ocr_result = ocr_util.recognize_image_text(obs_url)
|
||||
logging.info(f"ocr_result: {ocr_result}")
|
||||
|
||||
post_input_request = PostInputRequest(text=ocr_result)
|
||||
return await post_input(post_input_request)
|
||||
people = extract_people(ocr_result, obs_url)
|
||||
resp = BaseResponse(error_code=0, error_info="success")
|
||||
resp.data = people.to_dict()
|
||||
return resp
|
||||
|
||||
def extract_people(text: str, cover_link: str = None) -> People:
|
||||
extra_agent = ExtractPeopleAgent()
|
||||
people = extra_agent.extract_people_info(text)
|
||||
people.cover = cover_link
|
||||
logging.info(f"people: {people}")
|
||||
return people
|
||||
|
||||
class PostPeopleRequest(BaseModel):
|
||||
people: dict
|
||||
|
||||
@api.post("/peoples")
|
||||
@api.post("/people")
|
||||
async def post_people(post_people_request: PostPeopleRequest):
|
||||
logging.debug(f"post_people_request: {post_people_request}")
|
||||
people = People.from_dict(post_people_request.people)
|
||||
store = get_people_store()
|
||||
people.id = store.save(people)
|
||||
service = get_people_service()
|
||||
people.id, error = service.save(people)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success", data=people.id)
|
||||
|
||||
@api.put("/people/{people_id}")
|
||||
async def update_people(people_id: str, post_people_request: PostPeopleRequest):
|
||||
logging.debug(f"post_people_request: {post_people_request}")
|
||||
people = People.from_dict(post_people_request.people)
|
||||
people.id = people_id
|
||||
service = get_people_service()
|
||||
res, error = service.get(people_id)
|
||||
if not error.success or not res:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
_, error = service.save(people)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
@api.delete("/people/{people_id}")
|
||||
async def delete_people(people_id: str):
|
||||
service = get_people_service()
|
||||
error = service.delete(people_id)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
class GetPeopleRequest(BaseModel):
|
||||
query: Optional[str] = None
|
||||
conds: Optional[dict] = None
|
||||
@@ -93,8 +124,6 @@ async def get_peoples(
|
||||
marital_status: Optional[str] = Query(None, description="婚姻状态"),
|
||||
limit: int = Query(10, description="分页大小"),
|
||||
offset: int = Query(0, description="分页偏移量"),
|
||||
search: Optional[str] = Query(None, description="搜索内容"),
|
||||
top_k: int = Query(5, description="搜索结果数量"),
|
||||
):
|
||||
|
||||
# 解析查询参数为字典
|
||||
@@ -110,21 +139,14 @@ async def get_peoples(
|
||||
if marital_status:
|
||||
conds["marital_status"] = marital_status
|
||||
|
||||
logging.info(f"conds: , limit: {limit}, offset: {offset}, search: {search}, top_k: {top_k}")
|
||||
logging.info(f"conds: , limit: {limit}, offset: {offset}")
|
||||
|
||||
results = []
|
||||
store = get_people_store()
|
||||
if search:
|
||||
results = store.search(search, conds, ids=None, top_k=top_k)
|
||||
logging.info(f"search results: {results}")
|
||||
else:
|
||||
results = store.query(conds, limit=limit, offset=offset)
|
||||
logging.info(f"query results: {results}")
|
||||
service = get_people_service()
|
||||
results, error = service.list(conds, limit=limit, offset=offset)
|
||||
logging.info(f"query results: {results}")
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
peoples = [people.to_dict() for people in results]
|
||||
return BaseResponse(error_code=0, error_info="success", data=peoples)
|
||||
|
||||
@api.delete("/peoples/{people_id}")
|
||||
async def delete_people(people_id: str):
|
||||
store = get_people_store()
|
||||
store.delete(people_id)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
@@ -1,16 +0,0 @@
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../src'))
|
||||
|
||||
from utils.logger import init
|
||||
import logging
|
||||
|
||||
# 初始化日志
|
||||
init()
|
||||
|
||||
# 测试不同级别的日志
|
||||
logging.debug("这是一条调试信息")
|
||||
logging.info("这是一条普通信息")
|
||||
logging.warning("这是一条警告信息")
|
||||
logging.error("这是一条错误信息")
|
||||
logging.critical("这是一条严重错误信息")
|
||||
Reference in New Issue
Block a user