Compare commits
19 Commits
fca2b1449f
...
v0.2.0
| Author | SHA1 | Date | |
|---|---|---|---|
| 14b455c705 | |||
| 18f0083827 | |||
| b66a460dc1 | |||
| c69fc5bffa | |||
| 8174c4cfe5 | |||
| 4c48d11bfa | |||
| e74279ca5e | |||
| 13b70ba424 | |||
| fae93b5ab8 | |||
| 1a092248eb | |||
| 0a749d56e8 | |||
| 3d13aa18ae | |||
| d179418e7d | |||
| 496f35a386 | |||
| c99b324b81 | |||
| 7a189eb631 | |||
| 98cbc754f6 | |||
| 736d8ed193 | |||
| 268eb8be2b |
@@ -1,18 +1,18 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "service"
|
name = "service"
|
||||||
version = "0.1.0"
|
version = "0.1"
|
||||||
description = "This project is the web servcie sub-system for if.u projuect"
|
description = "This project is the web servcie sub-system for if.u projuect"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"alibabacloud-ocr-api20210707>=3.1.3",
|
"alibabacloud-ocr-api20210707>=3.1.3",
|
||||||
"chromadb>=1.1.1",
|
"alibabacloud-tea-openapi>=0.4.1",
|
||||||
"fastapi>=0.118.2",
|
"fastapi>=0.118.3",
|
||||||
"langchain>=0.3.27",
|
"langchain==0.3.27",
|
||||||
"langchain-openai>=0.3.35",
|
"langchain-openai==0.3.35",
|
||||||
"numpy>=2.3.3",
|
|
||||||
"pymysql>=1.1.2",
|
"pymysql>=1.1.2",
|
||||||
"python-multipart>=0.0.20",
|
"python-multipart>=0.0.20",
|
||||||
"qiniu>=7.17.0",
|
"qiniu>=7.17.0",
|
||||||
"requests>=2.32.5",
|
"sqlalchemy>=2.0.44",
|
||||||
|
"uvicorn>=0.38.0",
|
||||||
]
|
]
|
||||||
|
|||||||
22
src/agents/base_agent.py
Normal file
22
src/agents/base_agent.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from utils.config import get_instance as get_config
|
||||||
|
|
||||||
|
class BaseAgent:
|
||||||
|
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
|
||||||
|
config = get_config()
|
||||||
|
llm_api_url = api_url or config.get("ai", "llm_api_url")
|
||||||
|
llm_api_key = api_key or config.get("ai", "llm_api_key")
|
||||||
|
llm_model_name = model_name or config.get("ai", "llm_model_name")
|
||||||
|
self.llm = ChatOpenAI(
|
||||||
|
openai_api_key=llm_api_key,
|
||||||
|
openai_api_base=llm_api_url,
|
||||||
|
model_name=llm_model_name,
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryPeopleAgent(BaseAgent):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
pass
|
||||||
@@ -1,25 +1,19 @@
|
|||||||
|
|
||||||
|
import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
from .base_agent import BaseAgent
|
||||||
from models.people import People
|
from models.people import People
|
||||||
|
|
||||||
class BaseAgent:
|
|
||||||
def __init__(self):
|
|
||||||
self.llm = ChatOpenAI(
|
|
||||||
openai_api_key="56d82040-85c7-4701-8f87-734985e27909",
|
|
||||||
openai_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
|
||||||
model_name="ep-20250722161445-n9lfq"
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ExtractPeopleAgent(BaseAgent):
|
class ExtractPeopleAgent(BaseAgent):
|
||||||
def __init__(self):
|
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
|
||||||
super().__init__()
|
super().__init__(api_url, api_key, model_name)
|
||||||
self.prompt = ChatPromptTemplate.from_messages([
|
self.prompt = ChatPromptTemplate.from_messages([
|
||||||
(
|
(
|
||||||
"system",
|
"system",
|
||||||
|
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')},"
|
||||||
"你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n"
|
"你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n"
|
||||||
"姓名 name\n"
|
"姓名 name\n"
|
||||||
"性别 gender\n"
|
"性别 gender\n"
|
||||||
@@ -27,7 +21,9 @@ class ExtractPeopleAgent(BaseAgent):
|
|||||||
"身高(cm) height\n"
|
"身高(cm) height\n"
|
||||||
"婚姻状况 marital_status\n"
|
"婚姻状况 marital_status\n"
|
||||||
"择偶要求 match_requirement\n"
|
"择偶要求 match_requirement\n"
|
||||||
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n"
|
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致; 若未识别到以上的某项,则不返回该字段,不要自行填写“未知”,“未填写”等类似字眼。\n"
|
||||||
|
"其中,'年龄 age' 和 '身高(cm) height' 必须是一个整数,不能是一个字符串;\n"
|
||||||
|
"并且,'性别 gender' 根据识别结果,必须从 男,女,未知 三选一填写。\n"
|
||||||
"除了上述基本信息,还有一个字段\n"
|
"除了上述基本信息,还有一个字段\n"
|
||||||
"个人介绍 introduction\n"
|
"个人介绍 introduction\n"
|
||||||
"其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n"
|
"其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n"
|
||||||
@@ -42,13 +38,15 @@ class ExtractPeopleAgent(BaseAgent):
|
|||||||
response = self.llm.invoke(prompt)
|
response = self.llm.invoke(prompt)
|
||||||
logging.info(f"llm response: {response.content}")
|
logging.info(f"llm response: {response.content}")
|
||||||
try:
|
try:
|
||||||
return People.from_dict(json.loads(response.content))
|
people = People.from_dict(json.loads(response.content))
|
||||||
|
err = people.validate()
|
||||||
|
if not err.success:
|
||||||
|
raise ValueError(f"Failed to validate people info: {err.info}")
|
||||||
|
return people
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
|
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
|
||||||
return None
|
return None
|
||||||
|
except ValueError as e:
|
||||||
|
logging.error(f"Failed to validate people info: {e}")
|
||||||
|
return None
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class SummaryPeopleAgent(BaseAgent):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
pass
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# created by mmmy on 2025-09-27
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from ai.agent import ExtractPeopleAgent
|
|
||||||
from utils import ocr, vsdb, obs
|
|
||||||
from models.people import People
|
|
||||||
from fastapi import FastAPI
|
|
||||||
|
|
||||||
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
|
|
||||||
|
|
||||||
class App:
|
|
||||||
def __init__(self):
|
|
||||||
self.extract_people_agent = ExtractPeopleAgent()
|
|
||||||
self.ocr = ocr.get_instance()
|
|
||||||
self.vedb = vsdb.get_instance(db_type='chromadb')
|
|
||||||
self.obs = obs.get_instance()
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class InputForPeople:
|
|
||||||
image: bytes
|
|
||||||
text: str
|
|
||||||
|
|
||||||
def input_to_people(self, input: InputForPeople) -> People:
|
|
||||||
if not input.image and not input.text:
|
|
||||||
return None
|
|
||||||
if input.image:
|
|
||||||
content = self.ocr.recognize_image_text(input.image)
|
|
||||||
else:
|
|
||||||
content = input.text
|
|
||||||
print(content)
|
|
||||||
people = self._trans_text_to_people(content)
|
|
||||||
return people
|
|
||||||
|
|
||||||
def _trans_text_to_people(self, text: str) -> People:
|
|
||||||
if not text:
|
|
||||||
return None
|
|
||||||
person = self.extract_people_agent.extract_people_info(text)
|
|
||||||
print(person)
|
|
||||||
return person
|
|
||||||
|
|
||||||
def create_people(self, people: People) -> bool:
|
|
||||||
if not people:
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
people.save()
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"保存人物到数据库失败: {e}")
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
23
src/main.py
23
src/main.py
@@ -1,13 +1,12 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# created by mmmy on 2025-09-27
|
# created by mmmy on 2025-09-27
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
from venv import logger
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from app.api import api
|
from services import people as people_service
|
||||||
from utils import obs, ocr, vsdb, logger, config
|
from utils import config, logger, obs, ocr, rldb
|
||||||
from storage import people_store
|
|
||||||
|
from web.api import api
|
||||||
|
|
||||||
# 主函数
|
# 主函数
|
||||||
def main():
|
def main():
|
||||||
@@ -15,14 +14,20 @@ def main():
|
|||||||
parser = argparse.ArgumentParser(description='IF.u 服务')
|
parser = argparse.ArgumentParser(description='IF.u 服务')
|
||||||
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
|
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
config.init(args.config)
|
config.init(args.config)
|
||||||
logger.init()
|
logger.init()
|
||||||
obs.init()
|
|
||||||
|
rldb.init()
|
||||||
|
|
||||||
ocr.init()
|
ocr.init()
|
||||||
vsdb.init()
|
obs.init()
|
||||||
people_store.init()
|
|
||||||
|
people_service.init()
|
||||||
|
|
||||||
conf = config.get_instance()
|
conf = config.get_instance()
|
||||||
host = conf.get('web_service', 'server_host', fallback='127.0.0.1')
|
|
||||||
|
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
|
||||||
port = conf.getint('web_service', 'server_port', fallback=8099)
|
port = conf.getint('web_service', 'server_port', fallback=8099)
|
||||||
uvicorn.run(api, host=host, port=port)
|
uvicorn.run(api, host=host, port=port)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,61 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# created by mmmy on 2025-09-30
|
# created by mmmy on 2025-09-30
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy import Column, Integer, String, Text, DateTime, func
|
||||||
|
from utils.rldb import RLDBBaseModel
|
||||||
|
from utils.error import ErrorCode, error
|
||||||
|
|
||||||
|
class PeopleRLDBModel(RLDBBaseModel):
|
||||||
|
__tablename__ = 'peoples'
|
||||||
|
id = Column(String(36), primary_key=True)
|
||||||
|
name = Column(String(255), index=True)
|
||||||
|
contact = Column(String(255), index=True)
|
||||||
|
gender = Column(String(10))
|
||||||
|
age = Column(Integer)
|
||||||
|
height = Column(Integer)
|
||||||
|
marital_status = Column(String(20))
|
||||||
|
match_requirement = Column(Text)
|
||||||
|
introduction = Column(Text)
|
||||||
|
comments = Column(Text)
|
||||||
|
cover = Column(String(255), nullable=True)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||||
|
|
||||||
|
|
||||||
|
class Comment:
|
||||||
|
# 评论内容
|
||||||
|
content: str
|
||||||
|
# 评论人
|
||||||
|
author: str
|
||||||
|
# 创建时间
|
||||||
|
created_at: datetime
|
||||||
|
# 更新时间
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.content = kwargs.get('content', '')
|
||||||
|
self.author = kwargs.get('author', '')
|
||||||
|
self.created_at = kwargs.get('created_at', datetime.now())
|
||||||
|
self.updated_at = kwargs.get('updated_at', datetime.now())
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
'content': self.content,
|
||||||
|
'author': self.author,
|
||||||
|
'created_at': int(self.created_at.timestamp()),
|
||||||
|
'updated_at': int(self.updated_at.timestamp()),
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict):
|
||||||
|
data['created_at'] = datetime.fromtimestamp(data['created_at'])
|
||||||
|
data['updated_at'] = datetime.fromtimestamp(data['updated_at'])
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
class People:
|
class People:
|
||||||
@@ -18,24 +71,19 @@ class People:
|
|||||||
age: int
|
age: int
|
||||||
# 身高(cm)
|
# 身高(cm)
|
||||||
height: int
|
height: int
|
||||||
# 体重(kg)
|
|
||||||
# weight: int
|
|
||||||
# 婚姻状况
|
# 婚姻状况
|
||||||
# [
|
|
||||||
# "未婚(single)",
|
|
||||||
# "已婚(married)",
|
|
||||||
# "离异(divorced)",
|
|
||||||
# "丧偶(widowed)"
|
|
||||||
# ]
|
|
||||||
marital_status: str
|
marital_status: str
|
||||||
# 择偶要求
|
# 择偶要求
|
||||||
match_requirement: str
|
match_requirement: str
|
||||||
# 个人介绍
|
# 个人介绍
|
||||||
introduction: Dict[str, str]
|
introduction: Dict[str, str]
|
||||||
# 总结评价
|
# 总结评价
|
||||||
comments: Dict[str, str]
|
comments: Dict[str, "Comment"]
|
||||||
|
# 封面
|
||||||
|
cover: str = None
|
||||||
|
# 创建时间
|
||||||
|
created_at: datetime = None
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
# 初始化所有属性,从kwargs中获取值,如果不存在则设置默认值
|
# 初始化所有属性,从kwargs中获取值,如果不存在则设置默认值
|
||||||
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
|
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
|
||||||
@@ -48,15 +96,18 @@ class People:
|
|||||||
self.match_requirement = kwargs.get('match_requirement', '') if kwargs.get('match_requirement', '') is not None else ''
|
self.match_requirement = kwargs.get('match_requirement', '') if kwargs.get('match_requirement', '') is not None else ''
|
||||||
self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {}
|
self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {}
|
||||||
self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {}
|
self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {}
|
||||||
|
self.cover = kwargs.get('cover', None) if kwargs.get('cover', None) is not None else None
|
||||||
|
self.created_at = kwargs.get('created_at', None)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.tonl()
|
# 返回对象的字符串表示,包含所有属性
|
||||||
|
return (f"People(id={self.id}, name={self.name}, contact={self.contact}, gender={self.gender}, "
|
||||||
|
f"age={self.age}, height={self.height}, marital_status={self.marital_status}, "
|
||||||
|
f"match_requirement={self.match_requirement}, introduction={self.introduction}, "
|
||||||
|
f"comments={self.comments}, cover={self.cover}, created_at={self.created_at})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict):
|
def from_dict(cls, data: dict):
|
||||||
if 'created_at' in data:
|
|
||||||
# 移除 created_at 字段,避免类型错误
|
|
||||||
del data['created_at']
|
|
||||||
if 'updated_at' in data:
|
if 'updated_at' in data:
|
||||||
# 移除 updated_at 字段,避免类型错误
|
# 移除 updated_at 字段,避免类型错误
|
||||||
del data['updated_at']
|
del data['updated_at']
|
||||||
@@ -65,6 +116,24 @@ class People:
|
|||||||
del data['deleted_at']
|
del data['deleted_at']
|
||||||
return cls(**data)
|
return cls(**data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_rldb_model(cls, data: PeopleRLDBModel):
|
||||||
|
# 将关系数据库模型转换为对象
|
||||||
|
return cls(
|
||||||
|
id=data.id,
|
||||||
|
name=data.name,
|
||||||
|
contact=data.contact,
|
||||||
|
gender=data.gender,
|
||||||
|
age=data.age,
|
||||||
|
height=data.height,
|
||||||
|
marital_status=data.marital_status,
|
||||||
|
match_requirement=data.match_requirement,
|
||||||
|
introduction=json.loads(data.introduction) if data.introduction else {},
|
||||||
|
comments={k: Comment.from_dict(v) for k, v in json.loads(data.comments).items()} if data.comments else {},
|
||||||
|
cover=data.cover,
|
||||||
|
created_at=data.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
# 将对象转换为字典格式
|
# 将对象转换为字典格式
|
||||||
return {
|
return {
|
||||||
@@ -77,45 +146,39 @@ class People:
|
|||||||
'marital_status': self.marital_status,
|
'marital_status': self.marital_status,
|
||||||
'match_requirement': self.match_requirement,
|
'match_requirement': self.match_requirement,
|
||||||
'introduction': self.introduction,
|
'introduction': self.introduction,
|
||||||
'comments': self.comments,
|
'comments': {k: v.to_dict() for k, v in self.comments.items()},
|
||||||
}
|
'cover': self.cover,
|
||||||
|
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_rldb_model(self) -> PeopleRLDBModel:
|
||||||
|
# 将对象转换为关系数据库模型
|
||||||
|
return PeopleRLDBModel(
|
||||||
|
id=self.id,
|
||||||
|
name=self.name,
|
||||||
|
contact=self.contact,
|
||||||
|
gender=self.gender,
|
||||||
|
age=self.age,
|
||||||
|
height=self.height,
|
||||||
|
marital_status=self.marital_status,
|
||||||
|
match_requirement=self.match_requirement,
|
||||||
|
introduction=json.dumps(self.introduction, ensure_ascii=False),
|
||||||
|
comments=json.dumps({k: v.to_dict() for k, v in self.comments.items()}, ensure_ascii=False),
|
||||||
|
cover=self.cover,
|
||||||
|
)
|
||||||
|
|
||||||
def meta(self) -> Dict[str, str]:
|
def validate(self) -> error:
|
||||||
# 返回对象的元数据信息
|
err = error(ErrorCode.SUCCESS, "")
|
||||||
meta = {
|
if not self.name:
|
||||||
'id': self.id,
|
logging.error("Name is required, use default")
|
||||||
'name': self.name,
|
self.name = ""
|
||||||
'gender': self.gender,
|
if not self.gender in ['男', '女', '未知']:
|
||||||
'age': self.age,
|
logging.error("Gender must be '男', '女', or '未知', use default")
|
||||||
'height': self.height,
|
self.gender = "未知"
|
||||||
'marital_status': self.marital_status,
|
if not isinstance(self.age, int) or self.age < 0:
|
||||||
}
|
logging.error("Age must be an integer and greater than 0, use default")
|
||||||
logging.info(f"people meta: {meta}")
|
self.age = 0
|
||||||
return meta
|
if not isinstance(self.height, int) or self.height < 0:
|
||||||
|
logging.error("Height must be an integer and greater than 0, use default")
|
||||||
def tonl(self) -> str:
|
self.height = 0
|
||||||
# 将对象转换为文档格式的字符串
|
return err
|
||||||
doc = []
|
|
||||||
doc.append(f"姓名: {self.name}")
|
|
||||||
doc.append(f"性别: {self.gender}")
|
|
||||||
if self.age:
|
|
||||||
doc.append(f"年龄: {self.age}")
|
|
||||||
if self.height:
|
|
||||||
doc.append(f"身高: {self.height}cm")
|
|
||||||
if self.marital_status:
|
|
||||||
doc.append(f"婚姻状况: {self.marital_status}")
|
|
||||||
if self.match_requirement:
|
|
||||||
doc.append(f"择偶要求: {self.match_requirement}")
|
|
||||||
if self.introduction:
|
|
||||||
doc.append("个人介绍:")
|
|
||||||
for key, value in self.introduction.items():
|
|
||||||
doc.append(f" - {key}: {value}")
|
|
||||||
if self.comments:
|
|
||||||
doc.append("总结评价:")
|
|
||||||
for key, value in self.comments.items():
|
|
||||||
doc.append(f" - {key}: {value}")
|
|
||||||
return '\n'.join(doc)
|
|
||||||
|
|
||||||
def comment(self, comment: Dict[str, str]):
|
|
||||||
# 添加总结评价
|
|
||||||
self.comments.update(comment)
|
|
||||||
|
|||||||
124
src/services/people.py
Normal file
124
src/services/people.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from models.people import People, PeopleRLDBModel, Comment
|
||||||
|
from datetime import datetime
|
||||||
|
from utils.error import ErrorCode, error
|
||||||
|
from utils import rldb
|
||||||
|
|
||||||
|
|
||||||
|
class PeopleService:
|
||||||
|
def __init__(self):
|
||||||
|
self.rldb = rldb.get_instance()
|
||||||
|
|
||||||
|
def save(self, people: People) -> (str, error):
|
||||||
|
"""
|
||||||
|
保存人物到数据库和向量数据库
|
||||||
|
|
||||||
|
:param people: 人物对象
|
||||||
|
:return: 人物ID
|
||||||
|
"""
|
||||||
|
# 0. 生成 people id
|
||||||
|
people.id = people.id if people.id else uuid.uuid4().hex
|
||||||
|
|
||||||
|
# 1. 转换模型,并保存到 SQL 数据库
|
||||||
|
people_orm = people.to_rldb_model()
|
||||||
|
self.rldb.upsert(people_orm)
|
||||||
|
|
||||||
|
return people.id, error(ErrorCode.SUCCESS, "")
|
||||||
|
|
||||||
|
def delete(self, people_id: str) -> error:
|
||||||
|
"""
|
||||||
|
删除人物从数据库和向量数据库
|
||||||
|
|
||||||
|
:param people_id: 人物ID
|
||||||
|
:return: 错误对象
|
||||||
|
"""
|
||||||
|
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
|
||||||
|
if not people_orm:
|
||||||
|
return error(ErrorCode.RLDB_ERROR, f"people {people_id} not found")
|
||||||
|
self.rldb.delete(people_orm)
|
||||||
|
return error(ErrorCode.SUCCESS, "")
|
||||||
|
|
||||||
|
def get(self, people_id: str) -> (People, error):
|
||||||
|
"""
|
||||||
|
从数据库获取人物
|
||||||
|
|
||||||
|
:param people_id: 人物ID
|
||||||
|
:return: 人物对象
|
||||||
|
"""
|
||||||
|
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
|
||||||
|
if not people_orm:
|
||||||
|
return None, error(ErrorCode.MODEL_ERROR, f"people {people_id} not found")
|
||||||
|
return People.from_rldb_model(people_orm), error(ErrorCode.SUCCESS, "")
|
||||||
|
|
||||||
|
def list(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> (list[People], error):
|
||||||
|
"""
|
||||||
|
从数据库列出人物
|
||||||
|
|
||||||
|
:param conds: 查询条件字典
|
||||||
|
:param limit: 分页大小
|
||||||
|
:param offset: 分页偏移量
|
||||||
|
:return: 人物对象列表
|
||||||
|
"""
|
||||||
|
people_orms = self.rldb.query(PeopleRLDBModel, **conds)
|
||||||
|
peoples = [People.from_rldb_model(people_orm) for people_orm in people_orms]
|
||||||
|
|
||||||
|
return peoples, error(ErrorCode.SUCCESS, "")
|
||||||
|
|
||||||
|
def save_remark(self, people_id: str, content: str) -> error:
|
||||||
|
"""
|
||||||
|
为人物添加或更新备注
|
||||||
|
|
||||||
|
:param people_id: 人物ID
|
||||||
|
:param content: 备注内容
|
||||||
|
:return: 错误对象
|
||||||
|
"""
|
||||||
|
people: People
|
||||||
|
err: error
|
||||||
|
people, err = self.get(people_id)
|
||||||
|
logging.info(f"get people before save remark: {people}")
|
||||||
|
if not err.success:
|
||||||
|
return err
|
||||||
|
remark = people.comments.get("remark", None)
|
||||||
|
if remark is not None:
|
||||||
|
remark.content = content
|
||||||
|
remark.updated_at = datetime.now()
|
||||||
|
else:
|
||||||
|
people.comments["remark"] = Comment(content=content)
|
||||||
|
logging.info(f"save remark for people {people}")
|
||||||
|
_, err = self.save(people)
|
||||||
|
return err
|
||||||
|
|
||||||
|
def delete_remark(self, people_id: str) -> error:
|
||||||
|
"""
|
||||||
|
删除人物备注
|
||||||
|
|
||||||
|
:param people_id: 人物ID
|
||||||
|
:return: 错误对象
|
||||||
|
"""
|
||||||
|
people: People
|
||||||
|
err: error
|
||||||
|
people, err = self.get(people_id)
|
||||||
|
if not err.success:
|
||||||
|
return err
|
||||||
|
|
||||||
|
if "remark" in people.comments:
|
||||||
|
del people.comments["remark"]
|
||||||
|
_, err = self.save(people)
|
||||||
|
return err
|
||||||
|
|
||||||
|
return error(ErrorCode.SUCCESS, "")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
people_service = None
|
||||||
|
|
||||||
|
def init():
|
||||||
|
global people_service
|
||||||
|
people_service = PeopleService()
|
||||||
|
|
||||||
|
def get_instance() -> PeopleService:
|
||||||
|
return people_service
|
||||||
@@ -1,216 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, func
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from utils.config import get_instance as get_config
|
|
||||||
from utils.vsdb import VectorDB, get_instance as get_vsdb
|
|
||||||
from utils.obs import OBS, get_instance as get_obs
|
|
||||||
|
|
||||||
from models.people import People
|
|
||||||
|
|
||||||
people_store = None
|
|
||||||
Base = declarative_base()
|
|
||||||
class PeopleORM(Base):
|
|
||||||
__tablename__ = 'peoples'
|
|
||||||
id = Column(String(36), primary_key=True)
|
|
||||||
name = Column(String(255), index=True)
|
|
||||||
contact = Column(String(255), index=True)
|
|
||||||
gender = Column(String(10))
|
|
||||||
age = Column(Integer)
|
|
||||||
height = Column(Integer)
|
|
||||||
marital_status = Column(String(20))
|
|
||||||
match_requirement = Column(Text)
|
|
||||||
introduction = Column(Text)
|
|
||||||
comments = Column(Text)
|
|
||||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
|
||||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
|
||||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
|
||||||
def parse_from_people(self, people: People):
|
|
||||||
import json
|
|
||||||
self.id = people.id
|
|
||||||
self.name = people.name
|
|
||||||
self.contact = people.contact
|
|
||||||
self.gender = people.gender
|
|
||||||
self.age = people.age
|
|
||||||
self.height = people.height
|
|
||||||
self.marital_status = people.marital_status
|
|
||||||
self.match_requirement = people.match_requirement
|
|
||||||
# 将字典类型字段序列化为JSON字符串存储
|
|
||||||
self.introduction = json.dumps(people.introduction, ensure_ascii=False)
|
|
||||||
self.comments = json.dumps(people.comments, ensure_ascii=False)
|
|
||||||
|
|
||||||
def to_people(self) -> People:
|
|
||||||
import json
|
|
||||||
people = People()
|
|
||||||
people.id = self.id
|
|
||||||
people.name = self.name
|
|
||||||
people.contact = self.contact
|
|
||||||
people.gender = self.gender
|
|
||||||
people.age = self.age
|
|
||||||
people.height = self.height
|
|
||||||
people.marital_status = self.marital_status
|
|
||||||
people.match_requirement = self.match_requirement
|
|
||||||
# 将JSON字符串反序列化为字典类型字段
|
|
||||||
try:
|
|
||||||
people.introduction = json.loads(self.introduction) if self.introduction else {}
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
people.introduction = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
people.comments = json.loads(self.comments) if self.comments else {}
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
people.comments = {}
|
|
||||||
|
|
||||||
return people
|
|
||||||
|
|
||||||
class PeopleStore:
|
|
||||||
def __init__(self):
|
|
||||||
config = get_config()
|
|
||||||
self.sqldb_engine = create_engine(config.get("sqlalchemy", "database_dsn"))
|
|
||||||
Base.metadata.create_all(self.sqldb_engine)
|
|
||||||
self.session_maker = sessionmaker(bind=self.sqldb_engine)
|
|
||||||
self.vsdb: VectorDB = get_vsdb()
|
|
||||||
self.obs: OBS = get_obs()
|
|
||||||
|
|
||||||
def save(self, people: People) -> str:
|
|
||||||
"""
|
|
||||||
保存人物到数据库和向量数据库
|
|
||||||
|
|
||||||
:param people: 人物对象
|
|
||||||
:return: 人物ID
|
|
||||||
"""
|
|
||||||
# 0. 生成 people id
|
|
||||||
people.id = people.id if people.id else uuid.uuid4().hex
|
|
||||||
|
|
||||||
# 1. 转换模型,并保存到 SQL 数据库
|
|
||||||
people_orm = PeopleORM()
|
|
||||||
people_orm.parse_from_people(people)
|
|
||||||
with self.session_maker() as session:
|
|
||||||
session.add(people_orm)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# 2. 保存到向量数据库
|
|
||||||
people_metadata = people.meta()
|
|
||||||
people_document = people.tonl()
|
|
||||||
logging.info(f"people: {people}")
|
|
||||||
logging.info(f"people_metadata: {people_metadata}")
|
|
||||||
logging.info(f"people_document: {people_document}")
|
|
||||||
results = self.vsdb.insert(metadatas=[people_metadata], documents=[people_document], ids=[people.id])
|
|
||||||
logging.info(f"results: {results}")
|
|
||||||
if len(results) == 0:
|
|
||||||
raise Exception("insert failed")
|
|
||||||
|
|
||||||
# 3. 保存到 OBS 存储
|
|
||||||
people_dict = people.to_dict()
|
|
||||||
people_json = json.dumps(people_dict, ensure_ascii=False)
|
|
||||||
obs_url = self.obs.Put(f"peoples/{people.id}/detail.json", people_json.encode('utf-8'))
|
|
||||||
logging.info(f"obs_url: {obs_url}")
|
|
||||||
|
|
||||||
return people.id
|
|
||||||
|
|
||||||
def update(self, people: People) -> None:
|
|
||||||
raise Exception("update not implemented")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def find(self, people_id: str) -> People:
|
|
||||||
"""
|
|
||||||
根据人物ID查询人物
|
|
||||||
|
|
||||||
:param people_id: 人物ID
|
|
||||||
:return: 人物对象
|
|
||||||
"""
|
|
||||||
with self.session_maker() as session:
|
|
||||||
people_orm = session.query(PeopleORM).filter(
|
|
||||||
PeopleORM.id == people_id,
|
|
||||||
PeopleORM.deleted_at.is_(None)
|
|
||||||
).first()
|
|
||||||
if not people_orm:
|
|
||||||
raise Exception(f"people not found, people_id: {people_id}")
|
|
||||||
return people_orm.to_people()
|
|
||||||
|
|
||||||
def query(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> list[People]:
|
|
||||||
"""
|
|
||||||
根据查询条件查询人物
|
|
||||||
|
|
||||||
:param conds: 查询条件字典
|
|
||||||
:param limit: 分页大小
|
|
||||||
:param offset: 分页偏移量
|
|
||||||
:return: 人物对象列表
|
|
||||||
"""
|
|
||||||
if conds is None:
|
|
||||||
conds = {}
|
|
||||||
with self.session_maker() as session:
|
|
||||||
people_orms = session.query(PeopleORM).filter_by(**conds).filter(
|
|
||||||
PeopleORM.deleted_at.is_(None)
|
|
||||||
).limit(limit).offset(offset).all()
|
|
||||||
return [people_orm.to_people() for people_orm in people_orms]
|
|
||||||
|
|
||||||
def search(self, search: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[People]:
|
|
||||||
"""
|
|
||||||
根据搜索内容和查询条件查询人物
|
|
||||||
|
|
||||||
:param search: 搜索内容
|
|
||||||
:param metadatas: 查询条件字典
|
|
||||||
:param ids: 可选的人物ID列表,用于过滤结果
|
|
||||||
:param top_k: 返回结果数量
|
|
||||||
:return: 人物对象列表
|
|
||||||
"""
|
|
||||||
peoples = []
|
|
||||||
results = self.vsdb.search(document=search, metadatas=metadatas, ids=ids, top_k=top_k)
|
|
||||||
logging.info(f"results: {results}")
|
|
||||||
people_id_list = []
|
|
||||||
for result in results:
|
|
||||||
people_id = result.get('id', '')
|
|
||||||
if not people_id:
|
|
||||||
continue
|
|
||||||
people_id_list.append(people_id)
|
|
||||||
logging.info(f"people_id_list: {people_id_list}")
|
|
||||||
with self.session_maker() as session:
|
|
||||||
people_orms = session.query(PeopleORM).filter(PeopleORM.id.in_(people_id_list)).filter(
|
|
||||||
PeopleORM.deleted_at.is_(None)
|
|
||||||
).all()
|
|
||||||
# 根据 people_id_list 的顺序对查询结果进行排序
|
|
||||||
order_map = {pid: idx for idx, pid in enumerate(people_id_list)}
|
|
||||||
people_orms.sort(key=lambda orm: order_map.get(orm.id, len(order_map)))
|
|
||||||
for people_orm in people_orms:
|
|
||||||
people = people_orm.to_people()
|
|
||||||
peoples.append(people)
|
|
||||||
return peoples
|
|
||||||
|
|
||||||
def delete(self, people_id: str) -> None:
|
|
||||||
"""
|
|
||||||
删除人物从数据库和向量数据库
|
|
||||||
|
|
||||||
:param people_id: 人物ID
|
|
||||||
"""
|
|
||||||
# 1. 从 SQL 数据库软删除人物
|
|
||||||
with self.session_maker() as session:
|
|
||||||
session.query(PeopleORM).filter(
|
|
||||||
PeopleORM.id == people_id,
|
|
||||||
PeopleORM.deleted_at.is_(None)
|
|
||||||
).update({PeopleORM.deleted_at: func.now()}, synchronize_session=False)
|
|
||||||
session.commit()
|
|
||||||
logging.info(f"人物 {people_id} 标记删除 SQL 成功")
|
|
||||||
|
|
||||||
# 2. 删除向量数据库中的记录
|
|
||||||
self.vsdb.delete(ids=[people_id])
|
|
||||||
logging.info(f"人物 {people_id} 删除向量数据库成功")
|
|
||||||
|
|
||||||
# 3. 删除 OBS 存储中的文件
|
|
||||||
keys = self.obs.List(f"peoples/{people_id}/")
|
|
||||||
for key in keys:
|
|
||||||
self.obs.Del(key)
|
|
||||||
logging.info(f"文件 {key} 删除 OBS 成功")
|
|
||||||
|
|
||||||
def init():
|
|
||||||
global people_store
|
|
||||||
people_store = PeopleStore()
|
|
||||||
|
|
||||||
|
|
||||||
def get_instance() -> PeopleStore:
|
|
||||||
return people_store
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
# 导出utils模块中的子模块
|
|
||||||
from . import config, obs, ocr, vsdb, logger
|
|
||||||
__all__ = ['config', 'obs', 'ocr', 'vsdb', 'logger']
|
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import configparser
|
|||||||
|
|
||||||
config = None
|
config = None
|
||||||
|
|
||||||
|
|
||||||
def init(config_file: str):
|
def init(config_file: str):
|
||||||
global config
|
global config
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
|
|||||||
30
src/utils/error.py
Normal file
30
src/utils/error.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
import logging
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
class ErrorCode(Enum):
|
||||||
|
SUCCESS = 0
|
||||||
|
MODEL_ERROR = 1000
|
||||||
|
RLDB_ERROR = 2100
|
||||||
|
|
||||||
|
class error(Protocol):
|
||||||
|
_error_code: int = 0
|
||||||
|
_error_info: str = ""
|
||||||
|
|
||||||
|
def __init__(self, error_code: ErrorCode, error_info: str):
|
||||||
|
self._error_code = int(error_code.value)
|
||||||
|
self._error_info = error_info
|
||||||
|
logging.info(f"errorcode: {type(self._error_code)}")
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return f"{self.__class__.__name__}({self._error_code}, {self._error_info})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def code(self) -> int:
|
||||||
|
return self._error_code
|
||||||
|
@property
|
||||||
|
def info(self) -> str:
|
||||||
|
return self._error_info
|
||||||
|
@property
|
||||||
|
def success(self) -> bool:
|
||||||
|
return self._error_code == 0
|
||||||
167
src/utils/rldb.py
Normal file
167
src/utils/rldb.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
import uuid
|
||||||
|
from sqlalchemy import Column, DateTime, String, create_engine, func
|
||||||
|
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||||
|
from .config import get_instance as get_config
|
||||||
|
|
||||||
|
SQLAlchemyBase = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class RLDBBaseModel(SQLAlchemyBase):
|
||||||
|
__abstract__ = True
|
||||||
|
id = Column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||||
|
def __str__(self) -> str:
|
||||||
|
# 遍历所有的field,打印出所有的field和value, id 永远排在第一, 三个时间戳排在最后, 其余字段按定义顺序排序
|
||||||
|
fields = [field for field in self.__dict__ if not field.startswith('_')]
|
||||||
|
fields.remove("id") if "id" in fields else None
|
||||||
|
fields.remove("created_at") if "created_at" in fields else None
|
||||||
|
fields.remove("updated_at") if "updated_at" in fields else None
|
||||||
|
fields.remove("deleted_at") if "deleted_at" in fields else None
|
||||||
|
fields = ["id"] + fields + ["created_at", "updated_at", "deleted_at"]
|
||||||
|
field_values = [f"{field}={getattr(self, field)}" for field in fields]
|
||||||
|
return f"{self.__class__.__name__}({', '.join(field_values)})"
|
||||||
|
|
||||||
|
|
||||||
|
class RelationalDB(Protocol):
|
||||||
|
def insert(self, data: RLDBBaseModel) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
def update(self, data: RLDBBaseModel) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
def upsert(self, data: RLDBBaseModel) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
def delete(self, data: RLDBBaseModel) -> str:
|
||||||
|
...
|
||||||
|
|
||||||
|
def get(self,
|
||||||
|
model: type[RLDBBaseModel],
|
||||||
|
id: str,
|
||||||
|
include_deleted: bool = False
|
||||||
|
) -> RLDBBaseModel:
|
||||||
|
...
|
||||||
|
|
||||||
|
def query(self,
|
||||||
|
model: type[RLDBBaseModel],
|
||||||
|
include_deleted: bool = False,
|
||||||
|
limit: int = None,
|
||||||
|
offset: int = None,
|
||||||
|
**filters
|
||||||
|
) -> list[RLDBBaseModel]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class SqlAlchemyDB():
|
||||||
|
def __init__(self, dsn: str = None) -> None:
|
||||||
|
config = get_config()
|
||||||
|
dsn = dsn if dsn else config.get("sqlalchemy", "database_dsn")
|
||||||
|
self.sqldb_engine = create_engine(dsn)
|
||||||
|
SQLAlchemyBase.metadata.create_all(self.sqldb_engine)
|
||||||
|
self.session_maker = sessionmaker(bind=self.sqldb_engine)
|
||||||
|
|
||||||
|
def insert(self, data: RLDBBaseModel) -> str:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.add(data)
|
||||||
|
session.commit()
|
||||||
|
return data.id
|
||||||
|
|
||||||
|
def update(self, data: RLDBBaseModel) -> str:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.merge(data)
|
||||||
|
session.commit()
|
||||||
|
return data.id
|
||||||
|
|
||||||
|
def upsert(self, data: RLDBBaseModel) -> str:
|
||||||
|
existed = data.id and data.id != "" and self.get(data.__class__, data.id) is not None
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.merge(data) if existed else session.add(data)
|
||||||
|
session.commit()
|
||||||
|
return data.id
|
||||||
|
|
||||||
|
def delete(self, data: RLDBBaseModel) -> str:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.delete(data)
|
||||||
|
session.commit()
|
||||||
|
return data.id
|
||||||
|
|
||||||
|
def get(self,
|
||||||
|
model: type[RLDBBaseModel],
|
||||||
|
id: str,
|
||||||
|
) -> RLDBBaseModel:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
sel = session.query(model)
|
||||||
|
sel = sel.filter(model.id == id)
|
||||||
|
sel = sel.filter(model.deleted_at.is_(None))
|
||||||
|
result = sel.first()
|
||||||
|
return result
|
||||||
|
|
||||||
|
def query(self,
|
||||||
|
model: type[RLDBBaseModel],
|
||||||
|
limit: int = None,
|
||||||
|
offset: int = None,
|
||||||
|
**filters
|
||||||
|
) -> list[RLDBBaseModel]:
|
||||||
|
results: list[RLDBBaseModel] = []
|
||||||
|
with self.session_maker() as session:
|
||||||
|
sel = session.query(model)
|
||||||
|
sel = sel.filter(model.deleted_at.is_(None))
|
||||||
|
if filters:
|
||||||
|
sel = sel.filter_by(**filters)
|
||||||
|
if limit:
|
||||||
|
sel = sel.limit(limit)
|
||||||
|
if offset:
|
||||||
|
sel = sel.offset(offset)
|
||||||
|
results = sel.all()
|
||||||
|
results.sort(key=lambda x: x.created_at, reverse=True)
|
||||||
|
return results
|
||||||
|
|
||||||
|
_rldb_instance: RelationalDB = None
|
||||||
|
|
||||||
|
def init(type: str = "sqlalchemy", dsn: str = None):
|
||||||
|
global _rldb_instance
|
||||||
|
if type == "sqlalchemy":
|
||||||
|
_rldb_instance = SqlAlchemyDB(dsn)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"RelationalDB type {type} not supported")
|
||||||
|
|
||||||
|
def get_instance() -> RelationalDB:
|
||||||
|
global _rldb_instance
|
||||||
|
return _rldb_instance
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
class TestModel(RLDBBaseModel):
|
||||||
|
__tablename__ = "test_model"
|
||||||
|
name = Column(String(36), nullable=True)
|
||||||
|
conf = Column(String(96), nullable=True)
|
||||||
|
init("sqlalchemy", dsn="sqlite:///./demo_storage/rldb.db")
|
||||||
|
db = get_instance()
|
||||||
|
|
||||||
|
test_data = TestModel(name="test", conf="test.config")
|
||||||
|
print(f"before insert: {test_data}")
|
||||||
|
ret = db.insert(test_data)
|
||||||
|
print(f"after insert: {test_data}")
|
||||||
|
|
||||||
|
print(f"before update: {test_data}")
|
||||||
|
test_data.conf = "test.config.new"
|
||||||
|
ret = db.update(test_data)
|
||||||
|
print(f"after update: {test_data}")
|
||||||
|
|
||||||
|
test2_data = TestModel(name="test", conf="test2.config")
|
||||||
|
print(f"before upsert: {test2_data}")
|
||||||
|
ret = db.upsert(test2_data)
|
||||||
|
print(f"after upsert: {test2_data}")
|
||||||
|
|
||||||
|
get_data = db.get(TestModel, test_data.id)
|
||||||
|
print(f"get data: {get_data}")
|
||||||
|
|
||||||
|
query_data = db.query(TestModel, name="test")
|
||||||
|
for data in query_data:
|
||||||
|
print(data.id, data.name, data.conf)
|
||||||
|
print(f"query data: {data}")
|
||||||
|
ret = db.delete(data)
|
||||||
|
print(f"delete data.id: {ret}")
|
||||||
@@ -1,241 +0,0 @@
|
|||||||
import uuid
|
|
||||||
import chromadb
|
|
||||||
import logging
|
|
||||||
from typing import Protocol
|
|
||||||
from chromadb.config import Settings
|
|
||||||
from chromadb.utils import embedding_functions
|
|
||||||
from .config import get_instance as get_config
|
|
||||||
|
|
||||||
class VectorDB(Protocol):
|
|
||||||
|
|
||||||
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
|
|
||||||
"""
|
|
||||||
插入向量到数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vector (list[float]): 向量
|
|
||||||
metadata (dict): 元数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否插入成功
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def delete(self, ids: list[str]) -> bool:
|
|
||||||
"""
|
|
||||||
Delete documents from a collection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids: List of IDs to delete
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether deletion was successful
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def query(self, metadatas: dict, ids: list[str], top_k: int = 5) -> list[dict]:
|
|
||||||
"""
|
|
||||||
查询向量数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_vector (list[float]): 查询向量
|
|
||||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: 查询结果列表
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
|
||||||
"""
|
|
||||||
搜索向量数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document: Document to search
|
|
||||||
metadatas: Metadata to filter by
|
|
||||||
ids: List of IDs to filter by
|
|
||||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: 查询结果列表
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
class ChromaDB:
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Initialize the ChromaDB instance.
|
|
||||||
"""
|
|
||||||
config = get_config()
|
|
||||||
self.embedding_functions = embedding_functions.OpenAIEmbeddingFunction(
|
|
||||||
api_base=config.get("voc-engine_embedding", "api_url"),
|
|
||||||
api_key=config.get("voc-engine_embedding", "api_key"),
|
|
||||||
model_name=config.get("voc-engine_embedding", "endpoint"),
|
|
||||||
)
|
|
||||||
persist_directory = config.get("chroma_vsdb", "database_dir", fallback=None)
|
|
||||||
logging.debug(f"persist_directory: {persist_directory}")
|
|
||||||
if persist_directory:
|
|
||||||
self.client = chromadb.PersistentClient(
|
|
||||||
path=persist_directory,
|
|
||||||
settings=Settings(anonymized_telemetry=False)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.client = chromadb.Client(
|
|
||||||
settings=Settings(anonymized_telemetry=False),
|
|
||||||
)
|
|
||||||
self.collection_name = config.get("chroma_vsdb", "collection_name", fallback="peoples")
|
|
||||||
metadata: dict = kwargs.get('collection_metadata', {'hnsw:space': 'cosine'})
|
|
||||||
metadata['hnsw:space'] = metadata.get('hnsw:space', 'cosine')
|
|
||||||
self.collection = self.client.get_or_create_collection(
|
|
||||||
name=self.collection_name,
|
|
||||||
embedding_function=self.embedding_functions,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
|
|
||||||
"""
|
|
||||||
Insert documents into a collection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metadatas: List of metadata corresponding to each document
|
|
||||||
documents: List of documents to insert
|
|
||||||
ids: Optional list of unique IDs for each document. If None, IDs will be generated.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[str]: List of inserted IDs
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not ids:
|
|
||||||
# Generate unique IDs if not provided
|
|
||||||
ids = [str(uuid.uuid4()) for _ in range(len(documents))]
|
|
||||||
|
|
||||||
self.collection.add(
|
|
||||||
documents=documents,
|
|
||||||
metadatas=metadatas,
|
|
||||||
ids=ids
|
|
||||||
)
|
|
||||||
|
|
||||||
return ids
|
|
||||||
|
|
||||||
def delete(self, ids: list[str]) -> bool:
|
|
||||||
"""
|
|
||||||
Delete documents from a collection.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ids: List of IDs to delete
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: Whether deletion was successful
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.collection.delete(ids)
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error deleting documents: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def query(self, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
|
||||||
"""
|
|
||||||
查询向量数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metadatas: Metadata to filter by
|
|
||||||
ids: List of IDs to query
|
|
||||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: 查询结果列表
|
|
||||||
"""
|
|
||||||
results = self.collection.query(
|
|
||||||
query_embeddings=None,
|
|
||||||
query_texts=None,
|
|
||||||
n_results=top_k,
|
|
||||||
where=metadatas,
|
|
||||||
ids=ids,
|
|
||||||
include=["documents", "metadatas", "distances"],
|
|
||||||
)
|
|
||||||
formatted_results = []
|
|
||||||
for i in range(len(results['ids'][0])):
|
|
||||||
result = {
|
|
||||||
'id': results['ids'][0][i],
|
|
||||||
'distance': results['distances'][0][i],
|
|
||||||
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
|
|
||||||
'document': results['documents'][0][i] if results['documents'][0] else ''
|
|
||||||
}
|
|
||||||
formatted_results.append(result)
|
|
||||||
return formatted_results
|
|
||||||
|
|
||||||
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
|
||||||
"""
|
|
||||||
搜索向量数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
document: Document to search
|
|
||||||
metadatas: Metadata to filter by
|
|
||||||
ids: List of IDs to filter by
|
|
||||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[dict]: 查询结果列表
|
|
||||||
"""
|
|
||||||
results = self.collection.query(
|
|
||||||
query_embeddings=None,
|
|
||||||
query_texts=[document],
|
|
||||||
n_results=top_k,
|
|
||||||
where=metadatas if metadatas else None,
|
|
||||||
ids=ids,
|
|
||||||
include=["documents", "metadatas", "distances"],
|
|
||||||
)
|
|
||||||
formatted_results = []
|
|
||||||
for i in range(len(results['ids'][0])):
|
|
||||||
logging.info(f"result id: {results['ids'][0][i]}, distance: {results['distances'][0][i]}")
|
|
||||||
result = {
|
|
||||||
'id': results['ids'][0][i],
|
|
||||||
'distance': results['distances'][0][i],
|
|
||||||
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
|
|
||||||
'document': results['documents'][0][i] if results['documents'][0] else ''
|
|
||||||
}
|
|
||||||
formatted_results.append(result)
|
|
||||||
return formatted_results
|
|
||||||
pass
|
|
||||||
|
|
||||||
_vsdb_instance: VectorDB = None
|
|
||||||
|
|
||||||
def init():
|
|
||||||
global _vsdb_instance
|
|
||||||
_vsdb_instance = ChromaDB()
|
|
||||||
|
|
||||||
def get_instance() -> VectorDB:
|
|
||||||
global _vsdb_instance
|
|
||||||
return _vsdb_instance
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import os
|
|
||||||
|
|
||||||
from logger import init as init_logger
|
|
||||||
init_logger(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG)
|
|
||||||
|
|
||||||
from config import init as init_config
|
|
||||||
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
|
|
||||||
init_config(config_file)
|
|
||||||
|
|
||||||
init()
|
|
||||||
vsdb = get_instance()
|
|
||||||
metadatas = [
|
|
||||||
{'name': '丽丽'},
|
|
||||||
{'name': '志刚'},
|
|
||||||
{'name': '张三'},
|
|
||||||
{'name': '李四'},
|
|
||||||
]
|
|
||||||
documents = [
|
|
||||||
'姓名: 丽丽, 性别: 女, 年龄: 23, 爱好: 爬山、骑行、攀岩、跳伞、蹦极',
|
|
||||||
"姓名: 志刚, 性别: 男, 年龄: 25, 爱好: 读书、游戏",
|
|
||||||
"姓名: 张三, 性别: 男, 年龄: 30, 爱好: 画画、写作、阅读、逛展、旅行",
|
|
||||||
"姓名: 李四, 性别: 男, 年龄: 35, 爱好: 做饭、美食、旅游"
|
|
||||||
]
|
|
||||||
search_text = '25岁以下的'
|
|
||||||
ids = vsdb.insert(metadatas, documents)
|
|
||||||
results = vsdb.search(search_text, None, None, top_k=4)
|
|
||||||
for result in results:
|
|
||||||
print(result['document'], ' ', result['distance'])
|
|
||||||
@@ -1,20 +1,16 @@
|
|||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
import logging
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
from fastapi import FastAPI, File, UploadFile, Query
|
from fastapi import FastAPI, UploadFile, File, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ai.agent import ExtractPeopleAgent
|
|
||||||
from models.people import People
|
|
||||||
from utils import obs, ocr, vsdb
|
|
||||||
from storage.people_store import get_instance as get_people_store
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from services.people import get_instance as get_people_service
|
||||||
|
from models.people import People
|
||||||
|
from agents.extract_people_agent import ExtractPeopleAgent
|
||||||
|
from utils import obs, ocr
|
||||||
|
|
||||||
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
|
api = FastAPI(title="Single People Management and Searching", version="0.1")
|
||||||
api.add_middleware(
|
api.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=["*"],
|
||||||
@@ -28,19 +24,21 @@ class BaseResponse(BaseModel):
|
|||||||
error_info: str
|
error_info: str
|
||||||
data: Optional[Any] = None
|
data: Optional[Any] = None
|
||||||
|
|
||||||
|
@api.post("/ping")
|
||||||
|
async def ping():
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
class PostInputRequest(BaseModel):
|
class PostInputRequest(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
@api.post("/input")
|
@api.post("/recognition/input")
|
||||||
async def post_input(request: PostInputRequest):
|
async def post_input(request: PostInputRequest):
|
||||||
extra_agent = ExtractPeopleAgent()
|
people = extract_people(request.text)
|
||||||
people = extra_agent.extract_people_info(request.text)
|
|
||||||
logging.info(f"people: {people}")
|
|
||||||
resp = BaseResponse(error_code=0, error_info="success")
|
resp = BaseResponse(error_code=0, error_info="success")
|
||||||
resp.data = people.to_dict()
|
resp.data = people.to_dict()
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@api.post("/input_image")
|
@api.post("/recognition/image")
|
||||||
async def post_input_image(image: UploadFile = File(...)):
|
async def post_input_image(image: UploadFile = File(...)):
|
||||||
# 实现上传图片的处理
|
# 实现上传图片的处理
|
||||||
# 保存上传的图片文件
|
# 保存上传的图片文件
|
||||||
@@ -65,20 +63,53 @@ async def post_input_image(image: UploadFile = File(...)):
|
|||||||
ocr_result = ocr_util.recognize_image_text(obs_url)
|
ocr_result = ocr_util.recognize_image_text(obs_url)
|
||||||
logging.info(f"ocr_result: {ocr_result}")
|
logging.info(f"ocr_result: {ocr_result}")
|
||||||
|
|
||||||
post_input_request = PostInputRequest(text=ocr_result)
|
people = extract_people(ocr_result, obs_url)
|
||||||
return await post_input(post_input_request)
|
resp = BaseResponse(error_code=0, error_info="success")
|
||||||
|
resp.data = people.to_dict()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def extract_people(text: str, cover_link: str = None) -> People:
|
||||||
|
extra_agent = ExtractPeopleAgent()
|
||||||
|
people = extra_agent.extract_people_info(text)
|
||||||
|
people.cover = cover_link
|
||||||
|
logging.info(f"people: {people}")
|
||||||
|
return people
|
||||||
|
|
||||||
class PostPeopleRequest(BaseModel):
|
class PostPeopleRequest(BaseModel):
|
||||||
people: dict
|
people: dict
|
||||||
|
|
||||||
@api.post("/peoples")
|
@api.post("/people")
|
||||||
async def post_people(post_people_request: PostPeopleRequest):
|
async def post_people(post_people_request: PostPeopleRequest):
|
||||||
logging.debug(f"post_people_request: {post_people_request}")
|
logging.debug(f"post_people_request: {post_people_request}")
|
||||||
people = People.from_dict(post_people_request.people)
|
people = People.from_dict(post_people_request.people)
|
||||||
store = get_people_store()
|
service = get_people_service()
|
||||||
people.id = store.save(people)
|
people.id, error = service.save(people)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
return BaseResponse(error_code=0, error_info="success", data=people.id)
|
return BaseResponse(error_code=0, error_info="success", data=people.id)
|
||||||
|
|
||||||
|
@api.put("/people/{people_id}")
|
||||||
|
async def update_people(people_id: str, post_people_request: PostPeopleRequest):
|
||||||
|
logging.debug(f"post_people_request: {post_people_request}")
|
||||||
|
people = People.from_dict(post_people_request.people)
|
||||||
|
people.id = people_id
|
||||||
|
service = get_people_service()
|
||||||
|
res, error = service.get(people_id)
|
||||||
|
if not error.success or not res:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
_, error = service.save(people)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
@api.delete("/people/{people_id}")
|
||||||
|
async def delete_people(people_id: str):
|
||||||
|
service = get_people_service()
|
||||||
|
error = service.delete(people_id)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
class GetPeopleRequest(BaseModel):
|
class GetPeopleRequest(BaseModel):
|
||||||
query: Optional[str] = None
|
query: Optional[str] = None
|
||||||
conds: Optional[dict] = None
|
conds: Optional[dict] = None
|
||||||
@@ -93,8 +124,6 @@ async def get_peoples(
|
|||||||
marital_status: Optional[str] = Query(None, description="婚姻状态"),
|
marital_status: Optional[str] = Query(None, description="婚姻状态"),
|
||||||
limit: int = Query(10, description="分页大小"),
|
limit: int = Query(10, description="分页大小"),
|
||||||
offset: int = Query(0, description="分页偏移量"),
|
offset: int = Query(0, description="分页偏移量"),
|
||||||
search: Optional[str] = Query(None, description="搜索内容"),
|
|
||||||
top_k: int = Query(5, description="搜索结果数量"),
|
|
||||||
):
|
):
|
||||||
|
|
||||||
# 解析查询参数为字典
|
# 解析查询参数为字典
|
||||||
@@ -110,21 +139,35 @@ async def get_peoples(
|
|||||||
if marital_status:
|
if marital_status:
|
||||||
conds["marital_status"] = marital_status
|
conds["marital_status"] = marital_status
|
||||||
|
|
||||||
logging.info(f"conds: , limit: {limit}, offset: {offset}, search: {search}, top_k: {top_k}")
|
logging.info(f"conds: , limit: {limit}, offset: {offset}")
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
store = get_people_store()
|
service = get_people_service()
|
||||||
if search:
|
results, error = service.list(conds, limit=limit, offset=offset)
|
||||||
results = store.search(search, conds, ids=None, top_k=top_k)
|
logging.info(f"query results: {results}")
|
||||||
logging.info(f"search results: {results}")
|
if not error.success:
|
||||||
else:
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
results = store.query(conds, limit=limit, offset=offset)
|
|
||||||
logging.info(f"query results: {results}")
|
|
||||||
peoples = [people.to_dict() for people in results]
|
peoples = [people.to_dict() for people in results]
|
||||||
return BaseResponse(error_code=0, error_info="success", data=peoples)
|
return BaseResponse(error_code=0, error_info="success", data=peoples)
|
||||||
|
|
||||||
@api.delete("/peoples/{people_id}")
|
|
||||||
async def delete_people(people_id: str):
|
class RemarkRequest(BaseModel):
|
||||||
store = get_people_store()
|
content: str
|
||||||
store.delete(people_id)
|
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
|
||||||
|
@api.post("/people/{people_id}/remark")
|
||||||
|
async def post_remark(people_id: str, request: RemarkRequest):
|
||||||
|
service = get_people_service()
|
||||||
|
error = service.save_remark(people_id, request.content)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
|
||||||
|
@api.delete("/people/{people_id}/remark")
|
||||||
|
async def delete_remark(people_id: str):
|
||||||
|
service = get_people_service()
|
||||||
|
error = service.delete_remark(people_id)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../src'))
|
|
||||||
|
|
||||||
from utils.logger import init
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# 初始化日志
|
|
||||||
init()
|
|
||||||
|
|
||||||
# 测试不同级别的日志
|
|
||||||
logging.debug("这是一条调试信息")
|
|
||||||
logging.info("这是一条普通信息")
|
|
||||||
logging.warning("这是一条警告信息")
|
|
||||||
logging.error("这是一条错误信息")
|
|
||||||
logging.critical("这是一条严重错误信息")
|
|
||||||
Reference in New Issue
Block a user