4 Commits

Author SHA1 Message Date
12757afda1 feat: support custom management
- add custom model and rldb model
- add service for custom to operate rldb
- add apis to CURD custom and image upload and delete
- support to recognize custom from text or image
- refactor web servcie start mode and api group
- - group the apis
- - support uvicorn start service in terminal
- - refactor recognization api for both people and custom
2025-12-18 23:50:52 +08:00
25fb6ba9ce feat: support upload image api
- support upload and delete image of people
- support uploads any image and get link after login
2025-11-25 20:46:50 +08:00
3840080074 feat: the people resource belong a user 2025-11-23 22:36:17 +08:00
af8fa03e59 feat: support multi tenant 2025-11-22 09:53:56 +08:00
19 changed files with 1831 additions and 172 deletions

View File

@@ -0,0 +1,85 @@
import datetime
import json
import logging
from langchain.prompts import ChatPromptTemplate
from .base_agent import BaseAgent
from models.custom import Custom
class ExtractCustomAgent(BaseAgent):
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
super().__init__(api_url, api_key, model_name)
self.prompt = ChatPromptTemplate.from_messages([
(
"system",
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')}"
"你是一个专业的客户信息录入助手,善于从一段文字描述中,精确获取客户的以下信息:\n"
"姓名 name\n"
"性别 gender (男/女/未知)\n"
"出生年份 birth (整数年份,如 1990若文本只提供了年龄请根据当前日期计算出出生年份)\n"
"手机号 phone\n"
"邮箱 email\n"
"身高(cm) height (整数)\n"
"体重(kg) weight (整数)\n"
"学历 degree\n"
"毕业院校 academy\n"
"职业 occupation\n"
"年收入(万) income (整数)\n"
"资产(万) assets (整数)\n"
"流动资产(万) current_assets (整数)\n"
"房产情况 house (必须为以下之一: '有房无贷', '有房有贷', '无自有房', 若未提及则不填)\n"
"车辆情况 car (必须为以下之一: '有车无贷', '有车有贷', '无自有车', 若未提及则不填)\n"
"户口城市 registered_city\n"
"居住城市 live_city\n"
"籍贯 native_place\n"
"原生家庭情况 original_family\n"
"是否独生子女 is_single_child (true/false)\n"
"择偶要求 match_requirement\n"
"\n"
"以上信息需要严格按照 JSON 格式输出,字段名与条目中英文保持一致。\n"
"若未识别到某项,则不返回该字段,不要自行填写“未知”、“未填写”等。\n"
"\n"
"除了上述基本信息,还有一个字段:\n"
"其他介绍 introductions\n"
"其余的信息需要按照字典的方式进行提炼和总结,都放在 introductions 字段中key 使用提炼好的中文。\n"
),
("human", "{input}")
])
def extract_custom_info(self, text: str) -> Custom:
"""从文本中提取客户信息"""
prompt = self.prompt.format_prompt(input=text)
response = self.llm.invoke(prompt)
logging.info(f"llm response: {response.content}")
try:
custom_dict = json.loads(response.content)
# 类型安全转换防止LLM返回字符串类型的数字
int_fields = ['birth', 'height', 'weight', 'income', 'assets', 'scores', 'current_assets']
for field in int_fields:
if field in custom_dict and isinstance(custom_dict[field], str):
try:
# 尝试提取数字,简单处理
import re
num = re.findall(r'\d+', custom_dict[field])
if num:
custom_dict[field] = int(num[0])
else:
del custom_dict[field] # 无法转换则移除
except:
del custom_dict[field]
custom = Custom.from_dict(custom_dict)
err = custom.validate()
if not err.success:
logging.warning(f"Validation warning: {err.info}")
# 即使校验失败(如某些必填项缺失),也尽可能返回已提取的对象,
# 让上层业务逻辑决定是否接受或需要补充
return custom
except json.JSONDecodeError:
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
return None
except Exception as e:
logging.error(f"Failed to process custom info: {e}")
return None

View File

@@ -1,13 +1,36 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-09-27
import os
import sys
# Add src directory to sys.path to ensure modules can be imported correctly when running with uvicorn
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import argparse
import uvicorn
from services import people as people_service
from utils import config, logger, obs, ocr, rldb
from services import user as user_service
from services import custom as custom_service
from utils import config, logger, obs, ocr, rldb, sms, mailer
from web.api import api
def initialize_app(config_path):
"""Initialize application components with the given config path."""
config.init(config_path)
conf = config.get_instance()
logger.init()
rldb.init()
ocr.init()
obs.init()
mailer.init(conf.get('mailer', 'type', fallback='real'))
sms.init(conf.get('sms', 'type', fallback='real'))
people_service.init()
user_service.init()
custom_service.init()
# 主函数
def main():
main_path = os.path.dirname(os.path.abspath(__file__))
@@ -15,21 +38,19 @@ def main():
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
args = parser.parse_args()
config.init(args.config)
logger.init()
rldb.init()
ocr.init()
obs.init()
people_service.init()
initialize_app(args.config)
conf = config.get_instance()
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
port = conf.getint('web_service', 'server_port', fallback=8099)
uvicorn.run(api, host=host, port=port)
uvicorn.run("src.main:api", host=host, port=port, reload=True) # Modified to string import for reload support in main too, though api object also works
if __name__ == "__main__":
main()
main()
else:
# Support for running via 'uvicorn src.main:api'
# Use environment variable for config path or default
main_path = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(main_path, '../configuration/test_conf.ini')
config_path = os.environ.get('IFU_CONFIG_PATH', default_config_path)
initialize_app(config_path)

291
src/models/custom.py Normal file
View File

@@ -0,0 +1,291 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-11-27
import json
import logging
from typing import Dict, List
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, func, Boolean
from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error
class CustomRLDBModel(RLDBBaseModel):
"""
客户数据的数据库模型 (SQLAlchemy Model) - 更新版
"""
__tablename__ = 'customs'
id = Column(String(36), primary_key=True)
user_id = Column(String(36), index=True, nullable=False)
# 基本信息
name = Column(String(255), index=True, nullable=False)
gender = Column(String(10), nullable=False)
birth = Column(Integer, nullable=False) # 出生年份
phone = Column(String(50), index=True)
email = Column(String(255), index=True)
# 外貌信息
height = Column(Integer)
weight = Column(Integer)
images = Column(Text) # JSON string for list[str]
scores = Column(Integer)
# 学历职业
degree = Column(String(255))
academy = Column(String(255))
occupation = Column(String(255))
income = Column(Integer) # 单位:万
assets = Column(Integer) # 单位:万
current_assets = Column(Integer) # 单位:万
house = Column(String(50))
car = Column(String(50))
# 户口家庭
registered_city = Column(String(255))
live_city = Column(String(255))
native_place = Column(String(255))
original_family = Column(Text)
is_single_child = Column(Boolean, default=False)
match_requirement = Column(Text)
introductions = Column(Text) # JSON string for Dict[str, str]
# 客户信息
custom_level = Column(String(255))
comments = Column(Text) # JSON string for Dict[str, str]
is_public = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
class Custom:
"""
客户数据的业务逻辑模型 (Business Logic Model) - 更新版
"""
id: str
user_id: str
# 基本信息
name: str
gender: str
birth: int
phone: str
email: str
# 外貌信息
height: int
weight: int
images: List[str]
scores: int
# 学历职业
degree: str
academy: str
occupation: str
income: int
assets: int
current_assets: int
house: str
car: str
# 户口家庭
registered_city: str
live_city: str
native_place: str
original_family: str
is_single_child: bool
match_requirement: str
introductions: Dict[str, str]
# 客户信息
custom_level: str
comments: Dict[str, str]
is_public: bool
created_at: datetime = None
def __init__(self, **kwargs):
# 初始化所有属性
self.id = kwargs.get('id', '')
self.user_id = kwargs.get('user_id', '')
self.name = kwargs.get('name', '')
self.gender = kwargs.get('gender', '未知')
self.birth = kwargs.get('birth', 0)
self.phone = kwargs.get('phone', '')
self.email = kwargs.get('email', '')
self.height = kwargs.get('height', 0)
self.weight = kwargs.get('weight', 0)
self.images = kwargs.get('images', [])
self.scores = kwargs.get('scores', 0)
self.degree = kwargs.get('degree', '')
self.academy = kwargs.get('academy', '')
self.occupation = kwargs.get('occupation', '')
self.income = kwargs.get('income', 0)
self.assets = kwargs.get('assets', 0)
self.current_assets = kwargs.get('current_assets', 0)
self.house = kwargs.get('house', '')
self.car = kwargs.get('car', '')
self.registered_city = kwargs.get('registered_city', '')
self.live_city = kwargs.get('live_city', '')
self.native_place = kwargs.get('native_place', '')
self.original_family = kwargs.get('original_family', '')
self.is_single_child = kwargs.get('is_single_child', False)
self.match_requirement = kwargs.get('match_requirement', '')
self.introductions = kwargs.get('introductions', {})
self.custom_level = kwargs.get('custom_level', '')
self.comments = kwargs.get('comments', {})
self.is_public = kwargs.get('is_public', False)
self.created_at = kwargs.get('created_at')
def __str__(self) -> str:
# 返回对象的字符串表示
attributes = ", ".join(f"{k}={v}" for k, v in self.to_dict().items())
return f"Custom({attributes})"
@classmethod
def from_dict(cls, data: dict):
# 从字典创建对象实例
if 'created_at' in data and data['created_at'] is not None:
data['created_at'] = datetime.fromtimestamp(data['created_at'])
# 移除ORM特有的时间戳避免初始化错误
data.pop('updated_at', None)
data.pop('deleted_at', None)
return cls(**data)
def to_dict(self) -> dict:
# 将对象转换为字典
return {
'id': self.id,
'user_id': self.user_id,
'name': self.name,
'gender': self.gender,
'birth': self.birth,
'phone': self.phone,
'email': self.email,
'height': self.height,
'weight': self.weight,
'images': self.images,
'scores': self.scores,
'degree': self.degree,
'academy': self.academy,
'occupation': self.occupation,
'income': self.income,
'assets': self.assets,
'current_assets': self.current_assets,
'house': self.house,
'car': self.car,
'registered_city': self.registered_city,
'live_city': self.live_city,
'native_place': self.native_place,
'original_family': self.original_family,
'is_single_child': self.is_single_child,
'match_requirement': self.match_requirement,
'introductions': self.introductions,
'custom_level': self.custom_level,
'comments': self.comments,
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
}
@classmethod
def from_rldb_model(cls, data: CustomRLDBModel):
# 从数据库模型转换
return cls(
id=data.id,
user_id=data.user_id,
name=data.name,
gender=data.gender,
birth=data.birth,
phone=data.phone,
email=data.email,
height=data.height,
weight=data.weight,
images=json.loads(data.images) if data.images else [],
scores=data.scores,
degree=data.degree,
academy=data.academy,
occupation=data.occupation,
income=data.income,
assets=data.assets,
house=data.house,
car=data.car,
registered_city=data.registered_city,
live_city=data.live_city,
native_place=data.native_place,
original_family=data.original_family,
is_single_child=data.is_single_child,
match_requirement=data.match_requirement,
introductions=json.loads(data.introductions) if data.introductions else {},
custom_level=data.custom_level,
comments=json.loads(data.comments) if data.comments else {},
is_public=data.is_public,
created_at=data.created_at,
)
def to_rldb_model(self) -> CustomRLDBModel:
# 转换为数据库模型
return CustomRLDBModel(
id=self.id,
user_id=self.user_id,
name=self.name,
gender=self.gender,
birth=self.birth,
phone=self.phone,
email=self.email,
height=self.height,
weight=self.weight,
images=json.dumps(self.images, ensure_ascii=False),
scores=self.scores,
degree=self.degree,
academy=self.academy,
occupation=self.occupation,
income=self.income,
assets=self.assets,
current_assets=self.current_assets,
house=self.house,
car=self.car,
registered_city=self.registered_city,
live_city=self.live_city,
native_place=self.native_place,
original_family=self.original_family,
is_single_child=self.is_single_child,
match_requirement=self.match_requirement,
introductions=json.dumps(self.introductions, ensure_ascii=False),
custom_level=self.custom_level,
comments=json.dumps(self.comments, ensure_ascii=False),
is_public=self.is_public,
)
def validate(self) -> error:
# 数据校验逻辑
if not self.name:
return error(ErrorCode.INVALID_PARAMS, "Name cannot be empty.")
if not self.gender:
return error(ErrorCode.INVALID_PARAMS, "Gender cannot be empty.")
if self.gender not in ['', '']:
return error(ErrorCode.INVALID_PARAMS, "Gender must be '' or ''.")
current_year = datetime.now().year
min_birth_year = 1950
max_birth_year = current_year - 18
if not isinstance(self.birth, int):
return error(ErrorCode.INVALID_PARAMS, "Birth year must be an integer.")
if self.birth < min_birth_year or self.birth > max_birth_year:
return error(ErrorCode.INVALID_PARAMS, f"Birth year must be between {min_birth_year} and {max_birth_year}.")
valid_houses = ["", "有房无贷", "有房有贷", "无自有房"]
if self.house not in valid_houses:
return error(ErrorCode.INVALID_PARAMS, f"House must be one of {valid_houses}")
valid_cars = ["", "有车无贷", "有车有贷", "无自有车"]
if self.car not in valid_cars:
return error(ErrorCode.INVALID_PARAMS, f"Car must be one of {valid_cars}")
# ... 可根据需要添加更多校验 ...
return error(ErrorCode.SUCCESS, "")

View File

@@ -12,6 +12,7 @@ from utils.error import ErrorCode, error
class PeopleRLDBModel(RLDBBaseModel):
__tablename__ = 'peoples'
id = Column(String(36), primary_key=True)
user_id = Column(String(36), index=True)
name = Column(String(255), index=True)
contact = Column(String(255), index=True)
gender = Column(String(10))
@@ -61,6 +62,8 @@ class Comment:
class People:
# 数据库 ID
id: str
# 所属用户 ID
user_id: str
# 姓名
name: str
# 联系人
@@ -87,6 +90,7 @@ class People:
def __init__(self, **kwargs):
# 初始化所有属性从kwargs中获取值如果不存在则设置默认值
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.user_id = kwargs.get('user_id', '') if kwargs.get('user_id', '') is not None else ''
self.name = kwargs.get('name', '') if kwargs.get('name', '') is not None else ''
self.contact = kwargs.get('contact', '') if kwargs.get('contact', '') is not None else ''
self.gender = kwargs.get('gender', '') if kwargs.get('gender', '') is not None else ''
@@ -121,6 +125,7 @@ class People:
# 将关系数据库模型转换为对象
return cls(
id=data.id,
user_id=data.user_id,
name=data.name,
contact=data.contact,
gender=data.gender,
@@ -138,6 +143,7 @@ class People:
# 将对象转换为字典格式
return {
'id': self.id,
'user_id': self.user_id,
'name': self.name,
'contact': self.contact,
'gender': self.gender,
@@ -155,6 +161,7 @@ class People:
# 将对象转换为关系数据库模型
return PeopleRLDBModel(
id=self.id,
user_id=self.user_id,
name=self.name,
contact=self.contact,
gender=self.gender,

184
src/models/user.py Normal file
View File

@@ -0,0 +1,184 @@
from typing import Optional
import json
from datetime import datetime, timedelta
from sqlalchemy import Column, String, Text, DateTime, Integer, Boolean, func, UniqueConstraint
from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error
class UserRLDBModel(RLDBBaseModel):
__tablename__ = 'users'
id = Column(String(36), primary_key=True)
nickname = Column(String(255))
avatar_link = Column(String(255))
email = Column(String(127), unique=True, index=True)
phone = Column(String(32), unique=True, index=True)
password_hash = Column(String(255))
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
class VerificationCodeRLDBModel(RLDBBaseModel):
__tablename__ = 'verification_codes'
id = Column(String(36), primary_key=True)
target_type = Column(String(16))
target = Column(String(255), index=True)
code = Column(String(16))
scene = Column(String(32))
expires_at = Column(DateTime(timezone=True))
used_at = Column(DateTime(timezone=True), nullable=True)
class UserTokenRLDBModel(RLDBBaseModel):
__tablename__ = 'user_tokens'
id = Column(String(36), primary_key=True)
user_id = Column(String(36), index=True)
token = Column(Text)
expired_at = Column(DateTime(timezone=True))
revoked = Column(Boolean, default=False)
class User:
id: str
nickname: str
avatar_link: str
email: str
phone: str
password_hash: str
created_at: datetime = None
def __init__(self, **kwargs):
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.nickname = kwargs.get('nickname', '') if kwargs.get('nickname', '') is not None else ''
self.avatar_link = kwargs.get('avatar_link', '') if kwargs.get('avatar_link', '') is not None else ''
self.email = kwargs.get('email', '') if kwargs.get('email', '') is not None else ''
self.phone = kwargs.get('phone', '') if kwargs.get('phone', '') is not None else ''
self.password_hash = kwargs.get('password_hash', '') if kwargs.get('password_hash', '') is not None else ''
self.created_at = kwargs.get('created_at', None)
def __str__(self) -> str:
return (f"User(id={self.id}, nickname={self.nickname}, avatar_link={self.avatar_link}, "
f"email={self.email}, phone={self.phone}, created_at={self.created_at})")
@classmethod
def from_dict(cls, data: dict):
if 'updated_at' in data:
del data['updated_at']
if 'deleted_at' in data:
del data['deleted_at']
return cls(**data)
@classmethod
def from_rldb_model(cls, data: UserRLDBModel):
return cls(
id=data.id,
nickname=data.nickname,
avatar_link=data.avatar_link,
email=data.email,
phone=data.phone,
password_hash=data.password_hash,
created_at=data.created_at,
)
def to_dict(self) -> dict:
return {
'id': self.id,
'nickname': self.nickname,
'avatar_link': self.avatar_link,
'email': self.email,
'phone': self.phone,
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
}
def to_rldb_model(self) -> UserRLDBModel:
return UserRLDBModel(
id=self.id,
nickname=self.nickname,
avatar_link=self.avatar_link,
email=self.email,
phone=self.phone,
password_hash=self.password_hash,
)
def validate(self) -> error:
err = error(ErrorCode.SUCCESS, "")
if not self.email and not self.phone:
return error(ErrorCode.MODEL_ERROR, "email or phone required")
return err
class VerificationCode:
id: str
target_type: str
target: str
code: str
scene: str
expires_at: datetime
used_at: Optional[datetime] = None
def __init__(self, **kwargs):
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.target_type = kwargs.get('target_type', '')
self.target = kwargs.get('target', '')
self.code = kwargs.get('code', '')
self.scene = kwargs.get('scene', '')
self.expires_at = kwargs.get('expires_at')
self.used_at = kwargs.get('used_at', None)
@classmethod
def from_rldb_model(cls, data: VerificationCodeRLDBModel):
return cls(
id=data.id,
target_type=data.target_type,
target=data.target,
code=data.code,
scene=data.scene,
expires_at=data.expires_at,
used_at=data.used_at,
)
def to_rldb_model(self) -> VerificationCodeRLDBModel:
return VerificationCodeRLDBModel(
id=self.id,
target_type=self.target_type,
target=self.target,
code=self.code,
scene=self.scene,
expires_at=self.expires_at,
used_at=self.used_at,
)
class UserToken:
id: str
user_id: str
token: str
expired_at: datetime
revoked: bool
def __init__(self, **kwargs):
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.user_id = kwargs.get('user_id', '')
self.token = kwargs.get('token', '')
self.expired_at = kwargs.get('expired_at')
self.revoked = kwargs.get('revoked', False)
@classmethod
def from_rldb_model(cls, data: UserTokenRLDBModel):
return cls(
id=data.id,
user_id=data.user_id,
token=data.token,
expired_at=data.expired_at,
revoked=data.revoked,
)
def to_rldb_model(self) -> UserTokenRLDBModel:
return UserTokenRLDBModel(
id=self.id,
user_id=self.user_id,
token=self.token,
expired_at=self.expired_at,
revoked=self.revoked,
)

98
src/services/custom.py Normal file
View File

@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-11-27
import logging
import uuid
from models.custom import Custom, CustomRLDBModel
from utils.error import ErrorCode, error
from utils import rldb
class CustomService:
def __init__(self):
self.rldb = rldb.get_instance()
def save(self, custom: Custom) -> (str, error):
"""
保存客户到数据库。
如果 custom.id 存在,则更新;否则,创建。
:param custom: 客户对象
:return: 客户ID 和 错误对象
"""
# 0. 生成 custom id
custom.id = custom.id if custom.id else uuid.uuid4().hex
# 1. 转换模型,并保存到 SQL 数据库
try:
custom_orm = custom.to_rldb_model()
self.rldb.upsert(custom_orm)
return custom.id, error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to save custom {custom.id}: {e}")
return "", error(ErrorCode.RLDB_ERROR, f"Failed to save custom data: {str(e)}")
def delete(self, custom_id: str) -> error:
"""
从数据库删除客户。
:param custom_id: 客户ID
:return: 错误对象
"""
try:
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
if not custom_orm:
return error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
self.rldb.delete(custom_orm)
return error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to delete custom {custom_id}: {e}")
return error(ErrorCode.RLDB_ERROR, f"Failed to delete custom data: {str(e)}")
def get(self, custom_id: str) -> (Custom, error):
"""
从数据库获取单个客户。
:param custom_id: 客户ID
:return: 客户对象 和 错误对象
"""
try:
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
if not custom_orm:
return None, error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
custom = Custom.from_rldb_model(custom_orm)
return custom, error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to get custom {custom_id}: {e}")
return None, error(ErrorCode.RLDB_ERROR, f"Failed to retrieve custom data: {str(e)}")
def list(self, conds: dict = None, limit: int = 10, offset: int = 0) -> (list[Custom], error):
"""
根据条件从数据库列出客户(支持分页)。
:param conds: 查询条件字典
:param limit: 每页数量
:param offset: 偏移量
:return: 客户对象列表 和 错误对象
"""
if conds is None:
conds = {}
try:
custom_orms = self.rldb.query(CustomRLDBModel, limit=limit, offset=offset, **conds)
customs = [Custom.from_rldb_model(orm) for orm in custom_orms]
return customs, error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to list customs with conds {conds}: {e}")
return [], error(ErrorCode.RLDB_ERROR, f"Failed to list custom data: {str(e)}")
# --- Singleton Pattern ---
custom_service = None
def init():
"""初始化 CustomService 单例"""
global custom_service
custom_service = CustomService()
def get_instance() -> CustomService:
"""获取 CustomService 单例"""
return custom_service

220
src/services/user.py Normal file
View File

@@ -0,0 +1,220 @@
import uuid
import hmac
import base64
import os
from datetime import datetime, timedelta
from typing import Optional
from utils.error import ErrorCode, error
from utils import rldb, mailer, sms, config
from models.user import (
User,
UserRLDBModel,
VerificationCode,
VerificationCodeRLDBModel,
UserToken,
UserTokenRLDBModel,
)
class UserService:
def __init__(self):
self.rldb = rldb.get_instance()
self.mailer = mailer.get_instance()
self.sms = sms.get_instance()
self.conf = config.get_instance()
def _hash_password(self, password: str, salt: Optional[str] = None) -> str:
salt = salt if salt else base64.urlsafe_b64encode(os.urandom(16)).decode('utf-8')
digest = hmac.new(salt.encode('utf-8'), password.encode('utf-8'), 'sha256').digest()
return f"{salt}:{base64.urlsafe_b64encode(digest).decode('utf-8')}"
def _verify_password(self, password: str, password_hash: str) -> bool:
parts = password_hash.split(':')
if len(parts) != 2:
return False
salt = parts[0]
return self._hash_password(password, salt) == password_hash
def send_code(self, target_type: str, target: str, scene: str) -> error:
scens = {
"register": "注册",
"update": "信息更新",
# "login": "登录",
}
if scene not in scens:
return error(ErrorCode.MODEL_ERROR, f'scene {scene} not supported')
scene_name = scens.get(scene, scene)
code = f"{uuid.uuid4().int % 1000000:06d}"
expires = datetime.now() + timedelta(minutes=10)
vc = VerificationCode(
id=uuid.uuid4().hex,
target_type=target_type,
target=target,
code=code,
scene=scene,
expires_at=expires,
)
self.rldb.upsert(vc.to_rldb_model())
content = f"IF.U服务{scene_name}验证码: {code}, 10分钟内有效"
sent = True
if target_type == 'email':
sent = self.mailer.send(target, f'IF.U服务{scene_name}验证码', content) if self.mailer else False
elif target_type == 'phone':
sent = self.sms.send(target, content) if self.sms else False
if not sent:
return error(ErrorCode.RLDB_ERROR, 'send code failed')
return error(ErrorCode.SUCCESS, '')
def _get_user_by_identifier(self, email: Optional[str], phone: Optional[str]) -> Optional[User]:
if email:
users = self.rldb.query(UserRLDBModel, email=email, limit=1)
if users:
return User.from_rldb_model(users[0])
if phone:
users = self.rldb.query(UserRLDBModel, phone=phone, limit=1)
if users:
return User.from_rldb_model(users[0])
return None
def register(self, user: User, code: str) -> (str, error):
if not user.email and not user.phone:
return '', error(ErrorCode.MODEL_ERROR, 'email or phone required')
existed = self._get_user_by_identifier(user.email, user.phone)
if existed:
return '', error(ErrorCode.MODEL_ERROR, 'user existed')
target_type = 'phone' if user.phone else 'email'
target = user.phone if user.phone else user.email
vc_list = self.rldb.query(
VerificationCodeRLDBModel,
target_type=target_type,
target=target,
scene='register',
limit=1,
)
if not vc_list:
return '', error(ErrorCode.MODEL_ERROR, 'code not found')
vc = vc_list[0]
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
return '', error(ErrorCode.MODEL_ERROR, 'invalid code')
vc.used_at = datetime.now()
self.rldb.upsert(vc)
user.id = uuid.uuid4().hex
hashed = self._hash_password(user.password_hash)
user.password_hash = hashed
self.rldb.upsert(user.to_rldb_model())
return user.id, error(ErrorCode.SUCCESS, '')
def login(self, email: Optional[str], phone: Optional[str], password: str) -> (dict, error):
u = self._get_user_by_identifier(email, phone)
if not u:
return {}, error(ErrorCode.MODEL_ERROR, 'user not found')
if not self._verify_password(password, u.password_hash):
return {}, error(ErrorCode.MODEL_ERROR, 'invalid password')
ttl_days = self.conf.getint('auth', 'token_ttl_days', fallback=30)
expired_at = datetime.now() + timedelta(days=ttl_days)
token_raw = f"{u.id}.{uuid.uuid4().hex}.{int(expired_at.timestamp())}"
secret = self.conf.get('auth', 'jwt_secret', fallback='dev-secret')
signature = hmac.new(secret.encode('utf-8'), token_raw.encode('utf-8'), 'sha256').digest()
token = base64.urlsafe_b64encode(token_raw.encode('utf-8')).decode('utf-8') + '.' + base64.urlsafe_b64encode(signature).decode('utf-8')
ut = UserToken(id=uuid.uuid4().hex, user_id=u.id, token=token, expired_at=expired_at, revoked=False)
self.rldb.upsert(ut.to_rldb_model())
return {'token': token, 'expired_at': int(expired_at.timestamp())}, error(ErrorCode.SUCCESS, '')
def logout(self, token: str) -> error:
tokens = self.rldb.query(UserTokenRLDBModel, token=token, limit=1)
if not tokens:
return error(ErrorCode.MODEL_ERROR, 'token not found')
t = tokens[0]
t.revoked = True
self.rldb.upsert(t)
return error(ErrorCode.SUCCESS, '')
def delete_user(self, user_id: str) -> error:
u = self.rldb.get(UserRLDBModel, user_id)
if not u:
return error(ErrorCode.MODEL_ERROR, 'user not found')
self.rldb.delete(u)
return error(ErrorCode.SUCCESS, '')
def get(self, user_id: str) -> (User, error):
u = self.rldb.get(UserRLDBModel, user_id)
if not u:
return None, error(ErrorCode.MODEL_ERROR, 'user not found')
return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '')
def update_profile(self, user_id: str, nickname: str = None, avatar_link: str = None, phone: str = None, email: str = None) -> (User, error):
u = self.rldb.get(UserRLDBModel, user_id)
if not u:
return None, error(ErrorCode.MODEL_ERROR, 'user not found')
has_email = bool(u.email)
has_phone = bool(u.phone)
if nickname is not None:
u.nickname = nickname
if avatar_link is not None:
u.avatar_link = avatar_link
if email is not None:
new_email = email
if has_email:
if not has_phone:
return None, error(ErrorCode.MODEL_ERROR, 'email update requires phone exists')
conflicts = self.rldb.query(UserRLDBModel, email=new_email, limit=1)
if conflicts and conflicts[0].id != user_id:
return None, error(ErrorCode.MODEL_ERROR, 'email existed')
u.email = new_email
if phone is not None:
new_phone = phone
if has_phone:
if not has_email:
return None, error(ErrorCode.MODEL_ERROR, 'phone update requires email exists')
conflicts = self.rldb.query(UserRLDBModel, phone=new_phone, limit=1)
if conflicts and conflicts[0].id != user_id:
return None, error(ErrorCode.MODEL_ERROR, 'phone existed')
u.phone = new_phone
self.rldb.upsert(u)
return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '')
def update_phone_with_code(self, user_id: str, new_phone: str, code: str) -> (User, error):
vc_list = self.rldb.query(
VerificationCodeRLDBModel,
target_type='phone',
target=new_phone,
scene='update',
limit=1,
)
if not vc_list:
return None, error(ErrorCode.MODEL_ERROR, 'code not found')
vc = vc_list[0]
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
return None, error(ErrorCode.MODEL_ERROR, 'invalid code')
vc.used_at = datetime.now()
self.rldb.upsert(vc)
return self.update_profile(user_id, phone=new_phone)
def update_email_with_code(self, user_id: str, new_email: str, code: str) -> (User, error):
vc_list = self.rldb.query(
VerificationCodeRLDBModel,
target_type='email',
target=new_email,
scene='update',
limit=1,
)
if not vc_list:
return None, error(ErrorCode.MODEL_ERROR, 'code not found')
vc = vc_list[0]
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
return None, error(ErrorCode.MODEL_ERROR, 'invalid code')
vc.used_at = datetime.now()
self.rldb.upsert(vc)
return self.update_profile(user_id, email=new_email)
user_service = None
def init():
global user_service
user_service = UserService()
def get_instance() -> UserService:
return user_service

View File

@@ -7,6 +7,10 @@ class ErrorCode(Enum):
SUCCESS = 0
MODEL_ERROR = 1000
RLDB_ERROR = 2100
RLDB_NOT_FOUND = 2101
OBS_ERROR = 3100
OBS_INPUT_ERROR = 3102
OBS_SERVICE_ERROR = 3103
class error(Protocol):
_error_code: int = 0
@@ -15,7 +19,6 @@ class error(Protocol):
def __init__(self, error_code: ErrorCode, error_info: str):
self._error_code = int(error_code.value)
self._error_info = error_info
logging.info(f"errorcode: {type(self._error_code)}")
def __str__(self) -> str:
return f"{self.__class__.__name__}({self._error_code}, {self._error_info})"

60
src/utils/mailer.py Normal file
View File

@@ -0,0 +1,60 @@
import logging
import smtplib
from email.mime.text import MIMEText
from typing import Protocol
from .config import get_instance as get_config
class Mailer(Protocol):
def send(self, to_email: str, subject: str, content: str) -> bool:
...
class FakeMailer:
def __init__(self) -> None:
conf = get_config()
self.fake_message = conf.get('fake_mailer', 'message', fallback="FakeEmail")
def send(self, to_email: str, subject: str, content: str) -> bool:
logging.info(f"{self.fake_message}: to_email={to_email}, subject={subject}, content={content}")
return True
class RealMailer:
def __init__(self):
conf = get_config()
self.smtp_host = conf.get('real_mailer', 'smtp_host', fallback=None)
self.smtp_port = conf.getint('real_mailer', 'smtp_port', fallback=587)
self.smtp_user = conf.get('real_mailer', 'smtp_user', fallback=None)
self.smtp_pass = conf.get('real_mailer', 'smtp_pass', fallback=None)
self.from_email = conf.get('real_mailer', 'from_email', fallback=self.smtp_user)
def send(self, to_email: str, subject: str, content: str) -> bool:
if not self.smtp_host or not self.smtp_user or not self.smtp_pass:
return False
msg = MIMEText(content, 'plain', 'utf-8')
msg['Subject'] = subject
msg['From'] = self.from_email
msg['To'] = to_email
try:
server = smtplib.SMTP(self.smtp_host, self.smtp_port)
server.starttls()
server.login(self.smtp_user, self.smtp_pass)
server.sendmail(self.from_email, [to_email], msg.as_string())
server.quit()
return True
except Exception:
return False
_mailer: Mailer = None
def init(type: str = 'real'):
global _mailer
if type == 'real':
_mailer = RealMailer()
else:
_mailer = FakeMailer()
def get_instance() -> Mailer:
return _mailer

View File

@@ -4,11 +4,13 @@ import logging
from typing import Protocol
import qiniu
import requests
from .error import ErrorCode, error
from .config import get_instance as get_config
class OBS(Protocol):
def Put(self, obs_path: str, content: bytes) -> str:
def put(self, obs_path: str, content: bytes) -> str:
"""
上传文件到OBS
@@ -21,7 +23,7 @@ class OBS(Protocol):
"""
...
def Get(self, obs_path: str) -> bytes:
def get(self, obs_path: str) -> bytes:
"""
从OBS下载文件
@@ -33,7 +35,7 @@ class OBS(Protocol):
"""
...
def List(self, obs_path: str) -> list:
def list(self, obs_path: str) -> list:
"""
列出OBS目录下的所有文件
@@ -45,7 +47,7 @@ class OBS(Protocol):
"""
...
def Del(self, obs_path: str) -> bool:
def delete(self, obs_path: str) -> error:
"""
删除OBS文件
@@ -57,7 +59,7 @@ class OBS(Protocol):
"""
...
def Link(self, obs_path: str) -> str:
def get_link(self, obs_path: str) -> str:
"""
获取OBS文件链接
@@ -68,6 +70,31 @@ class OBS(Protocol):
str: OBS文件链接
"""
...
def delete_by_link(self, obs_link: str) -> error:
"""
根据OBS文件链接删除文件
Args:
obs_link (str): OBS文件链接
Returns:
bool: 是否删除成功
"""
...
def get_obs_path_by_link(self, obs_link: str) -> (str, error):
"""
从OBS文件链接获取OBS路径
Args:
obs_link (str): OBS文件链接
Returns:
str: OBS文件路径
error: 错误信息
"""
...
class Koodo:
@@ -82,7 +109,7 @@ class Koodo:
self.bucket = qiniu.BucketManager(self.auth)
pass
def Put(self, obs_path: str, content: bytes) -> str:
def put(self, obs_path: str, content: bytes) -> str:
"""
上传文件到OBS
@@ -103,7 +130,7 @@ class Koodo:
logging.info(f"文件 {obs_path} 上传成功, OBS路径: {full_path}")
return f"{self.outer_domain}/{full_path}"
def Get(self, obs_path: str) -> bytes:
def get(self, obs_path: str) -> bytes:
"""
从OBS下载文件
@@ -121,7 +148,7 @@ class Koodo:
return None
return resp.content
def List(self, prefix: str = "") -> list[str]:
def list(self, prefix: str = "") -> list[str]:
"""
列出OBS目录下的所有文件
@@ -143,7 +170,7 @@ class Koodo:
# logging.debug(f"info: {info}")
return keys
def Del(self, obs_path: str) -> bool:
def delete(self, obs_path: str) -> error:
"""
删除OBS文件
@@ -151,17 +178,17 @@ class Koodo:
obs_path (str): OBS文件路径
Returns:
bool: 是否删除成功
error: 删除结果
"""
ret, info = self.bucket.delete(self.bucket_name, f"{self.prefix_path}{obs_path}")
logging.debug(f"文件 {obs_path} 删除 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}")
logging.debug(f"文件 {self.prefix_path}{obs_path} 删除 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}")
if ret is None or info.status_code != 200:
logging.error(f"文件 {obs_path} 删除 OBS 失败, 错误信息: {info.text_body}")
return False
return error(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {self.prefix_path}{obs_path} 删除 OBS 失败, 错误信息: {info.text_body}")
logging.info(f"文件 {obs_path} 删除 OBS 成功")
return True
return error(error_code=ErrorCode.SUCCESS, error_info="success")
def Link(self, obs_path: str) -> str:
def get_link(self, obs_path: str) -> str:
"""
获取OBS文件链接
@@ -173,6 +200,38 @@ class Koodo:
"""
return f"{self.outer_domain}/{self.prefix_path}{obs_path}"
def delete_by_link(self, obs_link: str) -> error:
"""
根据OBS文件链接删除文件
Args:
obs_link (str): OBS文件链接
Returns:
error: 删除结果
"""
obs_path, err = self.get_obs_path_by_link(obs_link)
if not err.success:
return err
return self.delete(obs_path)
def get_obs_path_by_link(self, obs_link: str) -> (str, error):
"""
从OBS文件链接获取OBS路径
Args:
obs_link (str): OBS文件链接
Returns:
str: OBS文件路径
error: 错误信息
"""
if not obs_link.startswith(f"{self.outer_domain}/{self.prefix_path}"):
logging.error(f"文件 {obs_link} 不是 OBS 文件链接")
return "", error(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {obs_link} 不是 OBS 文件链接")
obs_path = obs_link[len(self.outer_domain) + len(self.prefix_path) + 1:]
return obs_path, error(error_code=ErrorCode.SUCCESS, error_info="success")
_obs_instance: OBS = None
@@ -213,8 +272,8 @@ if __name__ == "__main__":
# print(f"文件 {obs_path} 链接: {link}")
# 列出OBS目录下的所有文件
keys = obs.List("")
keys = obs.list("")
print(f"OBS 目录下的所有文件: {keys}")
for key in keys:
link = obs.Del(key)
link = obs.delete(key)
print(f"文件 {key} 删除 OBS 成功: {link}")

View File

@@ -1,4 +1,5 @@
from re import S
from typing import Protocol
import uuid
from sqlalchemy import Column, DateTime, String, create_engine, func

51
src/utils/sms.py Normal file
View File

@@ -0,0 +1,51 @@
import logging
from typing import Protocol
import requests
from .config import get_instance as get_config
class SMS(Protocol):
def send(self, phone: str, content: str) -> bool:
...
class FakeSMS:
def __init__(self) -> None:
conf = get_config()
self.fake_message = conf.get('fake_sms', 'message', fallback="FakeSMS")
def send(self, phone: str, content: str) -> bool:
logging.info(f"{self.fake_message}: phone={phone}, content={content}")
return True
class RealSMS:
def __init__(self):
conf = get_config()
self.webhook_url = conf.get('real_sms', 'webhook_url', fallback=None)
self.webhook_token = conf.get('real_sms', 'webhook_token', fallback=None)
def send(self, phone: str, content: str) -> bool:
if not self.webhook_url:
return False
try:
headers = {'Authorization': f'Bearer {self.webhook_token}'} if self.webhook_token else {}
data = {'phone': phone, 'content': content}
resp = requests.post(self.webhook_url, json=data, headers=headers, timeout=5)
return resp.status_code >= 200 and resp.status_code < 300
except Exception:
return False
_sms: SMS = None
def init(type: str = 'real'):
global _sms
if type == 'real':
_sms = RealSMS()
else:
_sms = FakeSMS()
def get_instance() -> SMS:
return _sms

View File

@@ -1,173 +1,60 @@
import os
import uuid
import logging
from typing import Any, Optional
from fastapi import FastAPI, UploadFile, File, Query
from pydantic import BaseModel
from fastapi import FastAPI, UploadFile, File, APIRouter, Depends
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from services.people import get_instance as get_people_service
from models.people import People
from agents.extract_people_agent import ExtractPeopleAgent
from utils import obs, ocr
from web.auth import require_auth
from utils import obs
from web.schemas import BaseResponse
from web.custom import router as custom_router
from web.people import router as people_router
from web.user import router as user_router
from web.recognition import router as recognition_router
api = FastAPI(title="Single People Management and Searching", version="0.1")
api.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=["https://localhost:5173", "https://ifu.mamamiyear.site"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class BaseResponse(BaseModel):
error_code: int
error_info: str
data: Optional[Any] = None
authorized_router = APIRouter(dependencies=[Depends(require_auth)])
@api.post("/ping")
@api.post("/api/ping")
async def ping():
return BaseResponse(error_code=0, error_info="success")
class PostInputRequest(BaseModel):
text: str
@api.post("/recognition/input")
async def post_input(request: PostInputRequest):
people = extract_people(request.text)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
@api.post("/recognition/image")
async def post_input_image(image: UploadFile = File(...)):
@authorized_router.post("/api/upload/image")
async def post_upload_image(image: UploadFile = File(...)):
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 确保uploads目录存在
os.makedirs("uploads", exist_ok=True)
# 保存文件到对象存储
file_path = f"uploads/{unique_filename}"
obs_util = obs.get_instance()
obs_util.Put(file_path, await image.read())
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.Link(file_path)
logging.info(f"obs_url: {obs_url}")
# 调用OCR处理图片
ocr_util = ocr.get_instance()
ocr_result = ocr_util.recognize_image_text(obs_url)
logging.info(f"ocr_result: {ocr_result}")
people = extract_people(ocr_result, obs_url)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
def extract_people(text: str, cover_link: str = None) -> People:
extra_agent = ExtractPeopleAgent()
people = extra_agent.extract_people_info(text)
people.cover = cover_link
logging.info(f"people: {people}")
return people
class PostPeopleRequest(BaseModel):
people: dict
@api.post("/people")
async def post_people(post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
service = get_people_service()
people.id, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=people.id)
@api.put("/people/{people_id}")
async def update_people(people_id: str, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.id = people_id
service = get_people_service()
res, error = service.get(people_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
_, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@api.delete("/people/{people_id}")
async def delete_people(people_id: str):
service = get_people_service()
error = service.delete(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
class GetPeopleRequest(BaseModel):
query: Optional[str] = None
conds: Optional[dict] = None
top_k: int = 5
@api.get("/peoples")
async def get_peoples(
name: Optional[str] = Query(None, description="姓名"),
gender: Optional[str] = Query(None, description="性别"),
age: Optional[int] = Query(None, description="年龄"),
height: Optional[int] = Query(None, description="身高"),
marital_status: Optional[str] = Query(None, description="婚姻状态"),
limit: int = Query(10, description="分页大小"),
offset: int = Query(0, description="分页偏移量"),
):
# 解析查询参数为字典
conds = {}
if name:
conds["name"] = name
if gender:
conds["gender"] = gender
if age:
conds["age"] = age
if height:
conds["height"] = height
if marital_status:
conds["marital_status"] = marital_status
logging.info(f"conds: , limit: {limit}, offset: {offset}")
results = []
service = get_people_service()
results, error = service.list(conds, limit=limit, offset=offset)
logging.info(f"query results: {results}")
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
peoples = [people.to_dict() for people in results]
return BaseResponse(error_code=0, error_info="success", data=peoples)
obs_url = obs_util.get_link(file_path)
return BaseResponse(error_code=0, error_info="success", data=obs_url)
class RemarkRequest(BaseModel):
content: str
api.include_router(authorized_router)
# Register custom router
api.include_router(custom_router, dependencies=[Depends(require_auth)])
@api.post("/people/{people_id}/remark")
async def post_remark(people_id: str, request: RemarkRequest):
service = get_people_service()
error = service.save_remark(people_id, request.content)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
# Register people router
api.include_router(people_router, dependencies=[Depends(require_auth)])
# Register user router
api.include_router(user_router)
# Register recognition router
api.include_router(recognition_router, dependencies=[Depends(require_auth)])
@api.delete("/people/{people_id}/remark")
async def delete_remark(people_id: str):
service = get_people_service()
error = service.delete_remark(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")

28
src/web/auth.py Normal file
View File

@@ -0,0 +1,28 @@
from typing import Optional
from fastapi import Cookie, HTTPException, Request
from utils import rldb as rldb_util
from models.user import User, UserTokenRLDBModel, UserRLDBModel
from datetime import datetime
def require_auth(request: Request, token: Optional[str] = Cookie(None)):
if not token:
raise HTTPException(status_code=401, detail="unauthorized")
db = rldb_util.get_instance()
tokens = db.query(UserTokenRLDBModel, token=token, limit=1)
if not tokens:
raise HTTPException(status_code=401, detail="unauthorized")
t = tokens[0]
if getattr(t, 'revoked', False):
raise HTTPException(status_code=401, detail="unauthorized")
if getattr(t, 'expired_at', None) and t.expired_at < datetime.now():
raise HTTPException(status_code=401, detail="unauthorized")
user_orm = db.get(UserRLDBModel, t.user_id)
if not user_orm:
raise HTTPException(status_code=401, detail="unauthorized")
user = User.from_rldb_model(user_orm)
request.state.user_id = user.id
request.state.user_nickname = user.nickname
request.state.user_email = user.email
request.state.user_phone = user.phone
request.state.token = token

151
src/web/custom.py Normal file
View File

@@ -0,0 +1,151 @@
import logging
import os
import uuid
from fastapi import APIRouter, Depends, Request, Query, UploadFile, File
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from models.custom import Custom
from services.custom import get_instance as get_custom_service
from utils.error import ErrorCode
from utils import obs
from web.schemas import BaseResponse
router = APIRouter(tags=["custom"])
class PostCustomRequest(BaseModel):
custom: dict
@router.post("/api/custom")
def create_custom(request: Request, post_custom_request: PostCustomRequest):
logging.debug(f"post_custom_request: {post_custom_request}")
custom = Custom.from_dict(post_custom_request.custom)
# Validate custom data
err = custom.validate()
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
custom.user_id = getattr(request.state, 'user_id', '')
service = get_custom_service()
custom.id, error = service.save(custom)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=custom.id)
@router.put("/api/custom/{custom_id}")
def update_custom(request: Request, custom_id: str, post_custom_request: PostCustomRequest):
logging.debug(f"post_custom_request: {post_custom_request}")
custom = Custom.from_dict(post_custom_request.custom)
custom.id = custom_id
# Validate custom data
err = custom.validate()
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
service = get_custom_service()
# Check permission
res, error = service.get(custom_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
custom.user_id = res.user_id # Ensure user_id is not changed or is set correctly
_, error = service.save(custom)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/custom/{custom_id}")
def delete_custom(request: Request, custom_id: str):
service = get_custom_service()
res, error = service.get(custom_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete(custom_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=custom_id)
@router.get("/api/customs")
def get_customs(request: Request, limit: int = Query(10, ge=1, le=1000), offset: int = Query(0, ge=0)):
service = get_custom_service()
res, error = service.list({'user_id': getattr(request.state, 'user_id', '')}, limit=limit, offset=offset)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
# custom对象转换为字典
customs = [custom.to_dict() for custom in res]
return BaseResponse(error_code=0, error_info="success", data=customs)
@router.get("/api/custom/{custom_id}")
def get_custom(request: Request, custom_id: str):
service = get_custom_service()
res, error = service.get(custom_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
return BaseResponse(error_code=0, error_info="success", data=res.to_dict())
@router.post("/api/custom/{custom_id}/image")
async def post_custom_image(request: Request, custom_id: str, image: UploadFile = File(...)):
# 检查 custom id 是否存在
service = get_custom_service()
custom, err = service.get(custom_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if custom.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"customs/{custom_id}/images/{unique_filename}"
obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
return BaseResponse(error_code=0, error_info="success", data=obs_url)
@router.delete("/api/custom/{custom_id}/image")
async def delete_custom_image(request: Request, custom_id: str, image_url: str):
# 检查 custom id 是否存在
service = get_custom_service()
custom, err = service.get(custom_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if custom.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 检查 image_url 是否是该 custom 名下的图片链接
obs_util = obs.get_instance()
obs_path, err = obs_util.get_obs_path_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if not obs_path.startswith(f"customs/{custom_id}/images/"):
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {custom_id} 名下的图片链接")
# 实现删除图片的处理
# 删除对象存储中的文件
err = obs_util.delete_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
return BaseResponse(error_code=0, error_info="success")

192
src/web/people.py Normal file
View File

@@ -0,0 +1,192 @@
import os
import uuid
import logging
from typing import Optional
from fastapi import APIRouter, Request, UploadFile, File, Query
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from services.people import get_instance as get_people_service
from models.people import People
from utils import obs
from utils.error import ErrorCode
from web.schemas import BaseResponse
router = APIRouter(tags=["people"])
class PostPeopleRequest(BaseModel):
people: dict
@router.post("/api/people")
async def post_people(request: Request, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.user_id = getattr(request.state, 'user_id', '')
service = get_people_service()
people.id, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=people.id)
@router.put("/api/people/{people_id}")
async def update_people(request: Request, people_id: str, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.id = people_id
service = get_people_service()
res, error = service.get(people_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
people.user_id = res.user_id
_, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/people/{people_id}")
async def delete_people(request: Request, people_id: str):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
class GetPeopleRequest(BaseModel):
query: Optional[str] = None
conds: Optional[dict] = None
top_k: int = 5
@router.get("/api/peoples")
async def get_peoples(
request: Request,
name: Optional[str] = Query(None, description="姓名"),
gender: Optional[str] = Query(None, description="性别"),
age: Optional[int] = Query(None, description="年龄"),
height: Optional[int] = Query(None, description="身高"),
marital_status: Optional[str] = Query(None, description="婚姻状态"),
limit: int = Query(10, description="分页大小"),
offset: int = Query(0, description="分页偏移量"),
):
# 解析查询参数为字典
conds = {}
conds["user_id"] = getattr(request.state, 'user_id', '')
if name:
conds["name"] = name
if gender:
conds["gender"] = gender
if age:
conds["age"] = age
if height:
conds["height"] = height
if marital_status:
conds["marital_status"] = marital_status
logging.info(f"conds: , limit: {limit}, offset: {offset}")
results = []
service = get_people_service()
results, error = service.list(conds, limit=limit, offset=offset)
logging.info(f"query results: {results}")
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
peoples = [people.to_dict() for people in results]
return BaseResponse(error_code=0, error_info="success", data=peoples)
class RemarkRequest(BaseModel):
content: str
@router.post("/api/people/{people_id}/remark")
async def post_people_remark(request: Request, people_id: str, body: RemarkRequest):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.save_remark(people_id, body.content)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/people/{people_id}/remark")
async def delete_people_remark(request: Request, people_id: str):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete_remark(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.post("/api/people/{people_id}/image")
async def post_people_image(request: Request, people_id: str, image: UploadFile = File(...)):
# 检查 people id 是否存在
service = get_people_service()
people, err = service.get(people_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if people.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"peoples/{people_id}/images/{unique_filename}"
obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
return BaseResponse(error_code=0, error_info="success", data=obs_url)
@router.delete("/api/people/{people_id}/image")
async def delete_people_image(request: Request, people_id: str, image_url: str):
# 检查 people id 是否存在
service = get_people_service()
people, err = service.get(people_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if people.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 检查 image_url 是否是该 people 名下的图片链接
obs_util = obs.get_instance()
obs_path, err = obs_util.get_obs_path_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if not obs_path.startswith(f"peoples/{people_id}/images/"):
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {people_id} 名下的图片链接")
# 实现删除图片的处理
# 删除对象存储中的文件
err = obs_util.delete_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
return BaseResponse(error_code=0, error_info="success")

91
src/web/recognition.py Normal file
View File

@@ -0,0 +1,91 @@
import os
import uuid
import logging
from typing import Optional
from fastapi import APIRouter, UploadFile, File
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from models.people import People
from models.custom import Custom
from agents.extract_people_agent import ExtractPeopleAgent
from agents.extract_custom_agent import ExtractCustomAgent
from utils import obs, ocr
from web.schemas import BaseResponse
from utils.error import ErrorCode
router = APIRouter(tags=["recognition"])
def extract_people(text: str, cover_link: str = None) -> People:
extra_agent = ExtractPeopleAgent()
people = extra_agent.extract_people_info(text)
if people:
people.cover = cover_link
logging.info(f"people: {people}")
return people
def extract_custom(text: str, image_link: str = None) -> Custom:
extra_agent = ExtractCustomAgent()
custom = extra_agent.extract_custom_info(text)
if custom:
if image_link:
custom.images = [image_link]
logging.info(f"custom: {custom}")
return custom
class PostInputRequest(BaseModel):
text: str
@router.post("/api/recognition/{model}/input")
async def post_recognition_input(model: str, request: PostInputRequest):
if model == "people":
result = await run_in_threadpool(extract_people, request.text)
elif model == "custom":
result = await run_in_threadpool(extract_custom, request.text)
else:
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
if result is None:
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
resp = BaseResponse(error_code=0, error_info="success")
resp.data = result.to_dict()
return resp
@router.post("/api/recognition/{model}/image")
async def post_recognition_image(model: str, image: UploadFile = File(...)):
if model not in ["people", "custom"]:
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"uploads/{model}/{unique_filename}"
obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
# 调用OCR处理图片
ocr_util = ocr.get_instance()
ocr_result = await run_in_threadpool(ocr_util.recognize_image_text, obs_url)
logging.info(f"ocr_result: {ocr_result}")
if model == "people":
result = await run_in_threadpool(extract_people, ocr_result, obs_url)
elif model == "custom":
result = await run_in_threadpool(extract_custom, ocr_result, obs_url)
if result is None:
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
resp = BaseResponse(error_code=0, error_info="success")
resp.data = result.to_dict()
return resp

7
src/web/schemas.py Normal file
View File

@@ -0,0 +1,7 @@
from typing import Any, Optional
from pydantic import BaseModel
class BaseResponse(BaseModel):
error_code: int
error_info: str
data: Optional[Any] = None

223
src/web/user.py Normal file
View File

@@ -0,0 +1,223 @@
import os
import time
import logging
from typing import Optional, Literal
from fastapi import APIRouter, Depends, Request, HTTPException, Response, UploadFile, File
from pydantic import BaseModel
from services.user import get_instance as get_user_service
from web.auth import require_auth
from utils import obs
from utils.config import get_instance as get_config
from web.schemas import BaseResponse
router = APIRouter(tags=["user"])
class SendCodeRequest(BaseModel):
target_type: str
target: str
scene: Literal['register', 'update']
# scene: Literal['register', 'login']
@router.post("/api/user/send_code")
async def send_user_code(request: SendCodeRequest):
service = get_user_service()
err = service.send_code(request.target_type, request.target, request.scene)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success")
class RegisterRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
password: str
code: str
@router.post("/api/user")
async def user_register(request: RegisterRequest):
service = get_user_service()
from models.user import User
u = User(
nickname=request.nickname or "",
avatar_link=request.avatar_link or "",
email=request.email or "",
phone=request.phone or "",
password_hash=request.password,
)
uid, err = service.register(u, request.code)
if not err.success:
logging.error(f"register failed: {err}")
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data=uid)
class LoginRequest(BaseModel):
email: Optional[str] = None
phone: Optional[str] = None
password: str
@router.post("/api/user/login")
async def user_login(request: LoginRequest, response: Response):
service = get_user_service()
data, err = service.login(request.email, request.phone, request.password)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
ttl_days = conf.getint('auth', 'token_ttl_days', fallback=30)
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
cookie_secure = conf.getboolean('auth', 'cookie_secure', fallback=False)
cookie_samesite = conf.get('auth', 'cookie_samesite', fallback=None)
response.set_cookie(
key="token",
value=data.get('token', ''),
max_age=ttl_days * 24 * 3600,
httponly=True,
secure=cookie_secure,
samesite=cookie_samesite,
domain=cookie_domain,
path="/",
)
return BaseResponse(error_code=0, error_info="success", data={"expired_at": data.get('expired_at')})
@router.delete("/api/user/me/login", dependencies=[Depends(require_auth)])
async def user_logout(response: Response, request: Request):
service = get_user_service()
err = service.logout(getattr(request.state, 'token', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/user/me", dependencies=[Depends(require_auth)])
async def user_delete(response: Response, request: Request):
service = get_user_service()
err = service.delete_user(getattr(request.state, 'user_id', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@router.get("/api/user/me", dependencies=[Depends(require_auth)])
async def user_me(request: Request):
service = get_user_service()
user, err = service.get(getattr(request.state, 'user_id', None))
if not err.success or not user:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateMeRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
phone: Optional[str] = None
email: Optional[str] = None
@router.put("/api/user/me", dependencies=[Depends(require_auth)])
async def update_user_me(request: Request, body: UpdateMeRequest):
service = get_user_service()
user, err = service.update_profile(
getattr(request.state, 'user_id', None),
nickname=body.nickname,
avatar_link=body.avatar_link,
phone=body.phone,
email=body.email,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
@router.put("/api/user/me/avatar", dependencies=[Depends(require_auth)])
async def upload_avatar(request: Request, avatar: UploadFile = File(...)):
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="unauthorized")
file_extension = os.path.splitext(avatar.filename)[1]
timestamp = int(time.time())
avatar_path = f"users/{user_id}/avatar-{timestamp}{file_extension}"
try:
obs_util = obs.get_instance()
obs_util.Put(avatar_path, await avatar.read())
avatar_url = obs_util.Link(avatar_path)
user_service = get_user_service()
_, err = user_service.update_profile(user_id, avatar_link=avatar_url)
if not err.success:
raise HTTPException(status_code=500, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data={"avatar_link": avatar_url})
except Exception as e:
logging.error(f"upload avatar failed: {e}")
raise HTTPException(status_code=500, detail="upload avatar failed")
class UpdatePhoneRequest(BaseModel):
phone: str
code: str
@router.put("/api/user/me/phone", dependencies=[Depends(require_auth)])
async def update_user_phone(request: Request, body: UpdatePhoneRequest):
service = get_user_service()
user, err = service.update_phone_with_code(
getattr(request.state, 'user_id', None),
body.phone,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateEmailRequest(BaseModel):
email: str
code: str
@router.put("/api/user/me/email", dependencies=[Depends(require_auth)])
async def update_user_email(request: Request, body: UpdateEmailRequest):
service = get_user_service()
user, err = service.update_email_with_code(
getattr(request.state, 'user_id', None),
body.email,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)