15 Commits

Author SHA1 Message Date
8174c4cfe5 Release v0.1 2025-11-12 23:54:02 +08:00
4c48d11bfa fix: agent recognize data type of age and height for people wrong 2025-11-12 23:52:45 +08:00
e74279ca5e chore: add some files into git ignore 2025-11-12 17:10:12 +08:00
13b70ba424 fix: exception during the conversion of the model to the RLDB model 2025-11-12 16:29:46 +08:00
fae93b5ab8 feat: add api routers for updating people
- put /people/{people_id}
2025-11-12 10:38:37 +08:00
1a092248eb feat: add api routers for recognition
- post     /recognition/input
- post     /recognition/image
2025-11-12 00:51:11 +08:00
0a749d56e8 refactor: add ai agents and implement extract people agent 2025-11-12 00:27:57 +08:00
3d13aa18ae refactor: add obs util and ocr util in project 2025-11-11 23:35:55 +08:00
d179418e7d feat: add api routers for CURD people
- post     /people
- delete   /people/{people_id}
- get      /peoples
2025-11-11 23:32:20 +08:00
496f35a386 refactor: define people service to CURD people 2025-11-11 23:31:26 +08:00
c99b324b81 refactor: add error util in project 2025-11-11 21:50:15 +08:00
7a189eb631 refactor: define people model and relational db model 2025-11-11 21:38:07 +08:00
98cbc754f6 refactor: add relation db access entry
- use sqlalchemy for adapt different relational database
2025-11-11 21:33:56 +08:00
736d8ed193 refactor: use config and logger in service 2025-11-11 21:22:13 +08:00
268eb8be2b feat: basic web service by fastapi and uvicorn 2025-11-11 21:14:07 +08:00
19 changed files with 896 additions and 2277 deletions

View File

@@ -1,18 +1,18 @@
[project]
name = "service"
version = "0.1.0"
version = "0.1"
description = "This project is the web servcie sub-system for if.u projuect"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"alibabacloud-ocr-api20210707>=3.1.3",
"chromadb>=1.1.1",
"fastapi>=0.118.2",
"langchain>=0.3.27",
"langchain-openai>=0.3.35",
"numpy>=2.3.3",
"alibabacloud-tea-openapi>=0.4.1",
"fastapi>=0.118.3",
"langchain==0.3.27",
"langchain-openai==0.3.35",
"pymysql>=1.1.2",
"python-multipart>=0.0.20",
"qiniu>=7.17.0",
"requests>=2.32.5",
"sqlalchemy>=2.0.44",
"uvicorn>=0.38.0",
]

22
src/agents/base_agent.py Normal file
View File

@@ -0,0 +1,22 @@
from langchain_openai import ChatOpenAI
from utils.config import get_instance as get_config
class BaseAgent:
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
config = get_config()
llm_api_url = api_url or config.get("ai", "llm_api_url")
llm_api_key = api_key or config.get("ai", "llm_api_key")
llm_model_name = model_name or config.get("ai", "llm_model_name")
self.llm = ChatOpenAI(
openai_api_key=llm_api_key,
openai_api_base=llm_api_url,
model_name=llm_model_name,
)
pass
class SummaryPeopleAgent(BaseAgent):
def __init__(self):
super().__init__()
pass

View File

@@ -1,25 +1,19 @@
import datetime
import json
import logging
from langchain_openai import ChatOpenAI
from langchain.prompts import ChatPromptTemplate
from .base_agent import BaseAgent
from models.people import People
class BaseAgent:
def __init__(self):
self.llm = ChatOpenAI(
openai_api_key="56d82040-85c7-4701-8f87-734985e27909",
openai_api_base="https://ark.cn-beijing.volces.com/api/v3",
model_name="ep-20250722161445-n9lfq"
)
pass
class ExtractPeopleAgent(BaseAgent):
def __init__(self):
super().__init__()
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
super().__init__(api_url, api_key, model_name)
self.prompt = ChatPromptTemplate.from_messages([
(
"system",
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')}"
"你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n"
"姓名 name\n"
"性别 gender\n"
@@ -28,6 +22,8 @@ class ExtractPeopleAgent(BaseAgent):
"婚姻状况 marital_status\n"
"择偶要求 match_requirement\n"
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n"
"其中,'年龄 age''身高(cm) height' 必须是一个整数,不能是一个字符串;\n"
"并且,'性别 gender' 根据识别结果,必须从 男,女,未知 三选一填写。\n"
"除了上述基本信息,还有一个字段\n"
"个人介绍 introduction\n"
"其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n"
@@ -42,13 +38,15 @@ class ExtractPeopleAgent(BaseAgent):
response = self.llm.invoke(prompt)
logging.info(f"llm response: {response.content}")
try:
return People.from_dict(json.loads(response.content))
people = People.from_dict(json.loads(response.content))
err = people.validate()
if not err.success:
raise ValueError(f"Failed to validate people info: {err.info}")
return people
except json.JSONDecodeError:
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
return None
pass
class SummaryPeopleAgent(BaseAgent):
def __init__(self):
super().__init__()
except ValueError as e:
logging.error(f"Failed to validate people info: {e}")
return None
pass

View File

@@ -1,52 +0,0 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-09-27
import logging
from ai.agent import ExtractPeopleAgent
from utils import ocr, vsdb, obs
from models.people import People
from fastapi import FastAPI
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
class App:
def __init__(self):
self.extract_people_agent = ExtractPeopleAgent()
self.ocr = ocr.get_instance()
self.vedb = vsdb.get_instance(db_type='chromadb')
self.obs = obs.get_instance()
def run(self):
pass
class InputForPeople:
image: bytes
text: str
def input_to_people(self, input: InputForPeople) -> People:
if not input.image and not input.text:
return None
if input.image:
content = self.ocr.recognize_image_text(input.image)
else:
content = input.text
print(content)
people = self._trans_text_to_people(content)
return people
def _trans_text_to_people(self, text: str) -> People:
if not text:
return None
person = self.extract_people_agent.extract_people_info(text)
print(person)
return person
def create_people(self, people: People) -> bool:
if not people:
return False
try:
people.save()
except Exception as e:
logging.error(f"保存人物到数据库失败: {e}")
return False
return True

View File

@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-09-27
import logging
import os
import argparse
from venv import logger
import uvicorn
from app.api import api
from utils import obs, ocr, vsdb, logger, config
from storage import people_store
from services import people as people_service
from utils import config, logger, obs, ocr, rldb
from web.api import api
# 主函数
def main():
@@ -15,14 +14,20 @@ def main():
parser = argparse.ArgumentParser(description='IF.u 服务')
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
args = parser.parse_args()
config.init(args.config)
logger.init()
obs.init()
rldb.init()
ocr.init()
vsdb.init()
people_store.init()
obs.init()
people_service.init()
conf = config.get_instance()
host = conf.get('web_service', 'server_host', fallback='127.0.0.1')
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
port = conf.getint('web_service', 'server_port', fallback=8099)
uvicorn.run(api, host=host, port=port)

View File

@@ -1,8 +1,29 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-09-30
import json
import logging
from typing import Dict
from sqlalchemy import Column, Integer, String, Text, DateTime, func
from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error
class PeopleRLDBModel(RLDBBaseModel):
__tablename__ = 'peoples'
id = Column(String(36), primary_key=True)
name = Column(String(255), index=True)
contact = Column(String(255), index=True)
gender = Column(String(10))
age = Column(Integer)
height = Column(Integer)
marital_status = Column(String(20))
match_requirement = Column(Text)
introduction = Column(Text)
comments = Column(Text)
cover = Column(String(255), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
class People:
@@ -18,15 +39,7 @@ class People:
age: int
# 身高(cm)
height: int
# 体重(kg)
# weight: int
# 婚姻状况
# [
# "未婚(single)",
# "已婚(married)",
# "离异(divorced)",
# "丧偶(widowed)"
# ]
marital_status: str
# 择偶要求
match_requirement: str
@@ -34,7 +47,8 @@ class People:
introduction: Dict[str, str]
# 总结评价
comments: Dict[str, str]
# 封面
cover: str = None
def __init__(self, **kwargs):
# 初始化所有属性从kwargs中获取值如果不存在则设置默认值
@@ -48,9 +62,14 @@ class People:
self.match_requirement = kwargs.get('match_requirement', '') if kwargs.get('match_requirement', '') is not None else ''
self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {}
self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {}
self.cover = kwargs.get('cover', None) if kwargs.get('cover', None) is not None else None
def __str__(self) -> str:
return self.tonl()
# 返回对象的字符串表示,包含所有属性
return (f"People(id={self.id}, name={self.name}, contact={self.contact}, gender={self.gender}, "
f"age={self.age}, height={self.height}, marital_status={self.marital_status}, "
f"match_requirement={self.match_requirement}, introduction={self.introduction}, "
f"comments={self.comments}, cover={self.cover})")
@classmethod
def from_dict(cls, data: dict):
@@ -65,6 +84,23 @@ class People:
del data['deleted_at']
return cls(**data)
@classmethod
def from_rldb_model(cls, data: PeopleRLDBModel):
# 将关系数据库模型转换为对象
return cls(
id=data.id,
name=data.name,
contact=data.contact,
gender=data.gender,
age=data.age,
height=data.height,
marital_status=data.marital_status,
match_requirement=data.match_requirement,
introduction=json.loads(data.introduction) if data.introduction else {},
comments=json.loads(data.comments) if data.comments else {},
cover=data.cover,
)
def to_dict(self) -> dict:
# 将对象转换为字典格式
return {
@@ -78,44 +114,37 @@ class People:
'match_requirement': self.match_requirement,
'introduction': self.introduction,
'comments': self.comments,
'cover': self.cover,
}
def meta(self) -> Dict[str, str]:
# 返回对象的元数据信息
meta = {
'id': self.id,
'name': self.name,
'gender': self.gender,
'age': self.age,
'height': self.height,
'marital_status': self.marital_status,
}
logging.info(f"people meta: {meta}")
return meta
def to_rldb_model(self) -> PeopleRLDBModel:
# 将对象转换为关系数据库模型
return PeopleRLDBModel(
id=self.id,
name=self.name,
contact=self.contact,
gender=self.gender,
age=self.age,
height=self.height,
marital_status=self.marital_status,
match_requirement=self.match_requirement,
introduction=json.dumps(self.introduction, ensure_ascii=False),
comments=json.dumps(self.comments, ensure_ascii=False),
cover=self.cover,
)
def tonl(self) -> str:
# 将对象转换为文档格式的字符串
doc = []
doc.append(f"姓名: {self.name}")
doc.append(f"性别: {self.gender}")
if self.age:
doc.append(f"年龄: {self.age}")
if self.height:
doc.append(f"身高: {self.height}cm")
if self.marital_status:
doc.append(f"婚姻状况: {self.marital_status}")
if self.match_requirement:
doc.append(f"择偶要求: {self.match_requirement}")
if self.introduction:
doc.append("个人介绍:")
for key, value in self.introduction.items():
doc.append(f" - {key}: {value}")
if self.comments:
doc.append("总结评价:")
for key, value in self.comments.items():
doc.append(f" - {key}: {value}")
return '\n'.join(doc)
def comment(self, comment: Dict[str, str]):
# 添加总结评价
self.comments.update(comment)
def validate(self) -> error:
err = error(ErrorCode.SUCCESS, "")
if not self.name:
logging.error("Name is required")
err = error(ErrorCode.MODEL_ERROR, "Name is required")
if not self.gender in ['', '', '未知']:
logging.error("Gender must be '', '', or '未知'")
err = error(ErrorCode.MODEL_ERROR, "Gender must be '', '', or '未知'")
if not isinstance(self.age, int) or self.age <= 0:
logging.error("Age must be an integer and greater than 0")
err = error(ErrorCode.MODEL_ERROR, "Age must be an integer and greater than 0")
if not isinstance(self.height, int) or self.height <= 0:
logging.error("Height must be an integer and greater than 0")
err = error(ErrorCode.MODEL_ERROR, "Height must be an integer and greater than 0")
return err

78
src/services/people.py Normal file
View File

@@ -0,0 +1,78 @@
import uuid
from models.people import People, PeopleRLDBModel
from utils.error import ErrorCode, error
from utils import rldb
class PeopleService:
def __init__(self):
self.rldb = rldb.get_instance()
def save(self, people: People) -> (str, error):
"""
保存人物到数据库和向量数据库
:param people: 人物对象
:return: 人物ID
"""
# 0. 生成 people id
people.id = people.id if people.id else uuid.uuid4().hex
# 1. 转换模型,并保存到 SQL 数据库
people_orm = people.to_rldb_model()
self.rldb.upsert(people_orm)
return people.id, error(ErrorCode.SUCCESS, "")
def delete(self, people_id: str) -> error:
"""
删除人物从数据库和向量数据库
:param people_id: 人物ID
:return: 错误对象
"""
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
if not people_orm:
return error(ErrorCode.RLDB_ERROR, f"people {people_id} not found")
self.rldb.delete(people_orm)
return error(ErrorCode.SUCCESS, "")
def get(self, people_id: str) -> (People, error):
"""
从数据库获取人物
:param people_id: 人物ID
:return: 人物对象
"""
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
if not people_orm:
return None, error(ErrorCode.MODEL_ERROR, f"people {people_id} not found")
return People.from_rldb_model(people_orm), error(ErrorCode.SUCCESS, "")
def list(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> (list[People], error):
"""
从数据库列出人物
:param conds: 查询条件字典
:param limit: 分页大小
:param offset: 分页偏移量
:return: 人物对象列表
"""
people_orms = self.rldb.query(PeopleRLDBModel, **conds)
peoples = [People.from_rldb_model(people_orm) for people_orm in people_orms]
return peoples, error(ErrorCode.SUCCESS, "")
people_service = None
def init():
global people_service
people_service = PeopleService()
def get_instance() -> PeopleService:
return people_service

View File

@@ -1,216 +0,0 @@
import json
import logging
import uuid
from datetime import datetime
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from utils.config import get_instance as get_config
from utils.vsdb import VectorDB, get_instance as get_vsdb
from utils.obs import OBS, get_instance as get_obs
from models.people import People
people_store = None
Base = declarative_base()
class PeopleORM(Base):
__tablename__ = 'peoples'
id = Column(String(36), primary_key=True)
name = Column(String(255), index=True)
contact = Column(String(255), index=True)
gender = Column(String(10))
age = Column(Integer)
height = Column(Integer)
marital_status = Column(String(20))
match_requirement = Column(Text)
introduction = Column(Text)
comments = Column(Text)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
def parse_from_people(self, people: People):
import json
self.id = people.id
self.name = people.name
self.contact = people.contact
self.gender = people.gender
self.age = people.age
self.height = people.height
self.marital_status = people.marital_status
self.match_requirement = people.match_requirement
# 将字典类型字段序列化为JSON字符串存储
self.introduction = json.dumps(people.introduction, ensure_ascii=False)
self.comments = json.dumps(people.comments, ensure_ascii=False)
def to_people(self) -> People:
import json
people = People()
people.id = self.id
people.name = self.name
people.contact = self.contact
people.gender = self.gender
people.age = self.age
people.height = self.height
people.marital_status = self.marital_status
people.match_requirement = self.match_requirement
# 将JSON字符串反序列化为字典类型字段
try:
people.introduction = json.loads(self.introduction) if self.introduction else {}
except (json.JSONDecodeError, TypeError):
people.introduction = {}
try:
people.comments = json.loads(self.comments) if self.comments else {}
except (json.JSONDecodeError, TypeError):
people.comments = {}
return people
class PeopleStore:
def __init__(self):
config = get_config()
self.sqldb_engine = create_engine(config.get("sqlalchemy", "database_dsn"))
Base.metadata.create_all(self.sqldb_engine)
self.session_maker = sessionmaker(bind=self.sqldb_engine)
self.vsdb: VectorDB = get_vsdb()
self.obs: OBS = get_obs()
def save(self, people: People) -> str:
"""
保存人物到数据库和向量数据库
:param people: 人物对象
:return: 人物ID
"""
# 0. 生成 people id
people.id = people.id if people.id else uuid.uuid4().hex
# 1. 转换模型,并保存到 SQL 数据库
people_orm = PeopleORM()
people_orm.parse_from_people(people)
with self.session_maker() as session:
session.add(people_orm)
session.commit()
# 2. 保存到向量数据库
people_metadata = people.meta()
people_document = people.tonl()
logging.info(f"people: {people}")
logging.info(f"people_metadata: {people_metadata}")
logging.info(f"people_document: {people_document}")
results = self.vsdb.insert(metadatas=[people_metadata], documents=[people_document], ids=[people.id])
logging.info(f"results: {results}")
if len(results) == 0:
raise Exception("insert failed")
# 3. 保存到 OBS 存储
people_dict = people.to_dict()
people_json = json.dumps(people_dict, ensure_ascii=False)
obs_url = self.obs.Put(f"peoples/{people.id}/detail.json", people_json.encode('utf-8'))
logging.info(f"obs_url: {obs_url}")
return people.id
def update(self, people: People) -> None:
raise Exception("update not implemented")
return None
def find(self, people_id: str) -> People:
"""
根据人物ID查询人物
:param people_id: 人物ID
:return: 人物对象
"""
with self.session_maker() as session:
people_orm = session.query(PeopleORM).filter(
PeopleORM.id == people_id,
PeopleORM.deleted_at.is_(None)
).first()
if not people_orm:
raise Exception(f"people not found, people_id: {people_id}")
return people_orm.to_people()
def query(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> list[People]:
"""
根据查询条件查询人物
:param conds: 查询条件字典
:param limit: 分页大小
:param offset: 分页偏移量
:return: 人物对象列表
"""
if conds is None:
conds = {}
with self.session_maker() as session:
people_orms = session.query(PeopleORM).filter_by(**conds).filter(
PeopleORM.deleted_at.is_(None)
).limit(limit).offset(offset).all()
return [people_orm.to_people() for people_orm in people_orms]
def search(self, search: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[People]:
"""
根据搜索内容和查询条件查询人物
:param search: 搜索内容
:param metadatas: 查询条件字典
:param ids: 可选的人物ID列表用于过滤结果
:param top_k: 返回结果数量
:return: 人物对象列表
"""
peoples = []
results = self.vsdb.search(document=search, metadatas=metadatas, ids=ids, top_k=top_k)
logging.info(f"results: {results}")
people_id_list = []
for result in results:
people_id = result.get('id', '')
if not people_id:
continue
people_id_list.append(people_id)
logging.info(f"people_id_list: {people_id_list}")
with self.session_maker() as session:
people_orms = session.query(PeopleORM).filter(PeopleORM.id.in_(people_id_list)).filter(
PeopleORM.deleted_at.is_(None)
).all()
# 根据 people_id_list 的顺序对查询结果进行排序
order_map = {pid: idx for idx, pid in enumerate(people_id_list)}
people_orms.sort(key=lambda orm: order_map.get(orm.id, len(order_map)))
for people_orm in people_orms:
people = people_orm.to_people()
peoples.append(people)
return peoples
def delete(self, people_id: str) -> None:
"""
删除人物从数据库和向量数据库
:param people_id: 人物ID
"""
# 1. 从 SQL 数据库软删除人物
with self.session_maker() as session:
session.query(PeopleORM).filter(
PeopleORM.id == people_id,
PeopleORM.deleted_at.is_(None)
).update({PeopleORM.deleted_at: func.now()}, synchronize_session=False)
session.commit()
logging.info(f"人物 {people_id} 标记删除 SQL 成功")
# 2. 删除向量数据库中的记录
self.vsdb.delete(ids=[people_id])
logging.info(f"人物 {people_id} 删除向量数据库成功")
# 3. 删除 OBS 存储中的文件
keys = self.obs.List(f"peoples/{people_id}/")
for key in keys:
self.obs.Del(key)
logging.info(f"文件 {key} 删除 OBS 成功")
def init():
global people_store
people_store = PeopleStore()
def get_instance() -> PeopleStore:
return people_store

View File

@@ -1,3 +0,0 @@
# 导出utils模块中的子模块
from . import config, obs, ocr, vsdb, logger
__all__ = ['config', 'obs', 'ocr', 'vsdb', 'logger']

View File

@@ -2,7 +2,6 @@ import configparser
config = None
def init(config_file: str):
global config
config = configparser.ConfigParser()

30
src/utils/error.py Normal file
View File

@@ -0,0 +1,30 @@
from enum import Enum
import logging
from typing import Protocol
class ErrorCode(Enum):
SUCCESS = 0
MODEL_ERROR = 1000
RLDB_ERROR = 2100
class error(Protocol):
_error_code: int = 0
_error_info: str = ""
def __init__(self, error_code: ErrorCode, error_info: str):
self._error_code = int(error_code.value)
self._error_info = error_info
logging.info(f"errorcode: {type(self._error_code)}")
def __str__(self) -> str:
return f"{self.__class__.__name__}({self._error_code}, {self._error_info})"
@property
def code(self) -> int:
return self._error_code
@property
def info(self) -> str:
return self._error_info
@property
def success(self) -> bool:
return self._error_code == 0

167
src/utils/rldb.py Normal file
View File

@@ -0,0 +1,167 @@
from typing import Protocol
import uuid
from sqlalchemy import Column, DateTime, String, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker
from .config import get_instance as get_config
SQLAlchemyBase = declarative_base()
class RLDBBaseModel(SQLAlchemyBase):
__abstract__ = True
id = Column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
def __str__(self) -> str:
# 遍历所有的field打印出所有的field和value, id 永远排在第一, 三个时间戳排在最后, 其余字段按定义顺序排序
fields = [field for field in self.__dict__ if not field.startswith('_')]
fields.remove("id") if "id" in fields else None
fields.remove("created_at") if "created_at" in fields else None
fields.remove("updated_at") if "updated_at" in fields else None
fields.remove("deleted_at") if "deleted_at" in fields else None
fields = ["id"] + fields + ["created_at", "updated_at", "deleted_at"]
field_values = [f"{field}={getattr(self, field)}" for field in fields]
return f"{self.__class__.__name__}({', '.join(field_values)})"
class RelationalDB(Protocol):
def insert(self, data: RLDBBaseModel) -> str:
...
def update(self, data: RLDBBaseModel) -> str:
...
def upsert(self, data: RLDBBaseModel) -> str:
...
def delete(self, data: RLDBBaseModel) -> str:
...
def get(self,
model: type[RLDBBaseModel],
id: str,
include_deleted: bool = False
) -> RLDBBaseModel:
...
def query(self,
model: type[RLDBBaseModel],
include_deleted: bool = False,
limit: int = None,
offset: int = None,
**filters
) -> list[RLDBBaseModel]:
...
class SqlAlchemyDB():
def __init__(self, dsn: str = None) -> None:
config = get_config()
dsn = dsn if dsn else config.get("sqlalchemy", "database_dsn")
self.sqldb_engine = create_engine(dsn)
SQLAlchemyBase.metadata.create_all(self.sqldb_engine)
self.session_maker = sessionmaker(bind=self.sqldb_engine)
def insert(self, data: RLDBBaseModel) -> str:
with self.session_maker() as session:
session.add(data)
session.commit()
return data.id
def update(self, data: RLDBBaseModel) -> str:
with self.session_maker() as session:
session.merge(data)
session.commit()
return data.id
def upsert(self, data: RLDBBaseModel) -> str:
existed = data.id and data.id != "" and self.get(data.__class__, data.id) is not None
with self.session_maker() as session:
session.merge(data) if existed else session.add(data)
session.commit()
return data.id
def delete(self, data: RLDBBaseModel) -> str:
with self.session_maker() as session:
session.delete(data)
session.commit()
return data.id
def get(self,
model: type[RLDBBaseModel],
id: str,
) -> RLDBBaseModel:
with self.session_maker() as session:
sel = session.query(model)
sel = sel.filter(model.id == id)
sel = sel.filter(model.deleted_at.is_(None))
result = sel.first()
return result
def query(self,
model: type[RLDBBaseModel],
limit: int = None,
offset: int = None,
**filters
) -> list[RLDBBaseModel]:
results: list[RLDBBaseModel] = []
with self.session_maker() as session:
sel = session.query(model)
sel = sel.filter(model.deleted_at.is_(None))
if filters:
sel = sel.filter_by(**filters)
if limit:
sel = sel.limit(limit)
if offset:
sel = sel.offset(offset)
results = sel.all()
results.sort(key=lambda x: x.created_at, reverse=True)
return results
_rldb_instance: RelationalDB = None
def init(type: str = "sqlalchemy", dsn: str = None):
global _rldb_instance
if type == "sqlalchemy":
_rldb_instance = SqlAlchemyDB(dsn)
else:
raise ValueError(f"RelationalDB type {type} not supported")
def get_instance() -> RelationalDB:
global _rldb_instance
return _rldb_instance
if __name__ == "__main__":
class TestModel(RLDBBaseModel):
__tablename__ = "test_model"
name = Column(String(36), nullable=True)
conf = Column(String(96), nullable=True)
init("sqlalchemy", dsn="sqlite:///./demo_storage/rldb.db")
db = get_instance()
test_data = TestModel(name="test", conf="test.config")
print(f"before insert: {test_data}")
ret = db.insert(test_data)
print(f"after insert: {test_data}")
print(f"before update: {test_data}")
test_data.conf = "test.config.new"
ret = db.update(test_data)
print(f"after update: {test_data}")
test2_data = TestModel(name="test", conf="test2.config")
print(f"before upsert: {test2_data}")
ret = db.upsert(test2_data)
print(f"after upsert: {test2_data}")
get_data = db.get(TestModel, test_data.id)
print(f"get data: {get_data}")
query_data = db.query(TestModel, name="test")
for data in query_data:
print(data.id, data.name, data.conf)
print(f"query data: {data}")
ret = db.delete(data)
print(f"delete data.id: {ret}")

View File

@@ -1,241 +0,0 @@
import uuid
import chromadb
import logging
from typing import Protocol
from chromadb.config import Settings
from chromadb.utils import embedding_functions
from .config import get_instance as get_config
class VectorDB(Protocol):
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
"""
插入向量到数据库
Args:
vector (list[float]): 向量
metadata (dict): 元数据
Returns:
bool: 是否插入成功
"""
...
def delete(self, ids: list[str]) -> bool:
"""
Delete documents from a collection.
Args:
ids: List of IDs to delete
Returns:
bool: Whether deletion was successful
"""
...
def query(self, metadatas: dict, ids: list[str], top_k: int = 5) -> list[dict]:
"""
查询向量数据库
Args:
query_vector (list[float]): 查询向量
top_k (int, optional): 返回Top K结果. Defaults to 5.
Returns:
list[dict]: 查询结果列表
"""
...
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
"""
搜索向量数据库
Args:
document: Document to search
metadatas: Metadata to filter by
ids: List of IDs to filter by
top_k (int, optional): 返回Top K结果. Defaults to 5.
Returns:
list[dict]: 查询结果列表
"""
...
class ChromaDB:
def __init__(self, **kwargs):
"""
Initialize the ChromaDB instance.
"""
config = get_config()
self.embedding_functions = embedding_functions.OpenAIEmbeddingFunction(
api_base=config.get("voc-engine_embedding", "api_url"),
api_key=config.get("voc-engine_embedding", "api_key"),
model_name=config.get("voc-engine_embedding", "endpoint"),
)
persist_directory = config.get("chroma_vsdb", "database_dir", fallback=None)
logging.debug(f"persist_directory: {persist_directory}")
if persist_directory:
self.client = chromadb.PersistentClient(
path=persist_directory,
settings=Settings(anonymized_telemetry=False)
)
else:
self.client = chromadb.Client(
settings=Settings(anonymized_telemetry=False),
)
self.collection_name = config.get("chroma_vsdb", "collection_name", fallback="peoples")
metadata: dict = kwargs.get('collection_metadata', {'hnsw:space': 'cosine'})
metadata['hnsw:space'] = metadata.get('hnsw:space', 'cosine')
self.collection = self.client.get_or_create_collection(
name=self.collection_name,
embedding_function=self.embedding_functions,
metadata=metadata,
)
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
"""
Insert documents into a collection.
Args:
metadatas: List of metadata corresponding to each document
documents: List of documents to insert
ids: Optional list of unique IDs for each document. If None, IDs will be generated.
Returns:
list[str]: List of inserted IDs
"""
if not ids:
# Generate unique IDs if not provided
ids = [str(uuid.uuid4()) for _ in range(len(documents))]
self.collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
return ids
def delete(self, ids: list[str]) -> bool:
"""
Delete documents from a collection.
Args:
ids: List of IDs to delete
Returns:
bool: Whether deletion was successful
"""
try:
self.collection.delete(ids)
return True
except Exception as e:
print(f"Error deleting documents: {e}")
return False
def query(self, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
"""
查询向量数据库
Args:
metadatas: Metadata to filter by
ids: List of IDs to query
top_k (int, optional): 返回Top K结果. Defaults to 5.
Returns:
list[dict]: 查询结果列表
"""
results = self.collection.query(
query_embeddings=None,
query_texts=None,
n_results=top_k,
where=metadatas,
ids=ids,
include=["documents", "metadatas", "distances"],
)
formatted_results = []
for i in range(len(results['ids'][0])):
result = {
'id': results['ids'][0][i],
'distance': results['distances'][0][i],
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
'document': results['documents'][0][i] if results['documents'][0] else ''
}
formatted_results.append(result)
return formatted_results
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
"""
搜索向量数据库
Args:
document: Document to search
metadatas: Metadata to filter by
ids: List of IDs to filter by
top_k (int, optional): 返回Top K结果. Defaults to 5.
Returns:
list[dict]: 查询结果列表
"""
results = self.collection.query(
query_embeddings=None,
query_texts=[document],
n_results=top_k,
where=metadatas if metadatas else None,
ids=ids,
include=["documents", "metadatas", "distances"],
)
formatted_results = []
for i in range(len(results['ids'][0])):
logging.info(f"result id: {results['ids'][0][i]}, distance: {results['distances'][0][i]}")
result = {
'id': results['ids'][0][i],
'distance': results['distances'][0][i],
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
'document': results['documents'][0][i] if results['documents'][0] else ''
}
formatted_results.append(result)
return formatted_results
pass
_vsdb_instance: VectorDB = None
def init():
global _vsdb_instance
_vsdb_instance = ChromaDB()
def get_instance() -> VectorDB:
global _vsdb_instance
return _vsdb_instance
if __name__ == "__main__":
import os
from logger import init as init_logger
init_logger(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG)
from config import init as init_config
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
init_config(config_file)
init()
vsdb = get_instance()
metadatas = [
{'name': '丽丽'},
{'name': '志刚'},
{'name': '张三'},
{'name': '李四'},
]
documents = [
'姓名: 丽丽, 性别: 女, 年龄: 23, 爱好: 爬山、骑行、攀岩、跳伞、蹦极',
"姓名: 志刚, 性别: 男, 年龄: 25, 爱好: 读书、游戏",
"姓名: 张三, 性别: 男, 年龄: 30, 爱好: 画画、写作、阅读、逛展、旅行",
"姓名: 李四, 性别: 男, 年龄: 35, 爱好: 做饭、美食、旅游"
]
search_text = '25岁以下的'
ids = vsdb.insert(metadatas, documents)
results = vsdb.search(search_text, None, None, top_k=4)
for result in results:
print(result['document'], ' ', result['distance'])

View File

@@ -1,20 +1,16 @@
import json
import logging
import os
import uuid
import logging
from typing import Any, Optional
from fastapi import FastAPI, File, UploadFile, Query
from fastapi import FastAPI, UploadFile, File, Query
from pydantic import BaseModel
from ai.agent import ExtractPeopleAgent
from models.people import People
from utils import obs, ocr, vsdb
from storage.people_store import get_instance as get_people_store
from fastapi.middleware.cors import CORSMiddleware
from services.people import get_instance as get_people_service
from models.people import People
from agents.extract_people_agent import ExtractPeopleAgent
from utils import obs, ocr
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
api = FastAPI(title="Single People Management and Searching", version="0.1")
api.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -28,19 +24,21 @@ class BaseResponse(BaseModel):
error_info: str
data: Optional[Any] = None
@api.post("/ping")
async def ping():
return BaseResponse(error_code=0, error_info="success")
class PostInputRequest(BaseModel):
text: str
@api.post("/input")
@api.post("/recognition/input")
async def post_input(request: PostInputRequest):
extra_agent = ExtractPeopleAgent()
people = extra_agent.extract_people_info(request.text)
logging.info(f"people: {people}")
people = extract_people(request.text)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
@api.post("/input_image")
@api.post("/recognition/image")
async def post_input_image(image: UploadFile = File(...)):
# 实现上传图片的处理
# 保存上传的图片文件
@@ -65,20 +63,53 @@ async def post_input_image(image: UploadFile = File(...)):
ocr_result = ocr_util.recognize_image_text(obs_url)
logging.info(f"ocr_result: {ocr_result}")
post_input_request = PostInputRequest(text=ocr_result)
return await post_input(post_input_request)
people = extract_people(ocr_result, obs_url)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
def extract_people(text: str, cover_link: str = None) -> People:
extra_agent = ExtractPeopleAgent()
people = extra_agent.extract_people_info(text)
people.cover = cover_link
logging.info(f"people: {people}")
return people
class PostPeopleRequest(BaseModel):
people: dict
@api.post("/peoples")
@api.post("/people")
async def post_people(post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
store = get_people_store()
people.id = store.save(people)
service = get_people_service()
people.id, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=people.id)
@api.put("/people/{people_id}")
async def update_people(people_id: str, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.id = people_id
service = get_people_service()
res, error = service.get(people_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
_, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@api.delete("/people/{people_id}")
async def delete_people(people_id: str):
service = get_people_service()
error = service.delete(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
class GetPeopleRequest(BaseModel):
query: Optional[str] = None
conds: Optional[dict] = None
@@ -93,8 +124,6 @@ async def get_peoples(
marital_status: Optional[str] = Query(None, description="婚姻状态"),
limit: int = Query(10, description="分页大小"),
offset: int = Query(0, description="分页偏移量"),
search: Optional[str] = Query(None, description="搜索内容"),
top_k: int = Query(5, description="搜索结果数量"),
):
# 解析查询参数为字典
@@ -110,21 +139,14 @@ async def get_peoples(
if marital_status:
conds["marital_status"] = marital_status
logging.info(f"conds: , limit: {limit}, offset: {offset}, search: {search}, top_k: {top_k}")
logging.info(f"conds: , limit: {limit}, offset: {offset}")
results = []
store = get_people_store()
if search:
results = store.search(search, conds, ids=None, top_k=top_k)
logging.info(f"search results: {results}")
else:
results = store.query(conds, limit=limit, offset=offset)
service = get_people_service()
results, error = service.list(conds, limit=limit, offset=offset)
logging.info(f"query results: {results}")
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
peoples = [people.to_dict() for people in results]
return BaseResponse(error_code=0, error_info="success", data=peoples)
@api.delete("/peoples/{people_id}")
async def delete_people(people_id: str):
store = get_people_store()
store.delete(people_id)
return BaseResponse(error_code=0, error_info="success")

View File

@@ -1,16 +0,0 @@
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '../src'))
from utils.logger import init
import logging
# 初始化日志
init()
# 测试不同级别的日志
logging.debug("这是一条调试信息")
logging.info("这是一条普通信息")
logging.warning("这是一条警告信息")
logging.error("这是一条错误信息")
logging.critical("这是一条严重错误信息")

2053
uv.lock generated

File diff suppressed because it is too large Load Diff