Release v0.1

This commit is contained in:
2025-11-12 23:54:02 +08:00
20 changed files with 2885 additions and 0 deletions

8
.gitignore vendored
View File

@@ -205,3 +205,11 @@ cython_debug/
marimo/_static/
marimo/_lsp/
__marimo__/
# Other
configuration/
logs/
.DS_Store
# Test
localstore/

18
pyproject.toml Normal file
View File

@@ -0,0 +1,18 @@
[project]
name = "service"
version = "0.1"
description = "This project is the web servcie sub-system for if.u projuect"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"alibabacloud-ocr-api20210707>=3.1.3",
"alibabacloud-tea-openapi>=0.4.1",
"fastapi>=0.118.3",
"langchain==0.3.27",
"langchain-openai==0.3.35",
"pymysql>=1.1.2",
"python-multipart>=0.0.20",
"qiniu>=7.17.0",
"sqlalchemy>=2.0.44",
"uvicorn>=0.38.0",
]

0
src/agents/__init__.py Normal file
View File

22
src/agents/base_agent.py Normal file
View File

@@ -0,0 +1,22 @@
from langchain_openai import ChatOpenAI
from utils.config import get_instance as get_config
class BaseAgent:
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
config = get_config()
llm_api_url = api_url or config.get("ai", "llm_api_url")
llm_api_key = api_key or config.get("ai", "llm_api_key")
llm_model_name = model_name or config.get("ai", "llm_model_name")
self.llm = ChatOpenAI(
openai_api_key=llm_api_key,
openai_api_base=llm_api_url,
model_name=llm_model_name,
)
pass
class SummaryPeopleAgent(BaseAgent):
def __init__(self):
super().__init__()
pass

View File

@@ -0,0 +1,52 @@
import datetime
import json
import logging
from langchain.prompts import ChatPromptTemplate
from .base_agent import BaseAgent
from models.people import People
class ExtractPeopleAgent(BaseAgent):
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
super().__init__(api_url, api_key, model_name)
self.prompt = ChatPromptTemplate.from_messages([
(
"system",
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')}"
"你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n"
"姓名 name\n"
"性别 gender\n"
"年龄 age\n"
"身高(cm) height\n"
"婚姻状况 marital_status\n"
"择偶要求 match_requirement\n"
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n"
"其中,'年龄 age''身高(cm) height' 必须是一个整数,不能是一个字符串;\n"
"并且,'性别 gender' 根据识别结果,必须从 男,女,未知 三选一填写。\n"
"除了上述基本信息,还有一个字段\n"
"个人介绍 introduction\n"
"其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n"
"个人介绍的字典的 key 需要使用提炼好的中文。\n"
),
("human", "{input}")
])
def extract_people_info(self, text: str) -> People:
"""从文本中提取个人信息"""
prompt = self.prompt.format_prompt(input=text)
response = self.llm.invoke(prompt)
logging.info(f"llm response: {response.content}")
try:
people = People.from_dict(json.loads(response.content))
err = people.validate()
if not err.success:
raise ValueError(f"Failed to validate people info: {err.info}")
return people
except json.JSONDecodeError:
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
return None
except ValueError as e:
logging.error(f"Failed to validate people info: {e}")
return None
pass

35
src/main.py Normal file
View File

@@ -0,0 +1,35 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-09-27
import os
import argparse
import uvicorn
from services import people as people_service
from utils import config, logger, obs, ocr, rldb
from web.api import api
# 主函数
def main():
main_path = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser(description='IF.u 服务')
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
args = parser.parse_args()
config.init(args.config)
logger.init()
rldb.init()
ocr.init()
obs.init()
people_service.init()
conf = config.get_instance()
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
port = conf.getint('web_service', 'server_port', fallback=8099)
uvicorn.run(api, host=host, port=port)
if __name__ == "__main__":
main()

0
src/models/__init__.py Normal file
View File

150
src/models/people.py Normal file
View File

@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-09-30
import json
import logging
from typing import Dict
from sqlalchemy import Column, Integer, String, Text, DateTime, func
from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error
class PeopleRLDBModel(RLDBBaseModel):
__tablename__ = 'peoples'
id = Column(String(36), primary_key=True)
name = Column(String(255), index=True)
contact = Column(String(255), index=True)
gender = Column(String(10))
age = Column(Integer)
height = Column(Integer)
marital_status = Column(String(20))
match_requirement = Column(Text)
introduction = Column(Text)
comments = Column(Text)
cover = Column(String(255), nullable=True)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
class People:
# 数据库 ID
id: str
# 姓名
name: str
# 联系人
contact: str
# 性别
gender: str
# 年龄
age: int
# 身高(cm)
height: int
# 婚姻状况
marital_status: str
# 择偶要求
match_requirement: str
# 个人介绍
introduction: Dict[str, str]
# 总结评价
comments: Dict[str, str]
# 封面
cover: str = None
def __init__(self, **kwargs):
# 初始化所有属性从kwargs中获取值如果不存在则设置默认值
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.name = kwargs.get('name', '') if kwargs.get('name', '') is not None else ''
self.contact = kwargs.get('contact', '') if kwargs.get('contact', '') is not None else ''
self.gender = kwargs.get('gender', '') if kwargs.get('gender', '') is not None else ''
self.age = kwargs.get('age', 0) if kwargs.get('age', 0) is not None else 0
self.height = kwargs.get('height', 0) if kwargs.get('height', 0) is not None else 0
self.marital_status = kwargs.get('marital_status', '') if kwargs.get('marital_status', '') is not None else ''
self.match_requirement = kwargs.get('match_requirement', '') if kwargs.get('match_requirement', '') is not None else ''
self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {}
self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {}
self.cover = kwargs.get('cover', None) if kwargs.get('cover', None) is not None else None
def __str__(self) -> str:
# 返回对象的字符串表示,包含所有属性
return (f"People(id={self.id}, name={self.name}, contact={self.contact}, gender={self.gender}, "
f"age={self.age}, height={self.height}, marital_status={self.marital_status}, "
f"match_requirement={self.match_requirement}, introduction={self.introduction}, "
f"comments={self.comments}, cover={self.cover})")
@classmethod
def from_dict(cls, data: dict):
if 'created_at' in data:
# 移除 created_at 字段,避免类型错误
del data['created_at']
if 'updated_at' in data:
# 移除 updated_at 字段,避免类型错误
del data['updated_at']
if 'deleted_at' in data:
# 移除 deleted_at 字段,避免类型错误
del data['deleted_at']
return cls(**data)
@classmethod
def from_rldb_model(cls, data: PeopleRLDBModel):
# 将关系数据库模型转换为对象
return cls(
id=data.id,
name=data.name,
contact=data.contact,
gender=data.gender,
age=data.age,
height=data.height,
marital_status=data.marital_status,
match_requirement=data.match_requirement,
introduction=json.loads(data.introduction) if data.introduction else {},
comments=json.loads(data.comments) if data.comments else {},
cover=data.cover,
)
def to_dict(self) -> dict:
# 将对象转换为字典格式
return {
'id': self.id,
'name': self.name,
'contact': self.contact,
'gender': self.gender,
'age': self.age,
'height': self.height,
'marital_status': self.marital_status,
'match_requirement': self.match_requirement,
'introduction': self.introduction,
'comments': self.comments,
'cover': self.cover,
}
def to_rldb_model(self) -> PeopleRLDBModel:
# 将对象转换为关系数据库模型
return PeopleRLDBModel(
id=self.id,
name=self.name,
contact=self.contact,
gender=self.gender,
age=self.age,
height=self.height,
marital_status=self.marital_status,
match_requirement=self.match_requirement,
introduction=json.dumps(self.introduction, ensure_ascii=False),
comments=json.dumps(self.comments, ensure_ascii=False),
cover=self.cover,
)
def validate(self) -> error:
err = error(ErrorCode.SUCCESS, "")
if not self.name:
logging.error("Name is required")
err = error(ErrorCode.MODEL_ERROR, "Name is required")
if not self.gender in ['', '', '未知']:
logging.error("Gender must be '', '', or '未知'")
err = error(ErrorCode.MODEL_ERROR, "Gender must be '', '', or '未知'")
if not isinstance(self.age, int) or self.age <= 0:
logging.error("Age must be an integer and greater than 0")
err = error(ErrorCode.MODEL_ERROR, "Age must be an integer and greater than 0")
if not isinstance(self.height, int) or self.height <= 0:
logging.error("Height must be an integer and greater than 0")
err = error(ErrorCode.MODEL_ERROR, "Height must be an integer and greater than 0")
return err

0
src/services/__init__.py Normal file
View File

78
src/services/people.py Normal file
View File

@@ -0,0 +1,78 @@
import uuid
from models.people import People, PeopleRLDBModel
from utils.error import ErrorCode, error
from utils import rldb
class PeopleService:
def __init__(self):
self.rldb = rldb.get_instance()
def save(self, people: People) -> (str, error):
"""
保存人物到数据库和向量数据库
:param people: 人物对象
:return: 人物ID
"""
# 0. 生成 people id
people.id = people.id if people.id else uuid.uuid4().hex
# 1. 转换模型,并保存到 SQL 数据库
people_orm = people.to_rldb_model()
self.rldb.upsert(people_orm)
return people.id, error(ErrorCode.SUCCESS, "")
def delete(self, people_id: str) -> error:
"""
删除人物从数据库和向量数据库
:param people_id: 人物ID
:return: 错误对象
"""
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
if not people_orm:
return error(ErrorCode.RLDB_ERROR, f"people {people_id} not found")
self.rldb.delete(people_orm)
return error(ErrorCode.SUCCESS, "")
def get(self, people_id: str) -> (People, error):
"""
从数据库获取人物
:param people_id: 人物ID
:return: 人物对象
"""
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
if not people_orm:
return None, error(ErrorCode.MODEL_ERROR, f"people {people_id} not found")
return People.from_rldb_model(people_orm), error(ErrorCode.SUCCESS, "")
def list(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> (list[People], error):
"""
从数据库列出人物
:param conds: 查询条件字典
:param limit: 分页大小
:param offset: 分页偏移量
:return: 人物对象列表
"""
people_orms = self.rldb.query(PeopleRLDBModel, **conds)
peoples = [People.from_rldb_model(people_orm) for people_orm in people_orms]
return peoples, error(ErrorCode.SUCCESS, "")
people_service = None
def init():
global people_service
people_service = PeopleService()
def get_instance() -> PeopleService:
return people_service

0
src/utils/__init__.py Normal file
View File

24
src/utils/config.py Normal file
View File

@@ -0,0 +1,24 @@
import configparser
config = None
def init(config_file: str):
global config
config = configparser.ConfigParser()
config.read(config_file)
def get_instance() -> configparser.ConfigParser:
return config
if __name__ == "__main__":
# 本文件的绝对路径
import os
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
init(config_file)
conf = get_instance()
print(conf.sections())
for section in conf.sections():
print(conf.options(section))
for option in conf.options(section):
print(f"{section}.{option}={conf.get(section, option)}")

30
src/utils/error.py Normal file
View File

@@ -0,0 +1,30 @@
from enum import Enum
import logging
from typing import Protocol
class ErrorCode(Enum):
SUCCESS = 0
MODEL_ERROR = 1000
RLDB_ERROR = 2100
class error(Protocol):
_error_code: int = 0
_error_info: str = ""
def __init__(self, error_code: ErrorCode, error_info: str):
self._error_code = int(error_code.value)
self._error_info = error_info
logging.info(f"errorcode: {type(self._error_code)}")
def __str__(self) -> str:
return f"{self.__class__.__name__}({self._error_code}, {self._error_info})"
@property
def code(self) -> int:
return self._error_code
@property
def info(self) -> str:
return self._error_info
@property
def success(self) -> bool:
return self._error_code == 0

91
src/utils/logger.py Normal file
View File

@@ -0,0 +1,91 @@
import logging
import os
from datetime import datetime
from .config import get_instance as get_config
# 定义颜色代码
class Colors:
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
WHITE = '\033[37m'
RESET = '\033[0m' # 重置颜色
# 自定义控制台处理器,为不同日志级别添加颜色
class ColoredConsoleHandler(logging.StreamHandler):
def emit(self, record):
# 为不同日志级别设置颜色
colors = {
logging.DEBUG: Colors.CYAN,
logging.INFO: Colors.GREEN,
logging.WARNING: Colors.YELLOW,
logging.ERROR: Colors.RED,
logging.CRITICAL: Colors.MAGENTA
}
# 获取对应级别的颜色,默认为白色
color = colors.get(record.levelno, Colors.WHITE)
# 获取原始消息
message = self.format(record)
# 添加颜色并输出
self.stream.write(f"{color}{message}{Colors.RESET}\n")
self.flush()
def init():
config = get_config()
log_dir = config.get("log", "log_dir", fallback="logs")
log_file = config.get("log", "log_file", fallback="if.u.service")
log_level = config.get("log", "log_level", fallback=logging.INFO)
console_log_level = config.get("log", "console_log_level", fallback=logging.DEBUG)
# 创建logs目录如果不存在
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# 设置日志格式
log_format = "[%(asctime)s.%(msecs)03d][%(filename)s:%(lineno)d][%(levelname)s] %(message)s"
date_format = "%Y-%m-%d %H:%M:%S"
# 创建格式化器
formatter = logging.Formatter(log_format, datefmt=date_format)
# 获取根日志记录器
root_logger = logging.getLogger()
root_logger.setLevel(logging.NOTSET)
# 清除现有的处理器
root_logger.handlers.clear()
# 创建控制台处理器并设置颜色
console_handler = ColoredConsoleHandler()
console_handler.setFormatter(formatter)
console_handler.setLevel(console_log_level)
root_logger.addHandler(console_handler)
# 创建文件处理器
log_filename = os.path.join(log_dir, f"{log_file}_{datetime.now().strftime('%Y%m%d')}.log")
file_handler = logging.FileHandler(log_filename, encoding='utf-8')
file_handler.setFormatter(formatter)
file_handler.setLevel(log_level)
root_logger.addHandler(file_handler)
# 确保日志消息被正确处理
logging.addLevelName(logging.DEBUG, "D")
logging.addLevelName(logging.INFO, "I")
logging.addLevelName(logging.WARNING, "W")
logging.addLevelName(logging.ERROR, "E")
logging.addLevelName(logging.CRITICAL, "C")
if __name__ == "__main__":
init(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG)
logging.debug("debug log")
logging.info("info log")
logging.warning("warning log")
logging.error("error log")
logging.critical("critical log")

220
src/utils/obs.py Normal file
View File

@@ -0,0 +1,220 @@
import json
import logging
from typing import Protocol
import qiniu
import requests
from .config import get_instance as get_config
class OBS(Protocol):
def Put(self, obs_path: str, content: bytes) -> str:
"""
上传文件到OBS
Args:
obs_path (str): OBS目标路径
content (bytes): 文件内容
Returns:
str: OBS文件路径
"""
...
def Get(self, obs_path: str) -> bytes:
"""
从OBS下载文件
Args:
obs_path (str): OBS文件路径
Returns:
bytes: 文件内容
"""
...
def List(self, obs_path: str) -> list:
"""
列出OBS目录下的所有文件
Args:
obs_path (str): OBS目录路径
Returns:
list: 所有文件路径列表
"""
...
def Del(self, obs_path: str) -> bool:
"""
删除OBS文件
Args:
obs_path (str): OBS文件路径
Returns:
bool: 是否删除成功
"""
...
def Link(self, obs_path: str) -> str:
"""
获取OBS文件链接
Args:
obs_path (str): OBS文件路径
Returns:
str: OBS文件链接
"""
...
class Koodo:
def __init__(self):
config = get_config()
self.bucket_name = config.get('koodo_obs', 'bucket_name')
self.prefix_path = config.get('koodo_obs', 'prefix_path')
self.access_key = config.get('koodo_obs', 'access_key')
self.secret_key = config.get('koodo_obs', 'secret_key')
self.outer_domain = config.get('koodo_obs', 'outer_domain')
self.auth = qiniu.Auth(self.access_key, self.secret_key)
self.bucket = qiniu.BucketManager(self.auth)
pass
def Put(self, obs_path: str, content: bytes) -> str:
"""
上传文件到OBS
Args:
obs_path (str): OBS目标路径
content (bytes): 文件内容
Returns:
str: OBS文件路径
"""
full_path = f"{self.prefix_path}{obs_path}"
token = self.auth.upload_token(self.bucket_name, full_path)
ret, info = qiniu.put_data(token, full_path, content)
logging.debug(f"文件 {obs_path} 上传到 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}")
if ret is None or info.status_code != 200:
logging.error(f"文件 {obs_path} 上传失败, 错误信息: {info.text_body}")
return ""
logging.info(f"文件 {obs_path} 上传成功, OBS路径: {full_path}")
return f"{self.outer_domain}/{full_path}"
def Get(self, obs_path: str) -> bytes:
"""
从OBS下载文件
Args:
obs_path (str): OBS文件路径
Returns:
bytes: 文件内容
"""
link = f"{self.outer_domain}/{self.prefix_path}{obs_path}"
resp = requests.get(link)
data = json.loads(resp.text)
if 'error' in data and data['error']:
logging.error(f"从 OBS {obs_path} 下载文件失败, 错误信息: {data['error']}")
return None
return resp.content
def List(self, prefix: str = "") -> list[str]:
"""
列出OBS目录下的所有文件
Args:
prefix (str, optional): OBS目录路径前缀. Defaults to "".
Returns:
list: 文件路径列表
"""
prefix = f"{self.prefix_path}{prefix}"
ret, eof, info = self.bucket.list(self.bucket_name, prefix)
keys = []
for item in ret['items']:
item['key'] = item['key'].replace(prefix, "")
keys.append(item['key'])
# logging.debug(f"文件 {item['key']} 路径: {item['key']}")
# logging.debug(f"ret: {ret}")
# logging.debug(f"eof: {eof}")
# logging.debug(f"info: {info}")
return keys
def Del(self, obs_path: str) -> bool:
"""
删除OBS文件
Args:
obs_path (str): OBS文件路径
Returns:
bool: 是否删除成功
"""
ret, info = self.bucket.delete(self.bucket_name, f"{self.prefix_path}{obs_path}")
logging.debug(f"文件 {obs_path} 删除 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}")
if ret is None or info.status_code != 200:
logging.error(f"文件 {obs_path} 删除 OBS 失败, 错误信息: {info.text_body}")
return False
logging.info(f"文件 {obs_path} 删除 OBS 成功")
return True
def Link(self, obs_path: str) -> str:
"""
获取OBS文件链接
Args:
obs_path (str): OBS文件路径
Returns:
str: OBS文件链接
"""
return f"{self.outer_domain}/{self.prefix_path}{obs_path}"
_obs_instance: OBS = None
def init():
global _obs_instance
_obs_instance = Koodo()
def get_instance() -> OBS:
global _obs_instance
return _obs_instance
if __name__ == "__main__":
import os
from .logger import init as init_logger
init_logger(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG)
from .config import init as init_config, get_instance as get_config
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
init_config(config_file)
init()
obs = get_instance()
# 从OBS下载测试图片
# obs_path = "test111.PNG"
# local_path = os.path.join(os.path.dirname(__file__), "../../test/9e03ad5eb8b1a51e752fb79cd8f98169.PNG")
# content = None
# with open(local_path, "rb") as f:
# content = f.read()
# if content is None:
# print(f"文件 {local_path} 读取失败")
# exit(1)
# obs.Put(obs_path, content)
# link = obs.Link(obs_path)
# print(f"文件 {obs_path} 链接: {link}")
# 列出OBS目录下的所有文件
keys = obs.List("")
print(f"OBS 目录下的所有文件: {keys}")
for key in keys:
link = obs.Del(key)
print(f"文件 {key} 删除 OBS 成功: {link}")

105
src/utils/ocr.py Normal file
View File

@@ -0,0 +1,105 @@
import json
import logging
from typing import Protocol
from alibabacloud_ocr_api20210707.client import Client as OcrClient
from alibabacloud_tea_openapi import models as open_api_models
from alibabacloud_ocr_api20210707 import models as ocr_models
from alibabacloud_tea_util import models as util_models
from alibabacloud_tea_util.client import Client as UtilClient
from .config import get_instance as get_config
class OCR(Protocol):
def recognize_image_text(self, image_link: str) -> str:
"""
从图片提取文本
Args:
image_link (str): 图片链接
Returns:
str: 提取到的文本
"""
...
class AliOCR:
def __init__(self):
config = get_config()
self.access_key = config.get("ali_ocr", "access_key")
self.secret_key = config.get("ali_ocr", "secret_key")
self.endpoint = config.get("ali_ocr", "endpoint")
self.client = self._create_client()
def _create_client(self):
config = open_api_models.Config(
access_key_id=self.access_key,
access_key_secret=self.secret_key,
)
config.endpoint = self.endpoint
return OcrClient(config)
def recognize_image_text(self, image_link: str) -> str:
"""
使用阿里云OCR从图片链接提取文本
Args:
image_link (str): 图片链接
Returns:
str: 提取到的文本
"""
# 创建OCR请求
recognize_general_request = ocr_models.RecognizeGeneralRequest(url=image_link)
runtime = util_models.RuntimeOptions()
try:
resp = self.client.recognize_general_with_options(recognize_general_request, runtime)
logging.debug(resp.body.data)
except Exception as error:
# 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。
# 错误 message
logging.error(error.message)
# 诊断地址
logging.error(error.data.get("Recommend"))
UtilClient.assert_as_string(error.message)
response = self.client.recognize_general_with_options(recognize_general_request, runtime)
if response.status_code == 200 and response.body:
result_data = response.body.data
result_body = json.loads(result_data)
if result_body and 'content' in result_body:
return result_body['content']
return ""
# 全局OCR实例
_ocr_instance = None
def init():
"""初始化OCR实例"""
global _ocr_instance
_ocr_instance = AliOCR()
def get_instance() -> OCR:
"""获取OCR实例"""
global _ocr_instance
if _ocr_instance is None:
raise RuntimeError("OCR模块未初始化请先调用init()函数")
return _ocr_instance
if __name__ == "__main__":
import os
from logger import init as init_logger
init_logger(console_log_level=logging.DEBUG)
from config import init as init_config
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
init_config(config_file)
init()
ocr = get_instance()
text = ocr.recognize_image_text(image_link="https://pic.mamamiyear.site/test.if.u/test111.PNG")
print(text)

167
src/utils/rldb.py Normal file
View File

@@ -0,0 +1,167 @@
from typing import Protocol
import uuid
from sqlalchemy import Column, DateTime, String, create_engine, func
from sqlalchemy.orm import declarative_base, sessionmaker
from .config import get_instance as get_config
SQLAlchemyBase = declarative_base()
class RLDBBaseModel(SQLAlchemyBase):
__abstract__ = True
id = Column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
def __str__(self) -> str:
# 遍历所有的field打印出所有的field和value, id 永远排在第一, 三个时间戳排在最后, 其余字段按定义顺序排序
fields = [field for field in self.__dict__ if not field.startswith('_')]
fields.remove("id") if "id" in fields else None
fields.remove("created_at") if "created_at" in fields else None
fields.remove("updated_at") if "updated_at" in fields else None
fields.remove("deleted_at") if "deleted_at" in fields else None
fields = ["id"] + fields + ["created_at", "updated_at", "deleted_at"]
field_values = [f"{field}={getattr(self, field)}" for field in fields]
return f"{self.__class__.__name__}({', '.join(field_values)})"
class RelationalDB(Protocol):
def insert(self, data: RLDBBaseModel) -> str:
...
def update(self, data: RLDBBaseModel) -> str:
...
def upsert(self, data: RLDBBaseModel) -> str:
...
def delete(self, data: RLDBBaseModel) -> str:
...
def get(self,
model: type[RLDBBaseModel],
id: str,
include_deleted: bool = False
) -> RLDBBaseModel:
...
def query(self,
model: type[RLDBBaseModel],
include_deleted: bool = False,
limit: int = None,
offset: int = None,
**filters
) -> list[RLDBBaseModel]:
...
class SqlAlchemyDB():
def __init__(self, dsn: str = None) -> None:
config = get_config()
dsn = dsn if dsn else config.get("sqlalchemy", "database_dsn")
self.sqldb_engine = create_engine(dsn)
SQLAlchemyBase.metadata.create_all(self.sqldb_engine)
self.session_maker = sessionmaker(bind=self.sqldb_engine)
def insert(self, data: RLDBBaseModel) -> str:
with self.session_maker() as session:
session.add(data)
session.commit()
return data.id
def update(self, data: RLDBBaseModel) -> str:
with self.session_maker() as session:
session.merge(data)
session.commit()
return data.id
def upsert(self, data: RLDBBaseModel) -> str:
existed = data.id and data.id != "" and self.get(data.__class__, data.id) is not None
with self.session_maker() as session:
session.merge(data) if existed else session.add(data)
session.commit()
return data.id
def delete(self, data: RLDBBaseModel) -> str:
with self.session_maker() as session:
session.delete(data)
session.commit()
return data.id
def get(self,
model: type[RLDBBaseModel],
id: str,
) -> RLDBBaseModel:
with self.session_maker() as session:
sel = session.query(model)
sel = sel.filter(model.id == id)
sel = sel.filter(model.deleted_at.is_(None))
result = sel.first()
return result
def query(self,
model: type[RLDBBaseModel],
limit: int = None,
offset: int = None,
**filters
) -> list[RLDBBaseModel]:
results: list[RLDBBaseModel] = []
with self.session_maker() as session:
sel = session.query(model)
sel = sel.filter(model.deleted_at.is_(None))
if filters:
sel = sel.filter_by(**filters)
if limit:
sel = sel.limit(limit)
if offset:
sel = sel.offset(offset)
results = sel.all()
results.sort(key=lambda x: x.created_at, reverse=True)
return results
_rldb_instance: RelationalDB = None
def init(type: str = "sqlalchemy", dsn: str = None):
global _rldb_instance
if type == "sqlalchemy":
_rldb_instance = SqlAlchemyDB(dsn)
else:
raise ValueError(f"RelationalDB type {type} not supported")
def get_instance() -> RelationalDB:
global _rldb_instance
return _rldb_instance
if __name__ == "__main__":
class TestModel(RLDBBaseModel):
__tablename__ = "test_model"
name = Column(String(36), nullable=True)
conf = Column(String(96), nullable=True)
init("sqlalchemy", dsn="sqlite:///./demo_storage/rldb.db")
db = get_instance()
test_data = TestModel(name="test", conf="test.config")
print(f"before insert: {test_data}")
ret = db.insert(test_data)
print(f"after insert: {test_data}")
print(f"before update: {test_data}")
test_data.conf = "test.config.new"
ret = db.update(test_data)
print(f"after update: {test_data}")
test2_data = TestModel(name="test", conf="test2.config")
print(f"before upsert: {test2_data}")
ret = db.upsert(test2_data)
print(f"after upsert: {test2_data}")
get_data = db.get(TestModel, test_data.id)
print(f"get data: {get_data}")
query_data = db.query(TestModel, name="test")
for data in query_data:
print(data.id, data.name, data.conf)
print(f"query data: {data}")
ret = db.delete(data)
print(f"delete data.id: {ret}")

0
src/web/__init__.py Normal file
View File

152
src/web/api.py Normal file
View File

@@ -0,0 +1,152 @@
import os
import uuid
import logging
from typing import Any, Optional
from fastapi import FastAPI, UploadFile, File, Query
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from services.people import get_instance as get_people_service
from models.people import People
from agents.extract_people_agent import ExtractPeopleAgent
from utils import obs, ocr
api = FastAPI(title="Single People Management and Searching", version="0.1")
api.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class BaseResponse(BaseModel):
error_code: int
error_info: str
data: Optional[Any] = None
@api.post("/ping")
async def ping():
return BaseResponse(error_code=0, error_info="success")
class PostInputRequest(BaseModel):
text: str
@api.post("/recognition/input")
async def post_input(request: PostInputRequest):
people = extract_people(request.text)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
@api.post("/recognition/image")
async def post_input_image(image: UploadFile = File(...)):
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 确保uploads目录存在
os.makedirs("uploads", exist_ok=True)
# 保存文件到对象存储
file_path = f"uploads/{unique_filename}"
obs_util = obs.get_instance()
obs_util.Put(file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.Link(file_path)
logging.info(f"obs_url: {obs_url}")
# 调用OCR处理图片
ocr_util = ocr.get_instance()
ocr_result = ocr_util.recognize_image_text(obs_url)
logging.info(f"ocr_result: {ocr_result}")
people = extract_people(ocr_result, obs_url)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
def extract_people(text: str, cover_link: str = None) -> People:
extra_agent = ExtractPeopleAgent()
people = extra_agent.extract_people_info(text)
people.cover = cover_link
logging.info(f"people: {people}")
return people
class PostPeopleRequest(BaseModel):
people: dict
@api.post("/people")
async def post_people(post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
service = get_people_service()
people.id, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=people.id)
@api.put("/people/{people_id}")
async def update_people(people_id: str, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.id = people_id
service = get_people_service()
res, error = service.get(people_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
_, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@api.delete("/people/{people_id}")
async def delete_people(people_id: str):
service = get_people_service()
error = service.delete(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
class GetPeopleRequest(BaseModel):
query: Optional[str] = None
conds: Optional[dict] = None
top_k: int = 5
@api.get("/peoples")
async def get_peoples(
name: Optional[str] = Query(None, description="姓名"),
gender: Optional[str] = Query(None, description="性别"),
age: Optional[int] = Query(None, description="年龄"),
height: Optional[int] = Query(None, description="身高"),
marital_status: Optional[str] = Query(None, description="婚姻状态"),
limit: int = Query(10, description="分页大小"),
offset: int = Query(0, description="分页偏移量"),
):
# 解析查询参数为字典
conds = {}
if name:
conds["name"] = name
if gender:
conds["gender"] = gender
if age:
conds["age"] = age
if height:
conds["height"] = height
if marital_status:
conds["marital_status"] = marital_status
logging.info(f"conds: , limit: {limit}, offset: {offset}")
results = []
service = get_people_service()
results, error = service.list(conds, limit=limit, offset=offset)
logging.info(f"query results: {results}")
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
peoples = [people.to_dict() for people in results]
return BaseResponse(error_code=0, error_info="success", data=peoples)

1733
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff