feat: support custom management
- add custom model and rldb model - add service for custom to operate rldb - add apis to CURD custom and image upload and delete - support to recognize custom from text or image - refactor web servcie start mode and api group - - group the apis - - support uvicorn start service in terminal - - refactor recognization api for both people and custom
This commit is contained in:
85
src/agents/extract_custom_agent.py
Normal file
85
src/agents/extract_custom_agent.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
from .base_agent import BaseAgent
|
||||||
|
from models.custom import Custom
|
||||||
|
|
||||||
|
class ExtractCustomAgent(BaseAgent):
|
||||||
|
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
|
||||||
|
super().__init__(api_url, api_key, model_name)
|
||||||
|
self.prompt = ChatPromptTemplate.from_messages([
|
||||||
|
(
|
||||||
|
"system",
|
||||||
|
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')},"
|
||||||
|
"你是一个专业的客户信息录入助手,善于从一段文字描述中,精确获取客户的以下信息:\n"
|
||||||
|
"姓名 name\n"
|
||||||
|
"性别 gender (男/女/未知)\n"
|
||||||
|
"出生年份 birth (整数年份,如 1990;若文本只提供了年龄,请根据当前日期计算出出生年份)\n"
|
||||||
|
"手机号 phone\n"
|
||||||
|
"邮箱 email\n"
|
||||||
|
"身高(cm) height (整数)\n"
|
||||||
|
"体重(kg) weight (整数)\n"
|
||||||
|
"学历 degree\n"
|
||||||
|
"毕业院校 academy\n"
|
||||||
|
"职业 occupation\n"
|
||||||
|
"年收入(万) income (整数)\n"
|
||||||
|
"资产(万) assets (整数)\n"
|
||||||
|
"流动资产(万) current_assets (整数)\n"
|
||||||
|
"房产情况 house (必须为以下之一: '有房无贷', '有房有贷', '无自有房', 若未提及则不填)\n"
|
||||||
|
"车辆情况 car (必须为以下之一: '有车无贷', '有车有贷', '无自有车', 若未提及则不填)\n"
|
||||||
|
"户口城市 registered_city\n"
|
||||||
|
"居住城市 live_city\n"
|
||||||
|
"籍贯 native_place\n"
|
||||||
|
"原生家庭情况 original_family\n"
|
||||||
|
"是否独生子女 is_single_child (true/false)\n"
|
||||||
|
"择偶要求 match_requirement\n"
|
||||||
|
"\n"
|
||||||
|
"以上信息需要严格按照 JSON 格式输出,字段名与条目中英文保持一致。\n"
|
||||||
|
"若未识别到某项,则不返回该字段,不要自行填写“未知”、“未填写”等。\n"
|
||||||
|
"\n"
|
||||||
|
"除了上述基本信息,还有一个字段:\n"
|
||||||
|
"其他介绍 introductions\n"
|
||||||
|
"其余的信息需要按照字典的方式进行提炼和总结,都放在 introductions 字段中,key 使用提炼好的中文。\n"
|
||||||
|
),
|
||||||
|
("human", "{input}")
|
||||||
|
])
|
||||||
|
|
||||||
|
def extract_custom_info(self, text: str) -> Custom:
|
||||||
|
"""从文本中提取客户信息"""
|
||||||
|
prompt = self.prompt.format_prompt(input=text)
|
||||||
|
response = self.llm.invoke(prompt)
|
||||||
|
logging.info(f"llm response: {response.content}")
|
||||||
|
try:
|
||||||
|
custom_dict = json.loads(response.content)
|
||||||
|
|
||||||
|
# 类型安全转换,防止LLM返回字符串类型的数字
|
||||||
|
int_fields = ['birth', 'height', 'weight', 'income', 'assets', 'scores', 'current_assets']
|
||||||
|
for field in int_fields:
|
||||||
|
if field in custom_dict and isinstance(custom_dict[field], str):
|
||||||
|
try:
|
||||||
|
# 尝试提取数字,简单处理
|
||||||
|
import re
|
||||||
|
num = re.findall(r'\d+', custom_dict[field])
|
||||||
|
if num:
|
||||||
|
custom_dict[field] = int(num[0])
|
||||||
|
else:
|
||||||
|
del custom_dict[field] # 无法转换则移除
|
||||||
|
except:
|
||||||
|
del custom_dict[field]
|
||||||
|
|
||||||
|
custom = Custom.from_dict(custom_dict)
|
||||||
|
err = custom.validate()
|
||||||
|
if not err.success:
|
||||||
|
logging.warning(f"Validation warning: {err.info}")
|
||||||
|
# 即使校验失败(如某些必填项缺失),也尽可能返回已提取的对象,
|
||||||
|
# 让上层业务逻辑决定是否接受或需要补充
|
||||||
|
return custom
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to process custom info: {e}")
|
||||||
|
return None
|
||||||
41
src/main.py
41
src/main.py
@@ -1,22 +1,23 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# created by mmmy on 2025-09-27
|
# created by mmmy on 2025-09-27
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Add src directory to sys.path to ensure modules can be imported correctly when running with uvicorn
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from services import people as people_service
|
from services import people as people_service
|
||||||
from services import user as user_service
|
from services import user as user_service
|
||||||
|
from services import custom as custom_service
|
||||||
from utils import config, logger, obs, ocr, rldb, sms, mailer
|
from utils import config, logger, obs, ocr, rldb, sms, mailer
|
||||||
|
|
||||||
from web.api import api
|
from web.api import api
|
||||||
|
|
||||||
# 主函数
|
def initialize_app(config_path):
|
||||||
def main():
|
"""Initialize application components with the given config path."""
|
||||||
main_path = os.path.dirname(os.path.abspath(__file__))
|
config.init(config_path)
|
||||||
parser = argparse.ArgumentParser(description='IF.u 服务')
|
|
||||||
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
config.init(args.config)
|
|
||||||
conf = config.get_instance()
|
conf = config.get_instance()
|
||||||
|
|
||||||
logger.init()
|
logger.init()
|
||||||
@@ -28,12 +29,28 @@ def main():
|
|||||||
|
|
||||||
people_service.init()
|
people_service.init()
|
||||||
user_service.init()
|
user_service.init()
|
||||||
|
custom_service.init()
|
||||||
|
|
||||||
|
|
||||||
|
# 主函数
|
||||||
|
def main():
|
||||||
|
main_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
parser = argparse.ArgumentParser(description='IF.u 服务')
|
||||||
|
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
initialize_app(args.config)
|
||||||
|
|
||||||
|
conf = config.get_instance()
|
||||||
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
|
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
|
||||||
port = conf.getint('web_service', 'server_port', fallback=8099)
|
port = conf.getint('web_service', 'server_port', fallback=8099)
|
||||||
uvicorn.run(api, host=host, port=port)
|
uvicorn.run("src.main:api", host=host, port=port, reload=True) # Modified to string import for reload support in main too, though api object also works
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
else:
|
||||||
|
# Support for running via 'uvicorn src.main:api'
|
||||||
|
# Use environment variable for config path or default
|
||||||
|
main_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
default_config_path = os.path.join(main_path, '../configuration/test_conf.ini')
|
||||||
|
config_path = os.environ.get('IFU_CONFIG_PATH', default_config_path)
|
||||||
|
initialize_app(config_path)
|
||||||
|
|||||||
291
src/models/custom.py
Normal file
291
src/models/custom.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# created by mmmy on 2025-11-27
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlalchemy import Column, Integer, String, Text, DateTime, func, Boolean
|
||||||
|
from utils.rldb import RLDBBaseModel
|
||||||
|
from utils.error import ErrorCode, error
|
||||||
|
|
||||||
|
class CustomRLDBModel(RLDBBaseModel):
|
||||||
|
"""
|
||||||
|
客户数据的数据库模型 (SQLAlchemy Model) - 更新版
|
||||||
|
"""
|
||||||
|
__tablename__ = 'customs'
|
||||||
|
id = Column(String(36), primary_key=True)
|
||||||
|
user_id = Column(String(36), index=True, nullable=False)
|
||||||
|
|
||||||
|
# 基本信息
|
||||||
|
name = Column(String(255), index=True, nullable=False)
|
||||||
|
gender = Column(String(10), nullable=False)
|
||||||
|
birth = Column(Integer, nullable=False) # 出生年份
|
||||||
|
phone = Column(String(50), index=True)
|
||||||
|
email = Column(String(255), index=True)
|
||||||
|
|
||||||
|
# 外貌信息
|
||||||
|
height = Column(Integer)
|
||||||
|
weight = Column(Integer)
|
||||||
|
images = Column(Text) # JSON string for list[str]
|
||||||
|
scores = Column(Integer)
|
||||||
|
|
||||||
|
# 学历职业
|
||||||
|
degree = Column(String(255))
|
||||||
|
academy = Column(String(255))
|
||||||
|
occupation = Column(String(255))
|
||||||
|
income = Column(Integer) # 单位:万
|
||||||
|
assets = Column(Integer) # 单位:万
|
||||||
|
current_assets = Column(Integer) # 单位:万
|
||||||
|
house = Column(String(50))
|
||||||
|
car = Column(String(50))
|
||||||
|
|
||||||
|
# 户口家庭
|
||||||
|
registered_city = Column(String(255))
|
||||||
|
live_city = Column(String(255))
|
||||||
|
native_place = Column(String(255))
|
||||||
|
original_family = Column(Text)
|
||||||
|
is_single_child = Column(Boolean, default=False)
|
||||||
|
|
||||||
|
match_requirement = Column(Text)
|
||||||
|
|
||||||
|
introductions = Column(Text) # JSON string for Dict[str, str]
|
||||||
|
|
||||||
|
# 客户信息
|
||||||
|
custom_level = Column(String(255))
|
||||||
|
comments = Column(Text) # JSON string for Dict[str, str]
|
||||||
|
is_public = Column(Boolean, default=False)
|
||||||
|
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||||
|
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||||
|
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||||
|
|
||||||
|
class Custom:
|
||||||
|
"""
|
||||||
|
客户数据的业务逻辑模型 (Business Logic Model) - 更新版
|
||||||
|
"""
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
# 基本信息
|
||||||
|
name: str
|
||||||
|
gender: str
|
||||||
|
birth: int
|
||||||
|
phone: str
|
||||||
|
email: str
|
||||||
|
|
||||||
|
# 外貌信息
|
||||||
|
height: int
|
||||||
|
weight: int
|
||||||
|
images: List[str]
|
||||||
|
scores: int
|
||||||
|
|
||||||
|
# 学历职业
|
||||||
|
degree: str
|
||||||
|
academy: str
|
||||||
|
occupation: str
|
||||||
|
income: int
|
||||||
|
assets: int
|
||||||
|
current_assets: int
|
||||||
|
house: str
|
||||||
|
car: str
|
||||||
|
|
||||||
|
# 户口家庭
|
||||||
|
registered_city: str
|
||||||
|
live_city: str
|
||||||
|
native_place: str
|
||||||
|
original_family: str
|
||||||
|
is_single_child: bool
|
||||||
|
|
||||||
|
match_requirement: str
|
||||||
|
|
||||||
|
introductions: Dict[str, str]
|
||||||
|
|
||||||
|
# 客户信息
|
||||||
|
custom_level: str
|
||||||
|
comments: Dict[str, str]
|
||||||
|
is_public: bool
|
||||||
|
created_at: datetime = None
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
# 初始化所有属性
|
||||||
|
self.id = kwargs.get('id', '')
|
||||||
|
self.user_id = kwargs.get('user_id', '')
|
||||||
|
self.name = kwargs.get('name', '')
|
||||||
|
self.gender = kwargs.get('gender', '未知')
|
||||||
|
self.birth = kwargs.get('birth', 0)
|
||||||
|
self.phone = kwargs.get('phone', '')
|
||||||
|
self.email = kwargs.get('email', '')
|
||||||
|
self.height = kwargs.get('height', 0)
|
||||||
|
self.weight = kwargs.get('weight', 0)
|
||||||
|
self.images = kwargs.get('images', [])
|
||||||
|
self.scores = kwargs.get('scores', 0)
|
||||||
|
self.degree = kwargs.get('degree', '')
|
||||||
|
self.academy = kwargs.get('academy', '')
|
||||||
|
self.occupation = kwargs.get('occupation', '')
|
||||||
|
self.income = kwargs.get('income', 0)
|
||||||
|
self.assets = kwargs.get('assets', 0)
|
||||||
|
self.current_assets = kwargs.get('current_assets', 0)
|
||||||
|
self.house = kwargs.get('house', '')
|
||||||
|
self.car = kwargs.get('car', '')
|
||||||
|
self.registered_city = kwargs.get('registered_city', '')
|
||||||
|
self.live_city = kwargs.get('live_city', '')
|
||||||
|
self.native_place = kwargs.get('native_place', '')
|
||||||
|
self.original_family = kwargs.get('original_family', '')
|
||||||
|
self.is_single_child = kwargs.get('is_single_child', False)
|
||||||
|
self.match_requirement = kwargs.get('match_requirement', '')
|
||||||
|
self.introductions = kwargs.get('introductions', {})
|
||||||
|
self.custom_level = kwargs.get('custom_level', '')
|
||||||
|
self.comments = kwargs.get('comments', {})
|
||||||
|
self.is_public = kwargs.get('is_public', False)
|
||||||
|
self.created_at = kwargs.get('created_at')
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
# 返回对象的字符串表示
|
||||||
|
attributes = ", ".join(f"{k}={v}" for k, v in self.to_dict().items())
|
||||||
|
return f"Custom({attributes})"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict):
|
||||||
|
# 从字典创建对象实例
|
||||||
|
if 'created_at' in data and data['created_at'] is not None:
|
||||||
|
data['created_at'] = datetime.fromtimestamp(data['created_at'])
|
||||||
|
# 移除ORM特有的时间戳,避免初始化错误
|
||||||
|
data.pop('updated_at', None)
|
||||||
|
data.pop('deleted_at', None)
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
# 将对象转换为字典
|
||||||
|
return {
|
||||||
|
'id': self.id,
|
||||||
|
'user_id': self.user_id,
|
||||||
|
'name': self.name,
|
||||||
|
'gender': self.gender,
|
||||||
|
'birth': self.birth,
|
||||||
|
'phone': self.phone,
|
||||||
|
'email': self.email,
|
||||||
|
'height': self.height,
|
||||||
|
'weight': self.weight,
|
||||||
|
'images': self.images,
|
||||||
|
'scores': self.scores,
|
||||||
|
'degree': self.degree,
|
||||||
|
'academy': self.academy,
|
||||||
|
'occupation': self.occupation,
|
||||||
|
'income': self.income,
|
||||||
|
'assets': self.assets,
|
||||||
|
'current_assets': self.current_assets,
|
||||||
|
'house': self.house,
|
||||||
|
'car': self.car,
|
||||||
|
'registered_city': self.registered_city,
|
||||||
|
'live_city': self.live_city,
|
||||||
|
'native_place': self.native_place,
|
||||||
|
'original_family': self.original_family,
|
||||||
|
'is_single_child': self.is_single_child,
|
||||||
|
'match_requirement': self.match_requirement,
|
||||||
|
'introductions': self.introductions,
|
||||||
|
'custom_level': self.custom_level,
|
||||||
|
'comments': self.comments,
|
||||||
|
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_rldb_model(cls, data: CustomRLDBModel):
|
||||||
|
# 从数据库模型转换
|
||||||
|
return cls(
|
||||||
|
id=data.id,
|
||||||
|
user_id=data.user_id,
|
||||||
|
name=data.name,
|
||||||
|
gender=data.gender,
|
||||||
|
birth=data.birth,
|
||||||
|
phone=data.phone,
|
||||||
|
email=data.email,
|
||||||
|
height=data.height,
|
||||||
|
weight=data.weight,
|
||||||
|
images=json.loads(data.images) if data.images else [],
|
||||||
|
scores=data.scores,
|
||||||
|
degree=data.degree,
|
||||||
|
academy=data.academy,
|
||||||
|
occupation=data.occupation,
|
||||||
|
income=data.income,
|
||||||
|
assets=data.assets,
|
||||||
|
house=data.house,
|
||||||
|
car=data.car,
|
||||||
|
registered_city=data.registered_city,
|
||||||
|
live_city=data.live_city,
|
||||||
|
native_place=data.native_place,
|
||||||
|
original_family=data.original_family,
|
||||||
|
is_single_child=data.is_single_child,
|
||||||
|
match_requirement=data.match_requirement,
|
||||||
|
introductions=json.loads(data.introductions) if data.introductions else {},
|
||||||
|
custom_level=data.custom_level,
|
||||||
|
comments=json.loads(data.comments) if data.comments else {},
|
||||||
|
is_public=data.is_public,
|
||||||
|
created_at=data.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_rldb_model(self) -> CustomRLDBModel:
|
||||||
|
# 转换为数据库模型
|
||||||
|
return CustomRLDBModel(
|
||||||
|
id=self.id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
name=self.name,
|
||||||
|
gender=self.gender,
|
||||||
|
birth=self.birth,
|
||||||
|
phone=self.phone,
|
||||||
|
email=self.email,
|
||||||
|
height=self.height,
|
||||||
|
weight=self.weight,
|
||||||
|
images=json.dumps(self.images, ensure_ascii=False),
|
||||||
|
scores=self.scores,
|
||||||
|
degree=self.degree,
|
||||||
|
academy=self.academy,
|
||||||
|
occupation=self.occupation,
|
||||||
|
income=self.income,
|
||||||
|
assets=self.assets,
|
||||||
|
current_assets=self.current_assets,
|
||||||
|
house=self.house,
|
||||||
|
car=self.car,
|
||||||
|
registered_city=self.registered_city,
|
||||||
|
live_city=self.live_city,
|
||||||
|
native_place=self.native_place,
|
||||||
|
original_family=self.original_family,
|
||||||
|
is_single_child=self.is_single_child,
|
||||||
|
match_requirement=self.match_requirement,
|
||||||
|
introductions=json.dumps(self.introductions, ensure_ascii=False),
|
||||||
|
custom_level=self.custom_level,
|
||||||
|
comments=json.dumps(self.comments, ensure_ascii=False),
|
||||||
|
is_public=self.is_public,
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate(self) -> error:
|
||||||
|
# 数据校验逻辑
|
||||||
|
if not self.name:
|
||||||
|
return error(ErrorCode.INVALID_PARAMS, "Name cannot be empty.")
|
||||||
|
|
||||||
|
if not self.gender:
|
||||||
|
return error(ErrorCode.INVALID_PARAMS, "Gender cannot be empty.")
|
||||||
|
|
||||||
|
if self.gender not in ['男', '女']:
|
||||||
|
return error(ErrorCode.INVALID_PARAMS, "Gender must be '男' or '女'.")
|
||||||
|
|
||||||
|
current_year = datetime.now().year
|
||||||
|
min_birth_year = 1950
|
||||||
|
max_birth_year = current_year - 18
|
||||||
|
|
||||||
|
if not isinstance(self.birth, int):
|
||||||
|
return error(ErrorCode.INVALID_PARAMS, "Birth year must be an integer.")
|
||||||
|
|
||||||
|
if self.birth < min_birth_year or self.birth > max_birth_year:
|
||||||
|
return error(ErrorCode.INVALID_PARAMS, f"Birth year must be between {min_birth_year} and {max_birth_year}.")
|
||||||
|
|
||||||
|
valid_houses = ["", "有房无贷", "有房有贷", "无自有房"]
|
||||||
|
if self.house not in valid_houses:
|
||||||
|
return error(ErrorCode.INVALID_PARAMS, f"House must be one of {valid_houses}")
|
||||||
|
|
||||||
|
valid_cars = ["", "有车无贷", "有车有贷", "无自有车"]
|
||||||
|
if self.car not in valid_cars:
|
||||||
|
return error(ErrorCode.INVALID_PARAMS, f"Car must be one of {valid_cars}")
|
||||||
|
|
||||||
|
# ... 可根据需要添加更多校验 ...
|
||||||
|
return error(ErrorCode.SUCCESS, "")
|
||||||
98
src/services/custom.py
Normal file
98
src/services/custom.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# created by mmmy on 2025-11-27
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from models.custom import Custom, CustomRLDBModel
|
||||||
|
from utils.error import ErrorCode, error
|
||||||
|
from utils import rldb
|
||||||
|
|
||||||
|
class CustomService:
|
||||||
|
def __init__(self):
|
||||||
|
self.rldb = rldb.get_instance()
|
||||||
|
|
||||||
|
def save(self, custom: Custom) -> (str, error):
|
||||||
|
"""
|
||||||
|
保存客户到数据库。
|
||||||
|
如果 custom.id 存在,则更新;否则,创建。
|
||||||
|
|
||||||
|
:param custom: 客户对象
|
||||||
|
:return: 客户ID 和 错误对象
|
||||||
|
"""
|
||||||
|
# 0. 生成 custom id
|
||||||
|
custom.id = custom.id if custom.id else uuid.uuid4().hex
|
||||||
|
|
||||||
|
# 1. 转换模型,并保存到 SQL 数据库
|
||||||
|
try:
|
||||||
|
custom_orm = custom.to_rldb_model()
|
||||||
|
self.rldb.upsert(custom_orm)
|
||||||
|
return custom.id, error(ErrorCode.SUCCESS, "")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to save custom {custom.id}: {e}")
|
||||||
|
return "", error(ErrorCode.RLDB_ERROR, f"Failed to save custom data: {str(e)}")
|
||||||
|
|
||||||
|
def delete(self, custom_id: str) -> error:
|
||||||
|
"""
|
||||||
|
从数据库删除客户。
|
||||||
|
|
||||||
|
:param custom_id: 客户ID
|
||||||
|
:return: 错误对象
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
|
||||||
|
if not custom_orm:
|
||||||
|
return error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
|
||||||
|
self.rldb.delete(custom_orm)
|
||||||
|
return error(ErrorCode.SUCCESS, "")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to delete custom {custom_id}: {e}")
|
||||||
|
return error(ErrorCode.RLDB_ERROR, f"Failed to delete custom data: {str(e)}")
|
||||||
|
|
||||||
|
def get(self, custom_id: str) -> (Custom, error):
|
||||||
|
"""
|
||||||
|
从数据库获取单个客户。
|
||||||
|
|
||||||
|
:param custom_id: 客户ID
|
||||||
|
:return: 客户对象 和 错误对象
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
|
||||||
|
if not custom_orm:
|
||||||
|
return None, error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
|
||||||
|
|
||||||
|
custom = Custom.from_rldb_model(custom_orm)
|
||||||
|
return custom, error(ErrorCode.SUCCESS, "")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to get custom {custom_id}: {e}")
|
||||||
|
return None, error(ErrorCode.RLDB_ERROR, f"Failed to retrieve custom data: {str(e)}")
|
||||||
|
|
||||||
|
def list(self, conds: dict = None, limit: int = 10, offset: int = 0) -> (list[Custom], error):
|
||||||
|
"""
|
||||||
|
根据条件从数据库列出客户(支持分页)。
|
||||||
|
|
||||||
|
:param conds: 查询条件字典
|
||||||
|
:param limit: 每页数量
|
||||||
|
:param offset: 偏移量
|
||||||
|
:return: 客户对象列表 和 错误对象
|
||||||
|
"""
|
||||||
|
if conds is None:
|
||||||
|
conds = {}
|
||||||
|
try:
|
||||||
|
custom_orms = self.rldb.query(CustomRLDBModel, limit=limit, offset=offset, **conds)
|
||||||
|
customs = [Custom.from_rldb_model(orm) for orm in custom_orms]
|
||||||
|
return customs, error(ErrorCode.SUCCESS, "")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to list customs with conds {conds}: {e}")
|
||||||
|
return [], error(ErrorCode.RLDB_ERROR, f"Failed to list custom data: {str(e)}")
|
||||||
|
|
||||||
|
# --- Singleton Pattern ---
|
||||||
|
custom_service = None
|
||||||
|
|
||||||
|
def init():
|
||||||
|
"""初始化 CustomService 单例"""
|
||||||
|
global custom_service
|
||||||
|
custom_service = CustomService()
|
||||||
|
|
||||||
|
def get_instance() -> CustomService:
|
||||||
|
"""获取 CustomService 单例"""
|
||||||
|
return custom_service
|
||||||
@@ -7,6 +7,7 @@ class ErrorCode(Enum):
|
|||||||
SUCCESS = 0
|
SUCCESS = 0
|
||||||
MODEL_ERROR = 1000
|
MODEL_ERROR = 1000
|
||||||
RLDB_ERROR = 2100
|
RLDB_ERROR = 2100
|
||||||
|
RLDB_NOT_FOUND = 2101
|
||||||
OBS_ERROR = 3100
|
OBS_ERROR = 3100
|
||||||
OBS_INPUT_ERROR = 3102
|
OBS_INPUT_ERROR = 3102
|
||||||
OBS_SERVICE_ERROR = 3103
|
OBS_SERVICE_ERROR = 3103
|
||||||
|
|||||||
471
src/web/api.py
471
src/web/api.py
@@ -1,19 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
import logging
|
from fastapi import FastAPI, UploadFile, File, APIRouter, Depends
|
||||||
from typing import Any, Optional, Literal
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Response, APIRouter, Depends, Request
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from services.people import get_instance as get_people_service
|
|
||||||
from services.user import get_instance as get_user_service
|
|
||||||
from web.auth import require_auth
|
from web.auth import require_auth
|
||||||
from models.people import People
|
from utils import obs
|
||||||
from agents.extract_people_agent import ExtractPeopleAgent
|
from web.schemas import BaseResponse
|
||||||
from utils import obs, ocr
|
from web.custom import router as custom_router
|
||||||
from utils.config import get_instance as get_config
|
from web.people import router as people_router
|
||||||
from utils.error import ErrorCode
|
from web.user import router as user_router
|
||||||
|
from web.recognition import router as recognition_router
|
||||||
|
|
||||||
api = FastAPI(title="Single People Management and Searching", version="0.1")
|
api = FastAPI(title="Single People Management and Searching", version="0.1")
|
||||||
api.add_middleware(
|
api.add_middleware(
|
||||||
@@ -26,18 +22,10 @@ api.add_middleware(
|
|||||||
|
|
||||||
authorized_router = APIRouter(dependencies=[Depends(require_auth)])
|
authorized_router = APIRouter(dependencies=[Depends(require_auth)])
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
|
||||||
error_code: int
|
|
||||||
error_info: str
|
|
||||||
data: Optional[Any] = None
|
|
||||||
|
|
||||||
@api.post("/api/ping")
|
@api.post("/api/ping")
|
||||||
async def ping():
|
async def ping():
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
class PostInputRequest(BaseModel):
|
|
||||||
text: str
|
|
||||||
|
|
||||||
@authorized_router.post("/api/upload/image")
|
@authorized_router.post("/api/upload/image")
|
||||||
async def post_upload_image(image: UploadFile = File(...)):
|
async def post_upload_image(image: UploadFile = File(...)):
|
||||||
# 实现上传图片的处理
|
# 实现上传图片的处理
|
||||||
@@ -49,439 +37,24 @@ async def post_upload_image(image: UploadFile = File(...)):
|
|||||||
# 保存文件到对象存储
|
# 保存文件到对象存储
|
||||||
file_path = f"uploads/{unique_filename}"
|
file_path = f"uploads/{unique_filename}"
|
||||||
obs_util = obs.get_instance()
|
obs_util = obs.get_instance()
|
||||||
obs_util.put(file_path, await image.read())
|
await run_in_threadpool(obs_util.put, file_path, await image.read())
|
||||||
|
|
||||||
# 获取对象存储外链
|
# 获取对象存储外链
|
||||||
obs_url = obs_util.get_link(file_path)
|
obs_url = obs_util.get_link(file_path)
|
||||||
return BaseResponse(error_code=0, error_info="success", data=obs_url)
|
return BaseResponse(error_code=0, error_info="success", data=obs_url)
|
||||||
|
|
||||||
@authorized_router.post("/api/recognition/input")
|
|
||||||
async def post_recognition_input(request: PostInputRequest):
|
|
||||||
people = extract_people(request.text)
|
|
||||||
resp = BaseResponse(error_code=0, error_info="success")
|
|
||||||
resp.data = people.to_dict()
|
|
||||||
return resp
|
|
||||||
|
|
||||||
@authorized_router.post("/api/recognition/image")
|
|
||||||
async def post_recognition_image(image: UploadFile = File(...)):
|
|
||||||
# 实现上传图片的处理
|
|
||||||
# 保存上传的图片文件
|
|
||||||
# 生成唯一的文件名
|
|
||||||
file_extension = os.path.splitext(image.filename)[1]
|
|
||||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
|
||||||
|
|
||||||
# 保存文件到对象存储
|
|
||||||
file_path = f"uploads/{unique_filename}"
|
|
||||||
obs_util = obs.get_instance()
|
|
||||||
obs_util.put(file_path, await image.read())
|
|
||||||
|
|
||||||
# 获取对象存储外链
|
|
||||||
obs_url = obs_util.get_link(file_path)
|
|
||||||
logging.info(f"obs_url: {obs_url}")
|
|
||||||
|
|
||||||
# 调用OCR处理图片
|
|
||||||
ocr_util = ocr.get_instance()
|
|
||||||
ocr_result = ocr_util.recognize_image_text(obs_url)
|
|
||||||
logging.info(f"ocr_result: {ocr_result}")
|
|
||||||
|
|
||||||
people = extract_people(ocr_result, obs_url)
|
|
||||||
resp = BaseResponse(error_code=0, error_info="success")
|
|
||||||
resp.data = people.to_dict()
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def extract_people(text: str, cover_link: str = None) -> People:
|
|
||||||
extra_agent = ExtractPeopleAgent()
|
|
||||||
people = extra_agent.extract_people_info(text)
|
|
||||||
people.cover = cover_link
|
|
||||||
logging.info(f"people: {people}")
|
|
||||||
return people
|
|
||||||
|
|
||||||
class PostPeopleRequest(BaseModel):
|
|
||||||
people: dict
|
|
||||||
|
|
||||||
@authorized_router.post("/api/people")
|
|
||||||
async def post_people(request: Request, post_people_request: PostPeopleRequest):
|
|
||||||
logging.debug(f"post_people_request: {post_people_request}")
|
|
||||||
people = People.from_dict(post_people_request.people)
|
|
||||||
people.user_id = getattr(request.state, 'user_id', '')
|
|
||||||
service = get_people_service()
|
|
||||||
people.id, error = service.save(people)
|
|
||||||
if not error.success:
|
|
||||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data=people.id)
|
|
||||||
|
|
||||||
@authorized_router.put("/api/people/{people_id}")
|
|
||||||
async def update_people(request: Request, people_id: str, post_people_request: PostPeopleRequest):
|
|
||||||
logging.debug(f"post_people_request: {post_people_request}")
|
|
||||||
people = People.from_dict(post_people_request.people)
|
|
||||||
people.id = people_id
|
|
||||||
service = get_people_service()
|
|
||||||
res, error = service.get(people_id)
|
|
||||||
if not error.success or not res:
|
|
||||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
|
||||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
|
||||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
|
||||||
people.user_id = res.user_id
|
|
||||||
_, error = service.save(people)
|
|
||||||
if not error.success:
|
|
||||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
|
||||||
|
|
||||||
@authorized_router.delete("/api/people/{people_id}")
|
|
||||||
async def delete_people(request: Request, people_id: str):
|
|
||||||
service = get_people_service()
|
|
||||||
res, err = service.get(people_id)
|
|
||||||
if not err.success or not res:
|
|
||||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
|
||||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
|
||||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
|
||||||
error = service.delete(people_id)
|
|
||||||
if not error.success:
|
|
||||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
|
||||||
|
|
||||||
class GetPeopleRequest(BaseModel):
|
|
||||||
query: Optional[str] = None
|
|
||||||
conds: Optional[dict] = None
|
|
||||||
top_k: int = 5
|
|
||||||
|
|
||||||
@authorized_router.get("/api/peoples")
|
|
||||||
async def get_peoples(
|
|
||||||
request: Request,
|
|
||||||
name: Optional[str] = Query(None, description="姓名"),
|
|
||||||
gender: Optional[str] = Query(None, description="性别"),
|
|
||||||
age: Optional[int] = Query(None, description="年龄"),
|
|
||||||
height: Optional[int] = Query(None, description="身高"),
|
|
||||||
marital_status: Optional[str] = Query(None, description="婚姻状态"),
|
|
||||||
limit: int = Query(10, description="分页大小"),
|
|
||||||
offset: int = Query(0, description="分页偏移量"),
|
|
||||||
):
|
|
||||||
|
|
||||||
# 解析查询参数为字典
|
|
||||||
conds = {}
|
|
||||||
conds["user_id"] = getattr(request.state, 'user_id', '')
|
|
||||||
if name:
|
|
||||||
conds["name"] = name
|
|
||||||
if gender:
|
|
||||||
conds["gender"] = gender
|
|
||||||
if age:
|
|
||||||
conds["age"] = age
|
|
||||||
if height:
|
|
||||||
conds["height"] = height
|
|
||||||
if marital_status:
|
|
||||||
conds["marital_status"] = marital_status
|
|
||||||
|
|
||||||
logging.info(f"conds: , limit: {limit}, offset: {offset}")
|
|
||||||
|
|
||||||
results = []
|
|
||||||
service = get_people_service()
|
|
||||||
results, error = service.list(conds, limit=limit, offset=offset)
|
|
||||||
logging.info(f"query results: {results}")
|
|
||||||
if not error.success:
|
|
||||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
|
||||||
peoples = [people.to_dict() for people in results]
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data=peoples)
|
|
||||||
|
|
||||||
|
|
||||||
class RemarkRequest(BaseModel):
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.post("/api/people/{people_id}/remark")
|
|
||||||
async def post_people_remark(request: Request, people_id: str, body: RemarkRequest):
|
|
||||||
service = get_people_service()
|
|
||||||
res, err = service.get(people_id)
|
|
||||||
if not err.success or not res:
|
|
||||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
|
||||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
|
||||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
|
||||||
error = service.save_remark(people_id, body.content)
|
|
||||||
if not error.success:
|
|
||||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.delete("/api/people/{people_id}/remark")
|
|
||||||
async def delete_people_remark(request: Request, people_id: str):
|
|
||||||
service = get_people_service()
|
|
||||||
res, err = service.get(people_id)
|
|
||||||
if not err.success or not res:
|
|
||||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
|
||||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
|
||||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
|
||||||
error = service.delete_remark(people_id)
|
|
||||||
if not error.success:
|
|
||||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.post("/api/people/{people_id}/image")
|
|
||||||
async def post_people_image(request: Request, people_id: str, image: UploadFile = File(...)):
|
|
||||||
|
|
||||||
# 检查 people id 是否存在
|
|
||||||
service = get_people_service()
|
|
||||||
people, err = service.get(people_id)
|
|
||||||
if not err.success:
|
|
||||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
|
||||||
|
|
||||||
if people.user_id != getattr(request.state, 'user_id', ''):
|
|
||||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
|
||||||
|
|
||||||
# 实现上传图片的处理
|
|
||||||
# 保存上传的图片文件
|
|
||||||
# 生成唯一的文件名
|
|
||||||
file_extension = os.path.splitext(image.filename)[1]
|
|
||||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
|
||||||
|
|
||||||
# 保存文件到对象存储
|
|
||||||
file_path = f"peoples/{people_id}/images/{unique_filename}"
|
|
||||||
obs_util = obs.get_instance()
|
|
||||||
obs_util.put(file_path, await image.read())
|
|
||||||
|
|
||||||
# 获取对象存储外链
|
|
||||||
obs_url = obs_util.get_link(file_path)
|
|
||||||
logging.info(f"obs_url: {obs_url}")
|
|
||||||
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data=obs_url)
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.delete("/api/people/{people_id}/image")
|
|
||||||
async def delete_people_image(request: Request, people_id: str, image_url: str):
|
|
||||||
# 检查 people id 是否存在
|
|
||||||
service = get_people_service()
|
|
||||||
people, err = service.get(people_id)
|
|
||||||
if not err.success:
|
|
||||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
|
||||||
|
|
||||||
if people.user_id != getattr(request.state, 'user_id', ''):
|
|
||||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
|
||||||
|
|
||||||
# 检查 image_url 是否是该 people 名下的图片链接
|
|
||||||
obs_util = obs.get_instance()
|
|
||||||
obs_path, err = obs_util.get_obs_path_by_link(image_url)
|
|
||||||
if not err.success:
|
|
||||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
|
||||||
if not obs_path.startswith(f"peoples/{people_id}/images/"):
|
|
||||||
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {people_id} 名下的图片链接")
|
|
||||||
|
|
||||||
# 实现删除图片的处理
|
|
||||||
# 删除对象存储中的文件
|
|
||||||
err = obs_util.delete_by_link(image_url)
|
|
||||||
if not err.success:
|
|
||||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
|
||||||
|
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
|
||||||
|
|
||||||
|
|
||||||
class SendCodeRequest(BaseModel):
|
|
||||||
target_type: str
|
|
||||||
target: str
|
|
||||||
scene: Literal['register', 'update']
|
|
||||||
# scene: Literal['register', 'login']
|
|
||||||
|
|
||||||
|
|
||||||
@api.post("/api/user/send_code")
|
|
||||||
async def send_user_code(request: SendCodeRequest):
|
|
||||||
service = get_user_service()
|
|
||||||
err = service.send_code(request.target_type, request.target, request.scene)
|
|
||||||
if not err.success:
|
|
||||||
raise HTTPException(status_code=400, detail=err.info)
|
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
|
||||||
|
|
||||||
|
|
||||||
class RegisterRequest(BaseModel):
|
|
||||||
nickname: Optional[str] = None
|
|
||||||
avatar_link: Optional[str] = None
|
|
||||||
email: Optional[str] = None
|
|
||||||
phone: Optional[str] = None
|
|
||||||
password: str
|
|
||||||
code: str
|
|
||||||
|
|
||||||
@api.post("/api/user")
|
|
||||||
async def user_register(request: RegisterRequest):
|
|
||||||
service = get_user_service()
|
|
||||||
from models.user import User
|
|
||||||
u = User(
|
|
||||||
nickname=request.nickname or "",
|
|
||||||
avatar_link=request.avatar_link or "",
|
|
||||||
email=request.email or "",
|
|
||||||
phone=request.phone or "",
|
|
||||||
password_hash=request.password,
|
|
||||||
)
|
|
||||||
uid, err = service.register(u, request.code)
|
|
||||||
if not err.success:
|
|
||||||
logging.error(f"register failed: {err}")
|
|
||||||
raise HTTPException(status_code=400, detail=err.info)
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data=uid)
|
|
||||||
|
|
||||||
|
|
||||||
class LoginRequest(BaseModel):
|
|
||||||
email: Optional[str] = None
|
|
||||||
phone: Optional[str] = None
|
|
||||||
password: str
|
|
||||||
|
|
||||||
@api.post("/api/user/login")
|
|
||||||
async def user_login(request: LoginRequest, response: Response):
|
|
||||||
service = get_user_service()
|
|
||||||
data, err = service.login(request.email, request.phone, request.password)
|
|
||||||
if not err.success:
|
|
||||||
raise HTTPException(status_code=400, detail=err.info)
|
|
||||||
conf = get_config()
|
|
||||||
ttl_days = conf.getint('auth', 'token_ttl_days', fallback=30)
|
|
||||||
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
|
|
||||||
cookie_secure = conf.getboolean('auth', 'cookie_secure', fallback=False)
|
|
||||||
cookie_samesite = conf.get('auth', 'cookie_samesite', fallback=None)
|
|
||||||
response.set_cookie(
|
|
||||||
key="token",
|
|
||||||
value=data.get('token', ''),
|
|
||||||
max_age=ttl_days * 24 * 3600,
|
|
||||||
httponly=True,
|
|
||||||
secure=cookie_secure,
|
|
||||||
samesite=cookie_samesite,
|
|
||||||
domain=cookie_domain,
|
|
||||||
path="/",
|
|
||||||
)
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data={"expired_at": data.get('expired_at')})
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.delete("/api/user/me/login")
|
|
||||||
async def user_logout(response: Response, request: Request):
|
|
||||||
service = get_user_service()
|
|
||||||
err = service.logout(getattr(request.state, 'token', None))
|
|
||||||
if not err.success:
|
|
||||||
raise HTTPException(status_code=400, detail=err.info)
|
|
||||||
conf = get_config()
|
|
||||||
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
|
|
||||||
response.delete_cookie(key="token", domain=cookie_domain, path="/")
|
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.delete("/api/user/me")
|
|
||||||
async def user_delete(response: Response, request: Request):
|
|
||||||
service = get_user_service()
|
|
||||||
err = service.delete_user(getattr(request.state, 'user_id', None))
|
|
||||||
if not err.success:
|
|
||||||
raise HTTPException(status_code=400, detail=err.info)
|
|
||||||
conf = get_config()
|
|
||||||
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
|
|
||||||
response.delete_cookie(key="token", domain=cookie_domain, path="/")
|
|
||||||
return BaseResponse(error_code=0, error_info="success")
|
|
||||||
|
|
||||||
@authorized_router.get("/api/user/me")
|
|
||||||
async def user_me(request: Request):
|
|
||||||
service = get_user_service()
|
|
||||||
user, err = service.get(getattr(request.state, 'user_id', None))
|
|
||||||
if not err.success or not user:
|
|
||||||
raise HTTPException(status_code=400, detail=err.info)
|
|
||||||
data = {
|
|
||||||
'nickname': user.nickname,
|
|
||||||
'avatar_link': user.avatar_link,
|
|
||||||
'phone': user.phone,
|
|
||||||
'email': user.email,
|
|
||||||
}
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data=data)
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateMeRequest(BaseModel):
|
|
||||||
nickname: Optional[str] = None
|
|
||||||
avatar_link: Optional[str] = None
|
|
||||||
phone: Optional[str] = None
|
|
||||||
email: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.put("/api/user/me")
|
|
||||||
async def update_user_me(request: Request, body: UpdateMeRequest):
|
|
||||||
service = get_user_service()
|
|
||||||
user, err = service.update_profile(
|
|
||||||
getattr(request.state, 'user_id', None),
|
|
||||||
nickname=body.nickname,
|
|
||||||
avatar_link=body.avatar_link,
|
|
||||||
phone=body.phone,
|
|
||||||
email=body.email,
|
|
||||||
)
|
|
||||||
if not err.success:
|
|
||||||
raise HTTPException(status_code=400, detail=err.info)
|
|
||||||
data = {
|
|
||||||
'nickname': user.nickname,
|
|
||||||
'avatar_link': user.avatar_link,
|
|
||||||
'phone': user.phone,
|
|
||||||
'email': user.email,
|
|
||||||
}
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data=data)
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.put("/api/user/me/avatar")
|
|
||||||
async def upload_avatar(request: Request, avatar: UploadFile = File(...)):
|
|
||||||
user_id = getattr(request.state, 'user_id', None)
|
|
||||||
if not user_id:
|
|
||||||
raise HTTPException(status_code=401, detail="unauthorized")
|
|
||||||
|
|
||||||
file_extension = os.path.splitext(avatar.filename)[1]
|
|
||||||
timestamp = int(time.time())
|
|
||||||
avatar_path = f"users/{user_id}/avatar-{timestamp}{file_extension}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
obs_util = obs.get_instance()
|
|
||||||
obs_util.Put(avatar_path, await avatar.read())
|
|
||||||
avatar_url = obs_util.Link(avatar_path)
|
|
||||||
|
|
||||||
user_service = get_user_service()
|
|
||||||
_, err = user_service.update_profile(user_id, avatar_link=avatar_url)
|
|
||||||
if not err.success:
|
|
||||||
raise HTTPException(status_code=500, detail=err.info)
|
|
||||||
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data={"avatar_link": avatar_url})
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"upload avatar failed: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail="upload avatar failed")
|
|
||||||
|
|
||||||
|
|
||||||
class UpdatePhoneRequest(BaseModel):
|
|
||||||
phone: str
|
|
||||||
code: str
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.put("/api/user/me/phone")
|
|
||||||
async def update_user_phone(request: Request, body: UpdatePhoneRequest):
|
|
||||||
service = get_user_service()
|
|
||||||
user, err = service.update_phone_with_code(
|
|
||||||
getattr(request.state, 'user_id', None),
|
|
||||||
body.phone,
|
|
||||||
body.code,
|
|
||||||
)
|
|
||||||
if not err.success:
|
|
||||||
raise HTTPException(status_code=400, detail=err.info)
|
|
||||||
data = {
|
|
||||||
'nickname': user.nickname,
|
|
||||||
'avatar_link': user.avatar_link,
|
|
||||||
'phone': user.phone,
|
|
||||||
'email': user.email,
|
|
||||||
}
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data=data)
|
|
||||||
|
|
||||||
|
|
||||||
class UpdateEmailRequest(BaseModel):
|
|
||||||
email: str
|
|
||||||
code: str
|
|
||||||
|
|
||||||
|
|
||||||
@authorized_router.put("/api/user/me/email")
|
|
||||||
async def update_user_email(request: Request, body: UpdateEmailRequest):
|
|
||||||
service = get_user_service()
|
|
||||||
user, err = service.update_email_with_code(
|
|
||||||
getattr(request.state, 'user_id', None),
|
|
||||||
body.email,
|
|
||||||
body.code,
|
|
||||||
)
|
|
||||||
if not err.success:
|
|
||||||
raise HTTPException(status_code=400, detail=err.info)
|
|
||||||
data = {
|
|
||||||
'nickname': user.nickname,
|
|
||||||
'avatar_link': user.avatar_link,
|
|
||||||
'phone': user.phone,
|
|
||||||
'email': user.email,
|
|
||||||
}
|
|
||||||
return BaseResponse(error_code=0, error_info="success", data=data)
|
|
||||||
|
|
||||||
|
|
||||||
api.include_router(authorized_router)
|
api.include_router(authorized_router)
|
||||||
|
|
||||||
|
# Register custom router
|
||||||
|
api.include_router(custom_router, dependencies=[Depends(require_auth)])
|
||||||
|
|
||||||
|
# Register people router
|
||||||
|
api.include_router(people_router, dependencies=[Depends(require_auth)])
|
||||||
|
|
||||||
|
# Register user router
|
||||||
|
api.include_router(user_router)
|
||||||
|
|
||||||
|
# Register recognition router
|
||||||
|
api.include_router(recognition_router, dependencies=[Depends(require_auth)])
|
||||||
|
|
||||||
|
|||||||
151
src/web/custom.py
Normal file
151
src/web/custom.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from fastapi import APIRouter, Depends, Request, Query, UploadFile, File
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from models.custom import Custom
|
||||||
|
from services.custom import get_instance as get_custom_service
|
||||||
|
from utils.error import ErrorCode
|
||||||
|
from utils import obs
|
||||||
|
from web.schemas import BaseResponse
|
||||||
|
|
||||||
|
router = APIRouter(tags=["custom"])
|
||||||
|
|
||||||
|
class PostCustomRequest(BaseModel):
|
||||||
|
custom: dict
|
||||||
|
|
||||||
|
@router.post("/api/custom")
|
||||||
|
def create_custom(request: Request, post_custom_request: PostCustomRequest):
|
||||||
|
logging.debug(f"post_custom_request: {post_custom_request}")
|
||||||
|
custom = Custom.from_dict(post_custom_request.custom)
|
||||||
|
|
||||||
|
# Validate custom data
|
||||||
|
err = custom.validate()
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
|
||||||
|
custom.user_id = getattr(request.state, 'user_id', '')
|
||||||
|
|
||||||
|
service = get_custom_service()
|
||||||
|
custom.id, error = service.save(custom)
|
||||||
|
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=custom.id)
|
||||||
|
|
||||||
|
@router.put("/api/custom/{custom_id}")
|
||||||
|
def update_custom(request: Request, custom_id: str, post_custom_request: PostCustomRequest):
|
||||||
|
logging.debug(f"post_custom_request: {post_custom_request}")
|
||||||
|
custom = Custom.from_dict(post_custom_request.custom)
|
||||||
|
custom.id = custom_id
|
||||||
|
|
||||||
|
# Validate custom data
|
||||||
|
err = custom.validate()
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
|
||||||
|
service = get_custom_service()
|
||||||
|
# Check permission
|
||||||
|
res, error = service.get(custom_id)
|
||||||
|
if not error.success or not res:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
|
||||||
|
custom.user_id = res.user_id # Ensure user_id is not changed or is set correctly
|
||||||
|
|
||||||
|
_, error = service.save(custom)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
@router.delete("/api/custom/{custom_id}")
|
||||||
|
def delete_custom(request: Request, custom_id: str):
|
||||||
|
service = get_custom_service()
|
||||||
|
res, error = service.get(custom_id)
|
||||||
|
if not error.success or not res:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
error = service.delete(custom_id)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=custom_id)
|
||||||
|
|
||||||
|
@router.get("/api/customs")
|
||||||
|
def get_customs(request: Request, limit: int = Query(10, ge=1, le=1000), offset: int = Query(0, ge=0)):
|
||||||
|
service = get_custom_service()
|
||||||
|
res, error = service.list({'user_id': getattr(request.state, 'user_id', '')}, limit=limit, offset=offset)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
# custom对象转换为字典
|
||||||
|
customs = [custom.to_dict() for custom in res]
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=customs)
|
||||||
|
|
||||||
|
@router.get("/api/custom/{custom_id}")
|
||||||
|
def get_custom(request: Request, custom_id: str):
|
||||||
|
service = get_custom_service()
|
||||||
|
res, error = service.get(custom_id)
|
||||||
|
if not error.success or not res:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=res.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/custom/{custom_id}/image")
|
||||||
|
async def post_custom_image(request: Request, custom_id: str, image: UploadFile = File(...)):
|
||||||
|
# 检查 custom id 是否存在
|
||||||
|
service = get_custom_service()
|
||||||
|
custom, err = service.get(custom_id)
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
|
||||||
|
if custom.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
|
||||||
|
# 实现上传图片的处理
|
||||||
|
# 保存上传的图片文件
|
||||||
|
# 生成唯一的文件名
|
||||||
|
file_extension = os.path.splitext(image.filename)[1]
|
||||||
|
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||||
|
|
||||||
|
# 保存文件到对象存储
|
||||||
|
file_path = f"customs/{custom_id}/images/{unique_filename}"
|
||||||
|
obs_util = obs.get_instance()
|
||||||
|
await run_in_threadpool(obs_util.put, file_path, await image.read())
|
||||||
|
|
||||||
|
# 获取对象存储外链
|
||||||
|
obs_url = obs_util.get_link(file_path)
|
||||||
|
logging.info(f"obs_url: {obs_url}")
|
||||||
|
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=obs_url)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/api/custom/{custom_id}/image")
|
||||||
|
async def delete_custom_image(request: Request, custom_id: str, image_url: str):
|
||||||
|
# 检查 custom id 是否存在
|
||||||
|
service = get_custom_service()
|
||||||
|
custom, err = service.get(custom_id)
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
|
||||||
|
if custom.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
|
||||||
|
# 检查 image_url 是否是该 custom 名下的图片链接
|
||||||
|
obs_util = obs.get_instance()
|
||||||
|
obs_path, err = obs_util.get_obs_path_by_link(image_url)
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
if not obs_path.startswith(f"customs/{custom_id}/images/"):
|
||||||
|
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {custom_id} 名下的图片链接")
|
||||||
|
|
||||||
|
# 实现删除图片的处理
|
||||||
|
# 删除对象存储中的文件
|
||||||
|
err = obs_util.delete_by_link(image_url)
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
192
src/web/people.py
Normal file
192
src/web/people.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, Request, UploadFile, File, Query
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from services.people import get_instance as get_people_service
|
||||||
|
from models.people import People
|
||||||
|
from utils import obs
|
||||||
|
from utils.error import ErrorCode
|
||||||
|
from web.schemas import BaseResponse
|
||||||
|
|
||||||
|
router = APIRouter(tags=["people"])
|
||||||
|
|
||||||
|
|
||||||
|
class PostPeopleRequest(BaseModel):
|
||||||
|
people: dict
|
||||||
|
|
||||||
|
@router.post("/api/people")
|
||||||
|
async def post_people(request: Request, post_people_request: PostPeopleRequest):
|
||||||
|
logging.debug(f"post_people_request: {post_people_request}")
|
||||||
|
people = People.from_dict(post_people_request.people)
|
||||||
|
people.user_id = getattr(request.state, 'user_id', '')
|
||||||
|
service = get_people_service()
|
||||||
|
people.id, error = service.save(people)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=people.id)
|
||||||
|
|
||||||
|
@router.put("/api/people/{people_id}")
|
||||||
|
async def update_people(request: Request, people_id: str, post_people_request: PostPeopleRequest):
|
||||||
|
logging.debug(f"post_people_request: {post_people_request}")
|
||||||
|
people = People.from_dict(post_people_request.people)
|
||||||
|
people.id = people_id
|
||||||
|
service = get_people_service()
|
||||||
|
res, error = service.get(people_id)
|
||||||
|
if not error.success or not res:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
people.user_id = res.user_id
|
||||||
|
_, error = service.save(people)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
@router.delete("/api/people/{people_id}")
|
||||||
|
async def delete_people(request: Request, people_id: str):
|
||||||
|
service = get_people_service()
|
||||||
|
res, err = service.get(people_id)
|
||||||
|
if not err.success or not res:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
error = service.delete(people_id)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
class GetPeopleRequest(BaseModel):
|
||||||
|
query: Optional[str] = None
|
||||||
|
conds: Optional[dict] = None
|
||||||
|
top_k: int = 5
|
||||||
|
|
||||||
|
@router.get("/api/peoples")
|
||||||
|
async def get_peoples(
|
||||||
|
request: Request,
|
||||||
|
name: Optional[str] = Query(None, description="姓名"),
|
||||||
|
gender: Optional[str] = Query(None, description="性别"),
|
||||||
|
age: Optional[int] = Query(None, description="年龄"),
|
||||||
|
height: Optional[int] = Query(None, description="身高"),
|
||||||
|
marital_status: Optional[str] = Query(None, description="婚姻状态"),
|
||||||
|
limit: int = Query(10, description="分页大小"),
|
||||||
|
offset: int = Query(0, description="分页偏移量"),
|
||||||
|
):
|
||||||
|
|
||||||
|
# 解析查询参数为字典
|
||||||
|
conds = {}
|
||||||
|
conds["user_id"] = getattr(request.state, 'user_id', '')
|
||||||
|
if name:
|
||||||
|
conds["name"] = name
|
||||||
|
if gender:
|
||||||
|
conds["gender"] = gender
|
||||||
|
if age:
|
||||||
|
conds["age"] = age
|
||||||
|
if height:
|
||||||
|
conds["height"] = height
|
||||||
|
if marital_status:
|
||||||
|
conds["marital_status"] = marital_status
|
||||||
|
|
||||||
|
logging.info(f"conds: , limit: {limit}, offset: {offset}")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
service = get_people_service()
|
||||||
|
results, error = service.list(conds, limit=limit, offset=offset)
|
||||||
|
logging.info(f"query results: {results}")
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
peoples = [people.to_dict() for people in results]
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=peoples)
|
||||||
|
|
||||||
|
|
||||||
|
class RemarkRequest(BaseModel):
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/people/{people_id}/remark")
|
||||||
|
async def post_people_remark(request: Request, people_id: str, body: RemarkRequest):
|
||||||
|
service = get_people_service()
|
||||||
|
res, err = service.get(people_id)
|
||||||
|
if not err.success or not res:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
error = service.save_remark(people_id, body.content)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/api/people/{people_id}/remark")
|
||||||
|
async def delete_people_remark(request: Request, people_id: str):
|
||||||
|
service = get_people_service()
|
||||||
|
res, err = service.get(people_id)
|
||||||
|
if not err.success or not res:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
error = service.delete_remark(people_id)
|
||||||
|
if not error.success:
|
||||||
|
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/people/{people_id}/image")
|
||||||
|
async def post_people_image(request: Request, people_id: str, image: UploadFile = File(...)):
|
||||||
|
|
||||||
|
# 检查 people id 是否存在
|
||||||
|
service = get_people_service()
|
||||||
|
people, err = service.get(people_id)
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
|
||||||
|
if people.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
|
||||||
|
# 实现上传图片的处理
|
||||||
|
# 保存上传的图片文件
|
||||||
|
# 生成唯一的文件名
|
||||||
|
file_extension = os.path.splitext(image.filename)[1]
|
||||||
|
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||||
|
|
||||||
|
# 保存文件到对象存储
|
||||||
|
file_path = f"peoples/{people_id}/images/{unique_filename}"
|
||||||
|
obs_util = obs.get_instance()
|
||||||
|
await run_in_threadpool(obs_util.put, file_path, await image.read())
|
||||||
|
|
||||||
|
# 获取对象存储外链
|
||||||
|
obs_url = obs_util.get_link(file_path)
|
||||||
|
logging.info(f"obs_url: {obs_url}")
|
||||||
|
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=obs_url)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/api/people/{people_id}/image")
|
||||||
|
async def delete_people_image(request: Request, people_id: str, image_url: str):
|
||||||
|
# 检查 people id 是否存在
|
||||||
|
service = get_people_service()
|
||||||
|
people, err = service.get(people_id)
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
|
||||||
|
if people.user_id != getattr(request.state, 'user_id', ''):
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||||
|
|
||||||
|
# 检查 image_url 是否是该 people 名下的图片链接
|
||||||
|
obs_util = obs.get_instance()
|
||||||
|
obs_path, err = obs_util.get_obs_path_by_link(image_url)
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
if not obs_path.startswith(f"peoples/{people_id}/images/"):
|
||||||
|
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {people_id} 名下的图片链接")
|
||||||
|
|
||||||
|
# 实现删除图片的处理
|
||||||
|
# 删除对象存储中的文件
|
||||||
|
err = obs_util.delete_by_link(image_url)
|
||||||
|
if not err.success:
|
||||||
|
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||||
|
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
91
src/web/recognition.py
Normal file
91
src/web/recognition.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, UploadFile, File
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from models.people import People
|
||||||
|
from models.custom import Custom
|
||||||
|
from agents.extract_people_agent import ExtractPeopleAgent
|
||||||
|
from agents.extract_custom_agent import ExtractCustomAgent
|
||||||
|
from utils import obs, ocr
|
||||||
|
from web.schemas import BaseResponse
|
||||||
|
from utils.error import ErrorCode
|
||||||
|
|
||||||
|
router = APIRouter(tags=["recognition"])
|
||||||
|
|
||||||
|
def extract_people(text: str, cover_link: str = None) -> People:
|
||||||
|
extra_agent = ExtractPeopleAgent()
|
||||||
|
people = extra_agent.extract_people_info(text)
|
||||||
|
if people:
|
||||||
|
people.cover = cover_link
|
||||||
|
logging.info(f"people: {people}")
|
||||||
|
return people
|
||||||
|
|
||||||
|
def extract_custom(text: str, image_link: str = None) -> Custom:
|
||||||
|
extra_agent = ExtractCustomAgent()
|
||||||
|
custom = extra_agent.extract_custom_info(text)
|
||||||
|
if custom:
|
||||||
|
if image_link:
|
||||||
|
custom.images = [image_link]
|
||||||
|
logging.info(f"custom: {custom}")
|
||||||
|
return custom
|
||||||
|
|
||||||
|
class PostInputRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
@router.post("/api/recognition/{model}/input")
|
||||||
|
async def post_recognition_input(model: str, request: PostInputRequest):
|
||||||
|
if model == "people":
|
||||||
|
result = await run_in_threadpool(extract_people, request.text)
|
||||||
|
elif model == "custom":
|
||||||
|
result = await run_in_threadpool(extract_custom, request.text)
|
||||||
|
else:
|
||||||
|
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
|
||||||
|
|
||||||
|
resp = BaseResponse(error_code=0, error_info="success")
|
||||||
|
resp.data = result.to_dict()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
@router.post("/api/recognition/{model}/image")
|
||||||
|
async def post_recognition_image(model: str, image: UploadFile = File(...)):
|
||||||
|
if model not in ["people", "custom"]:
|
||||||
|
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
|
||||||
|
|
||||||
|
# 实现上传图片的处理
|
||||||
|
# 保存上传的图片文件
|
||||||
|
# 生成唯一的文件名
|
||||||
|
file_extension = os.path.splitext(image.filename)[1]
|
||||||
|
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||||
|
|
||||||
|
# 保存文件到对象存储
|
||||||
|
file_path = f"uploads/{model}/{unique_filename}"
|
||||||
|
obs_util = obs.get_instance()
|
||||||
|
await run_in_threadpool(obs_util.put, file_path, await image.read())
|
||||||
|
|
||||||
|
# 获取对象存储外链
|
||||||
|
obs_url = obs_util.get_link(file_path)
|
||||||
|
logging.info(f"obs_url: {obs_url}")
|
||||||
|
|
||||||
|
# 调用OCR处理图片
|
||||||
|
ocr_util = ocr.get_instance()
|
||||||
|
ocr_result = await run_in_threadpool(ocr_util.recognize_image_text, obs_url)
|
||||||
|
logging.info(f"ocr_result: {ocr_result}")
|
||||||
|
|
||||||
|
if model == "people":
|
||||||
|
result = await run_in_threadpool(extract_people, ocr_result, obs_url)
|
||||||
|
elif model == "custom":
|
||||||
|
result = await run_in_threadpool(extract_custom, ocr_result, obs_url)
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
|
||||||
|
|
||||||
|
resp = BaseResponse(error_code=0, error_info="success")
|
||||||
|
resp.data = result.to_dict()
|
||||||
|
return resp
|
||||||
7
src/web/schemas.py
Normal file
7
src/web/schemas.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class BaseResponse(BaseModel):
|
||||||
|
error_code: int
|
||||||
|
error_info: str
|
||||||
|
data: Optional[Any] = None
|
||||||
223
src/web/user.py
Normal file
223
src/web/user.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Literal
|
||||||
|
from fastapi import APIRouter, Depends, Request, HTTPException, Response, UploadFile, File
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from services.user import get_instance as get_user_service
|
||||||
|
from web.auth import require_auth
|
||||||
|
from utils import obs
|
||||||
|
from utils.config import get_instance as get_config
|
||||||
|
from web.schemas import BaseResponse
|
||||||
|
|
||||||
|
router = APIRouter(tags=["user"])
|
||||||
|
|
||||||
|
class SendCodeRequest(BaseModel):
|
||||||
|
target_type: str
|
||||||
|
target: str
|
||||||
|
scene: Literal['register', 'update']
|
||||||
|
# scene: Literal['register', 'login']
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/api/user/send_code")
|
||||||
|
async def send_user_code(request: SendCodeRequest):
|
||||||
|
service = get_user_service()
|
||||||
|
err = service.send_code(request.target_type, request.target, request.scene)
|
||||||
|
if not err.success:
|
||||||
|
raise HTTPException(status_code=400, detail=err.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
nickname: Optional[str] = None
|
||||||
|
avatar_link: Optional[str] = None
|
||||||
|
email: Optional[str] = None
|
||||||
|
phone: Optional[str] = None
|
||||||
|
password: str
|
||||||
|
code: str
|
||||||
|
|
||||||
|
@router.post("/api/user")
|
||||||
|
async def user_register(request: RegisterRequest):
|
||||||
|
service = get_user_service()
|
||||||
|
from models.user import User
|
||||||
|
u = User(
|
||||||
|
nickname=request.nickname or "",
|
||||||
|
avatar_link=request.avatar_link or "",
|
||||||
|
email=request.email or "",
|
||||||
|
phone=request.phone or "",
|
||||||
|
password_hash=request.password,
|
||||||
|
)
|
||||||
|
uid, err = service.register(u, request.code)
|
||||||
|
if not err.success:
|
||||||
|
logging.error(f"register failed: {err}")
|
||||||
|
raise HTTPException(status_code=400, detail=err.info)
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=uid)
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
email: Optional[str] = None
|
||||||
|
phone: Optional[str] = None
|
||||||
|
password: str
|
||||||
|
|
||||||
|
@router.post("/api/user/login")
|
||||||
|
async def user_login(request: LoginRequest, response: Response):
|
||||||
|
service = get_user_service()
|
||||||
|
data, err = service.login(request.email, request.phone, request.password)
|
||||||
|
if not err.success:
|
||||||
|
raise HTTPException(status_code=400, detail=err.info)
|
||||||
|
conf = get_config()
|
||||||
|
ttl_days = conf.getint('auth', 'token_ttl_days', fallback=30)
|
||||||
|
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
|
||||||
|
cookie_secure = conf.getboolean('auth', 'cookie_secure', fallback=False)
|
||||||
|
cookie_samesite = conf.get('auth', 'cookie_samesite', fallback=None)
|
||||||
|
response.set_cookie(
|
||||||
|
key="token",
|
||||||
|
value=data.get('token', ''),
|
||||||
|
max_age=ttl_days * 24 * 3600,
|
||||||
|
httponly=True,
|
||||||
|
secure=cookie_secure,
|
||||||
|
samesite=cookie_samesite,
|
||||||
|
domain=cookie_domain,
|
||||||
|
path="/",
|
||||||
|
)
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data={"expired_at": data.get('expired_at')})
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/api/user/me/login", dependencies=[Depends(require_auth)])
|
||||||
|
async def user_logout(response: Response, request: Request):
|
||||||
|
service = get_user_service()
|
||||||
|
err = service.logout(getattr(request.state, 'token', None))
|
||||||
|
if not err.success:
|
||||||
|
raise HTTPException(status_code=400, detail=err.info)
|
||||||
|
conf = get_config()
|
||||||
|
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
|
||||||
|
response.delete_cookie(key="token", domain=cookie_domain, path="/")
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/api/user/me", dependencies=[Depends(require_auth)])
|
||||||
|
async def user_delete(response: Response, request: Request):
|
||||||
|
service = get_user_service()
|
||||||
|
err = service.delete_user(getattr(request.state, 'user_id', None))
|
||||||
|
if not err.success:
|
||||||
|
raise HTTPException(status_code=400, detail=err.info)
|
||||||
|
conf = get_config()
|
||||||
|
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
|
||||||
|
response.delete_cookie(key="token", domain=cookie_domain, path="/")
|
||||||
|
return BaseResponse(error_code=0, error_info="success")
|
||||||
|
|
||||||
|
@router.get("/api/user/me", dependencies=[Depends(require_auth)])
|
||||||
|
async def user_me(request: Request):
|
||||||
|
service = get_user_service()
|
||||||
|
user, err = service.get(getattr(request.state, 'user_id', None))
|
||||||
|
if not err.success or not user:
|
||||||
|
raise HTTPException(status_code=400, detail=err.info)
|
||||||
|
data = {
|
||||||
|
'nickname': user.nickname,
|
||||||
|
'avatar_link': user.avatar_link,
|
||||||
|
'phone': user.phone,
|
||||||
|
'email': user.email,
|
||||||
|
}
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=data)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateMeRequest(BaseModel):
|
||||||
|
nickname: Optional[str] = None
|
||||||
|
avatar_link: Optional[str] = None
|
||||||
|
phone: Optional[str] = None
|
||||||
|
email: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/user/me", dependencies=[Depends(require_auth)])
|
||||||
|
async def update_user_me(request: Request, body: UpdateMeRequest):
|
||||||
|
service = get_user_service()
|
||||||
|
user, err = service.update_profile(
|
||||||
|
getattr(request.state, 'user_id', None),
|
||||||
|
nickname=body.nickname,
|
||||||
|
avatar_link=body.avatar_link,
|
||||||
|
phone=body.phone,
|
||||||
|
email=body.email,
|
||||||
|
)
|
||||||
|
if not err.success:
|
||||||
|
raise HTTPException(status_code=400, detail=err.info)
|
||||||
|
data = {
|
||||||
|
'nickname': user.nickname,
|
||||||
|
'avatar_link': user.avatar_link,
|
||||||
|
'phone': user.phone,
|
||||||
|
'email': user.email,
|
||||||
|
}
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/user/me/avatar", dependencies=[Depends(require_auth)])
|
||||||
|
async def upload_avatar(request: Request, avatar: UploadFile = File(...)):
|
||||||
|
user_id = getattr(request.state, 'user_id', None)
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="unauthorized")
|
||||||
|
|
||||||
|
file_extension = os.path.splitext(avatar.filename)[1]
|
||||||
|
timestamp = int(time.time())
|
||||||
|
avatar_path = f"users/{user_id}/avatar-{timestamp}{file_extension}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
obs_util = obs.get_instance()
|
||||||
|
obs_util.Put(avatar_path, await avatar.read())
|
||||||
|
avatar_url = obs_util.Link(avatar_path)
|
||||||
|
|
||||||
|
user_service = get_user_service()
|
||||||
|
_, err = user_service.update_profile(user_id, avatar_link=avatar_url)
|
||||||
|
if not err.success:
|
||||||
|
raise HTTPException(status_code=500, detail=err.info)
|
||||||
|
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data={"avatar_link": avatar_url})
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"upload avatar failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="upload avatar failed")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdatePhoneRequest(BaseModel):
|
||||||
|
phone: str
|
||||||
|
code: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/user/me/phone", dependencies=[Depends(require_auth)])
|
||||||
|
async def update_user_phone(request: Request, body: UpdatePhoneRequest):
|
||||||
|
service = get_user_service()
|
||||||
|
user, err = service.update_phone_with_code(
|
||||||
|
getattr(request.state, 'user_id', None),
|
||||||
|
body.phone,
|
||||||
|
body.code,
|
||||||
|
)
|
||||||
|
if not err.success:
|
||||||
|
raise HTTPException(status_code=400, detail=err.info)
|
||||||
|
data = {
|
||||||
|
'nickname': user.nickname,
|
||||||
|
'avatar_link': user.avatar_link,
|
||||||
|
'phone': user.phone,
|
||||||
|
'email': user.email,
|
||||||
|
}
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=data)
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateEmailRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
code: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/api/user/me/email", dependencies=[Depends(require_auth)])
|
||||||
|
async def update_user_email(request: Request, body: UpdateEmailRequest):
|
||||||
|
service = get_user_service()
|
||||||
|
user, err = service.update_email_with_code(
|
||||||
|
getattr(request.state, 'user_id', None),
|
||||||
|
body.email,
|
||||||
|
body.code,
|
||||||
|
)
|
||||||
|
if not err.success:
|
||||||
|
raise HTTPException(status_code=400, detail=err.info)
|
||||||
|
data = {
|
||||||
|
'nickname': user.nickname,
|
||||||
|
'avatar_link': user.avatar_link,
|
||||||
|
'phone': user.phone,
|
||||||
|
'email': user.email,
|
||||||
|
}
|
||||||
|
return BaseResponse(error_code=0, error_info="success", data=data)
|
||||||
Reference in New Issue
Block a user