1 Commits

Author SHA1 Message Date
8174c4cfe5 Release v0.1 2025-11-12 23:54:02 +08:00
21 changed files with 170 additions and 1930 deletions

View File

@@ -1,85 +0,0 @@
import datetime
import json
import logging
from langchain.prompts import ChatPromptTemplate
from .base_agent import BaseAgent
from models.custom import Custom
class ExtractCustomAgent(BaseAgent):
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
super().__init__(api_url, api_key, model_name)
self.prompt = ChatPromptTemplate.from_messages([
(
"system",
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')}"
"你是一个专业的客户信息录入助手,善于从一段文字描述中,精确获取客户的以下信息:\n"
"姓名 name\n"
"性别 gender (男/女/未知)\n"
"出生年份 birth (整数年份,如 1990若文本只提供了年龄请根据当前日期计算出出生年份)\n"
"手机号 phone\n"
"邮箱 email\n"
"身高(cm) height (整数)\n"
"体重(kg) weight (整数)\n"
"学历 degree\n"
"毕业院校 academy\n"
"职业 occupation\n"
"年收入(万) income (整数)\n"
"资产(万) assets (整数)\n"
"流动资产(万) current_assets (整数)\n"
"房产情况 house (必须为以下之一: '有房无贷', '有房有贷', '无自有房', 若未提及则不填)\n"
"车辆情况 car (必须为以下之一: '有车无贷', '有车有贷', '无自有车', 若未提及则不填)\n"
"户口城市 registered_city\n"
"居住城市 live_city\n"
"籍贯 native_place\n"
"原生家庭情况 original_family\n"
"是否独生子女 is_single_child (true/false)\n"
"择偶要求 match_requirement\n"
"\n"
"以上信息需要严格按照 JSON 格式输出,字段名与条目中英文保持一致。\n"
"若未识别到某项,则不返回该字段,不要自行填写“未知”、“未填写”等。\n"
"\n"
"除了上述基本信息,还有一个字段:\n"
"其他介绍 introductions\n"
"其余的信息需要按照字典的方式进行提炼和总结,都放在 introductions 字段中key 使用提炼好的中文。\n"
),
("human", "{input}")
])
def extract_custom_info(self, text: str) -> Custom:
"""从文本中提取客户信息"""
prompt = self.prompt.format_prompt(input=text)
response = self.llm.invoke(prompt)
logging.info(f"llm response: {response.content}")
try:
custom_dict = json.loads(response.content)
# 类型安全转换防止LLM返回字符串类型的数字
int_fields = ['birth', 'height', 'weight', 'income', 'assets', 'scores', 'current_assets']
for field in int_fields:
if field in custom_dict and isinstance(custom_dict[field], str):
try:
# 尝试提取数字,简单处理
import re
num = re.findall(r'\d+', custom_dict[field])
if num:
custom_dict[field] = int(num[0])
else:
del custom_dict[field] # 无法转换则移除
except:
del custom_dict[field]
custom = Custom.from_dict(custom_dict)
err = custom.validate()
if not err.success:
logging.warning(f"Validation warning: {err.info}")
# 即使校验失败(如某些必填项缺失),也尽可能返回已提取的对象,
# 让上层业务逻辑决定是否接受或需要补充
return custom
except json.JSONDecodeError:
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
return None
except Exception as e:
logging.error(f"Failed to process custom info: {e}")
return None

View File

@@ -21,7 +21,7 @@ class ExtractPeopleAgent(BaseAgent):
"身高(cm) height\n" "身高(cm) height\n"
"婚姻状况 marital_status\n" "婚姻状况 marital_status\n"
"择偶要求 match_requirement\n" "择偶要求 match_requirement\n"
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致; 若未识别到以上的某项,则不返回该字段,不要自行填写“未知”,“未填写”等类似字眼\n" "以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n"
"其中,'年龄 age''身高(cm) height' 必须是一个整数,不能是一个字符串;\n" "其中,'年龄 age''身高(cm) height' 必须是一个整数,不能是一个字符串;\n"
"并且,'性别 gender' 根据识别结果,必须从 男,女,未知 三选一填写。\n" "并且,'性别 gender' 根据识别结果,必须从 男,女,未知 三选一填写。\n"
"除了上述基本信息,还有一个字段\n" "除了上述基本信息,还有一个字段\n"

View File

@@ -1,36 +1,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# created by mmmy on 2025-09-27 # created by mmmy on 2025-09-27
import os import os
import sys
# Add src directory to sys.path to ensure modules can be imported correctly when running with uvicorn
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import argparse import argparse
import uvicorn import uvicorn
from services import people as people_service from services import people as people_service
from services import user as user_service from utils import config, logger, obs, ocr, rldb
from services import custom as custom_service
from utils import config, logger, obs, ocr, rldb, sms, mailer
from web.api import api from web.api import api
def initialize_app(config_path):
"""Initialize application components with the given config path."""
config.init(config_path)
conf = config.get_instance()
logger.init()
rldb.init()
ocr.init()
obs.init()
mailer.init(conf.get('mailer', 'type', fallback='real'))
sms.init(conf.get('sms', 'type', fallback='real'))
people_service.init()
user_service.init()
custom_service.init()
# 主函数 # 主函数
def main(): def main():
main_path = os.path.dirname(os.path.abspath(__file__)) main_path = os.path.dirname(os.path.abspath(__file__))
@@ -38,19 +15,21 @@ def main():
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径') parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
args = parser.parse_args() args = parser.parse_args()
initialize_app(args.config) config.init(args.config)
logger.init()
rldb.init()
ocr.init()
obs.init()
people_service.init()
conf = config.get_instance() conf = config.get_instance()
host = conf.get('web_service', 'server_host', fallback='0.0.0.0') host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
port = conf.getint('web_service', 'server_port', fallback=8099) port = conf.getint('web_service', 'server_port', fallback=8099)
uvicorn.run("src.main:api", host=host, port=port, reload=True) # Modified to string import for reload support in main too, though api object also works uvicorn.run(api, host=host, port=port)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
else:
# Support for running via 'uvicorn src.main:api'
# Use environment variable for config path or default
main_path = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(main_path, '../configuration/test_conf.ini')
config_path = os.environ.get('IFU_CONFIG_PATH', default_config_path)
initialize_app(config_path)

View File

@@ -1,291 +0,0 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-11-27
import json
import logging
from typing import Dict, List
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, func, Boolean
from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error
class CustomRLDBModel(RLDBBaseModel):
"""
客户数据的数据库模型 (SQLAlchemy Model) - 更新版
"""
__tablename__ = 'customs'
id = Column(String(36), primary_key=True)
user_id = Column(String(36), index=True, nullable=False)
# 基本信息
name = Column(String(255), index=True, nullable=False)
gender = Column(String(10), nullable=False)
birth = Column(Integer, nullable=False) # 出生年份
phone = Column(String(50), index=True)
email = Column(String(255), index=True)
# 外貌信息
height = Column(Integer)
weight = Column(Integer)
images = Column(Text) # JSON string for list[str]
scores = Column(Integer)
# 学历职业
degree = Column(String(255))
academy = Column(String(255))
occupation = Column(String(255))
income = Column(Integer) # 单位:万
assets = Column(Integer) # 单位:万
current_assets = Column(Integer) # 单位:万
house = Column(String(50))
car = Column(String(50))
# 户口家庭
registered_city = Column(String(255))
live_city = Column(String(255))
native_place = Column(String(255))
original_family = Column(Text)
is_single_child = Column(Boolean, default=False)
match_requirement = Column(Text)
introductions = Column(Text) # JSON string for Dict[str, str]
# 客户信息
custom_level = Column(String(255))
comments = Column(Text) # JSON string for Dict[str, str]
is_public = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
class Custom:
"""
客户数据的业务逻辑模型 (Business Logic Model) - 更新版
"""
id: str
user_id: str
# 基本信息
name: str
gender: str
birth: int
phone: str
email: str
# 外貌信息
height: int
weight: int
images: List[str]
scores: int
# 学历职业
degree: str
academy: str
occupation: str
income: int
assets: int
current_assets: int
house: str
car: str
# 户口家庭
registered_city: str
live_city: str
native_place: str
original_family: str
is_single_child: bool
match_requirement: str
introductions: Dict[str, str]
# 客户信息
custom_level: str
comments: Dict[str, str]
is_public: bool
created_at: datetime = None
def __init__(self, **kwargs):
# 初始化所有属性
self.id = kwargs.get('id', '')
self.user_id = kwargs.get('user_id', '')
self.name = kwargs.get('name', '')
self.gender = kwargs.get('gender', '未知')
self.birth = kwargs.get('birth', 0)
self.phone = kwargs.get('phone', '')
self.email = kwargs.get('email', '')
self.height = kwargs.get('height', 0)
self.weight = kwargs.get('weight', 0)
self.images = kwargs.get('images', [])
self.scores = kwargs.get('scores', 0)
self.degree = kwargs.get('degree', '')
self.academy = kwargs.get('academy', '')
self.occupation = kwargs.get('occupation', '')
self.income = kwargs.get('income', 0)
self.assets = kwargs.get('assets', 0)
self.current_assets = kwargs.get('current_assets', 0)
self.house = kwargs.get('house', '')
self.car = kwargs.get('car', '')
self.registered_city = kwargs.get('registered_city', '')
self.live_city = kwargs.get('live_city', '')
self.native_place = kwargs.get('native_place', '')
self.original_family = kwargs.get('original_family', '')
self.is_single_child = kwargs.get('is_single_child', False)
self.match_requirement = kwargs.get('match_requirement', '')
self.introductions = kwargs.get('introductions', {})
self.custom_level = kwargs.get('custom_level', '')
self.comments = kwargs.get('comments', {})
self.is_public = kwargs.get('is_public', False)
self.created_at = kwargs.get('created_at')
def __str__(self) -> str:
# 返回对象的字符串表示
attributes = ", ".join(f"{k}={v}" for k, v in self.to_dict().items())
return f"Custom({attributes})"
@classmethod
def from_dict(cls, data: dict):
# 从字典创建对象实例
if 'created_at' in data and data['created_at'] is not None:
data['created_at'] = datetime.fromtimestamp(data['created_at'])
# 移除ORM特有的时间戳避免初始化错误
data.pop('updated_at', None)
data.pop('deleted_at', None)
return cls(**data)
def to_dict(self) -> dict:
# 将对象转换为字典
return {
'id': self.id,
'user_id': self.user_id,
'name': self.name,
'gender': self.gender,
'birth': self.birth,
'phone': self.phone,
'email': self.email,
'height': self.height,
'weight': self.weight,
'images': self.images,
'scores': self.scores,
'degree': self.degree,
'academy': self.academy,
'occupation': self.occupation,
'income': self.income,
'assets': self.assets,
'current_assets': self.current_assets,
'house': self.house,
'car': self.car,
'registered_city': self.registered_city,
'live_city': self.live_city,
'native_place': self.native_place,
'original_family': self.original_family,
'is_single_child': self.is_single_child,
'match_requirement': self.match_requirement,
'introductions': self.introductions,
'custom_level': self.custom_level,
'comments': self.comments,
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
}
@classmethod
def from_rldb_model(cls, data: CustomRLDBModel):
# 从数据库模型转换
return cls(
id=data.id,
user_id=data.user_id,
name=data.name,
gender=data.gender,
birth=data.birth,
phone=data.phone,
email=data.email,
height=data.height,
weight=data.weight,
images=json.loads(data.images) if data.images else [],
scores=data.scores,
degree=data.degree,
academy=data.academy,
occupation=data.occupation,
income=data.income,
assets=data.assets,
house=data.house,
car=data.car,
registered_city=data.registered_city,
live_city=data.live_city,
native_place=data.native_place,
original_family=data.original_family,
is_single_child=data.is_single_child,
match_requirement=data.match_requirement,
introductions=json.loads(data.introductions) if data.introductions else {},
custom_level=data.custom_level,
comments=json.loads(data.comments) if data.comments else {},
is_public=data.is_public,
created_at=data.created_at,
)
def to_rldb_model(self) -> CustomRLDBModel:
# 转换为数据库模型
return CustomRLDBModel(
id=self.id,
user_id=self.user_id,
name=self.name,
gender=self.gender,
birth=self.birth,
phone=self.phone,
email=self.email,
height=self.height,
weight=self.weight,
images=json.dumps(self.images, ensure_ascii=False),
scores=self.scores,
degree=self.degree,
academy=self.academy,
occupation=self.occupation,
income=self.income,
assets=self.assets,
current_assets=self.current_assets,
house=self.house,
car=self.car,
registered_city=self.registered_city,
live_city=self.live_city,
native_place=self.native_place,
original_family=self.original_family,
is_single_child=self.is_single_child,
match_requirement=self.match_requirement,
introductions=json.dumps(self.introductions, ensure_ascii=False),
custom_level=self.custom_level,
comments=json.dumps(self.comments, ensure_ascii=False),
is_public=self.is_public,
)
def validate(self) -> error:
# 数据校验逻辑
if not self.name:
return error(ErrorCode.INVALID_PARAMS, "Name cannot be empty.")
if not self.gender:
return error(ErrorCode.INVALID_PARAMS, "Gender cannot be empty.")
if self.gender not in ['', '']:
return error(ErrorCode.INVALID_PARAMS, "Gender must be '' or ''.")
current_year = datetime.now().year
min_birth_year = 1950
max_birth_year = current_year - 18
if not isinstance(self.birth, int):
return error(ErrorCode.INVALID_PARAMS, "Birth year must be an integer.")
if self.birth < min_birth_year or self.birth > max_birth_year:
return error(ErrorCode.INVALID_PARAMS, f"Birth year must be between {min_birth_year} and {max_birth_year}.")
valid_houses = ["", "有房无贷", "有房有贷", "无自有房"]
if self.house not in valid_houses:
return error(ErrorCode.INVALID_PARAMS, f"House must be one of {valid_houses}")
valid_cars = ["", "有车无贷", "有车有贷", "无自有车"]
if self.car not in valid_cars:
return error(ErrorCode.INVALID_PARAMS, f"Car must be one of {valid_cars}")
# ... 可根据需要添加更多校验 ...
return error(ErrorCode.SUCCESS, "")

View File

@@ -4,7 +4,6 @@
import json import json
import logging import logging
from typing import Dict from typing import Dict
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, func from sqlalchemy import Column, Integer, String, Text, DateTime, func
from utils.rldb import RLDBBaseModel from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error from utils.error import ErrorCode, error
@@ -12,7 +11,6 @@ from utils.error import ErrorCode, error
class PeopleRLDBModel(RLDBBaseModel): class PeopleRLDBModel(RLDBBaseModel):
__tablename__ = 'peoples' __tablename__ = 'peoples'
id = Column(String(36), primary_key=True) id = Column(String(36), primary_key=True)
user_id = Column(String(36), index=True)
name = Column(String(255), index=True) name = Column(String(255), index=True)
contact = Column(String(255), index=True) contact = Column(String(255), index=True)
gender = Column(String(10)) gender = Column(String(10))
@@ -28,42 +26,9 @@ class PeopleRLDBModel(RLDBBaseModel):
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
class Comment:
# 评论内容
content: str
# 评论人
author: str
# 创建时间
created_at: datetime
# 更新时间
updated_at: datetime
def __init__(self, **kwargs):
self.content = kwargs.get('content', '')
self.author = kwargs.get('author', '')
self.created_at = kwargs.get('created_at', datetime.now())
self.updated_at = kwargs.get('updated_at', datetime.now())
def to_dict(self) -> dict:
return {
'content': self.content,
'author': self.author,
'created_at': int(self.created_at.timestamp()),
'updated_at': int(self.updated_at.timestamp()),
}
@classmethod
def from_dict(cls, data: dict):
data['created_at'] = datetime.fromtimestamp(data['created_at'])
data['updated_at'] = datetime.fromtimestamp(data['updated_at'])
return cls(**data)
class People: class People:
# 数据库 ID # 数据库 ID
id: str id: str
# 所属用户 ID
user_id: str
# 姓名 # 姓名
name: str name: str
# 联系人 # 联系人
@@ -81,16 +46,13 @@ class People:
# 个人介绍 # 个人介绍
introduction: Dict[str, str] introduction: Dict[str, str]
# 总结评价 # 总结评价
comments: Dict[str, "Comment"] comments: Dict[str, str]
# 封面 # 封面
cover: str = None cover: str = None
# 创建时间
created_at: datetime = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
# 初始化所有属性从kwargs中获取值如果不存在则设置默认值 # 初始化所有属性从kwargs中获取值如果不存在则设置默认值
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else '' self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.user_id = kwargs.get('user_id', '') if kwargs.get('user_id', '') is not None else ''
self.name = kwargs.get('name', '') if kwargs.get('name', '') is not None else '' self.name = kwargs.get('name', '') if kwargs.get('name', '') is not None else ''
self.contact = kwargs.get('contact', '') if kwargs.get('contact', '') is not None else '' self.contact = kwargs.get('contact', '') if kwargs.get('contact', '') is not None else ''
self.gender = kwargs.get('gender', '') if kwargs.get('gender', '') is not None else '' self.gender = kwargs.get('gender', '') if kwargs.get('gender', '') is not None else ''
@@ -101,17 +63,19 @@ class People:
self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {} self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {}
self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {} self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {}
self.cover = kwargs.get('cover', None) if kwargs.get('cover', None) is not None else None self.cover = kwargs.get('cover', None) if kwargs.get('cover', None) is not None else None
self.created_at = kwargs.get('created_at', None)
def __str__(self) -> str: def __str__(self) -> str:
# 返回对象的字符串表示,包含所有属性 # 返回对象的字符串表示,包含所有属性
return (f"People(id={self.id}, name={self.name}, contact={self.contact}, gender={self.gender}, " return (f"People(id={self.id}, name={self.name}, contact={self.contact}, gender={self.gender}, "
f"age={self.age}, height={self.height}, marital_status={self.marital_status}, " f"age={self.age}, height={self.height}, marital_status={self.marital_status}, "
f"match_requirement={self.match_requirement}, introduction={self.introduction}, " f"match_requirement={self.match_requirement}, introduction={self.introduction}, "
f"comments={self.comments}, cover={self.cover}, created_at={self.created_at})") f"comments={self.comments}, cover={self.cover})")
@classmethod @classmethod
def from_dict(cls, data: dict): def from_dict(cls, data: dict):
if 'created_at' in data:
# 移除 created_at 字段,避免类型错误
del data['created_at']
if 'updated_at' in data: if 'updated_at' in data:
# 移除 updated_at 字段,避免类型错误 # 移除 updated_at 字段,避免类型错误
del data['updated_at'] del data['updated_at']
@@ -125,7 +89,6 @@ class People:
# 将关系数据库模型转换为对象 # 将关系数据库模型转换为对象
return cls( return cls(
id=data.id, id=data.id,
user_id=data.user_id,
name=data.name, name=data.name,
contact=data.contact, contact=data.contact,
gender=data.gender, gender=data.gender,
@@ -134,16 +97,14 @@ class People:
marital_status=data.marital_status, marital_status=data.marital_status,
match_requirement=data.match_requirement, match_requirement=data.match_requirement,
introduction=json.loads(data.introduction) if data.introduction else {}, introduction=json.loads(data.introduction) if data.introduction else {},
comments={k: Comment.from_dict(v) for k, v in json.loads(data.comments).items()} if data.comments else {}, comments=json.loads(data.comments) if data.comments else {},
cover=data.cover, cover=data.cover,
created_at=data.created_at,
) )
def to_dict(self) -> dict: def to_dict(self) -> dict:
# 将对象转换为字典格式 # 将对象转换为字典格式
return { return {
'id': self.id, 'id': self.id,
'user_id': self.user_id,
'name': self.name, 'name': self.name,
'contact': self.contact, 'contact': self.contact,
'gender': self.gender, 'gender': self.gender,
@@ -152,16 +113,14 @@ class People:
'marital_status': self.marital_status, 'marital_status': self.marital_status,
'match_requirement': self.match_requirement, 'match_requirement': self.match_requirement,
'introduction': self.introduction, 'introduction': self.introduction,
'comments': {k: v.to_dict() for k, v in self.comments.items()}, 'comments': self.comments,
'cover': self.cover, 'cover': self.cover,
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
} }
def to_rldb_model(self) -> PeopleRLDBModel: def to_rldb_model(self) -> PeopleRLDBModel:
# 将对象转换为关系数据库模型 # 将对象转换为关系数据库模型
return PeopleRLDBModel( return PeopleRLDBModel(
id=self.id, id=self.id,
user_id=self.user_id,
name=self.name, name=self.name,
contact=self.contact, contact=self.contact,
gender=self.gender, gender=self.gender,
@@ -170,22 +129,22 @@ class People:
marital_status=self.marital_status, marital_status=self.marital_status,
match_requirement=self.match_requirement, match_requirement=self.match_requirement,
introduction=json.dumps(self.introduction, ensure_ascii=False), introduction=json.dumps(self.introduction, ensure_ascii=False),
comments=json.dumps({k: v.to_dict() for k, v in self.comments.items()}, ensure_ascii=False), comments=json.dumps(self.comments, ensure_ascii=False),
cover=self.cover, cover=self.cover,
) )
def validate(self) -> error: def validate(self) -> error:
err = error(ErrorCode.SUCCESS, "") err = error(ErrorCode.SUCCESS, "")
if not self.name: if not self.name:
logging.error("Name is required, use default") logging.error("Name is required")
self.name = "" err = error(ErrorCode.MODEL_ERROR, "Name is required")
if not self.gender in ['', '', '未知']: if not self.gender in ['', '', '未知']:
logging.error("Gender must be '', '', or '未知', use default") logging.error("Gender must be '', '', or '未知'")
self.gender = "未知" err = error(ErrorCode.MODEL_ERROR, "Gender must be '', '', or '未知'")
if not isinstance(self.age, int) or self.age < 0: if not isinstance(self.age, int) or self.age <= 0:
logging.error("Age must be an integer and greater than 0, use default") logging.error("Age must be an integer and greater than 0")
self.age = 0 err = error(ErrorCode.MODEL_ERROR, "Age must be an integer and greater than 0")
if not isinstance(self.height, int) or self.height < 0: if not isinstance(self.height, int) or self.height <= 0:
logging.error("Height must be an integer and greater than 0, use default") logging.error("Height must be an integer and greater than 0")
self.height = 0 err = error(ErrorCode.MODEL_ERROR, "Height must be an integer and greater than 0")
return err return err

View File

@@ -1,184 +0,0 @@
from typing import Optional
import json
from datetime import datetime, timedelta
from sqlalchemy import Column, String, Text, DateTime, Integer, Boolean, func, UniqueConstraint
from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error
class UserRLDBModel(RLDBBaseModel):
__tablename__ = 'users'
id = Column(String(36), primary_key=True)
nickname = Column(String(255))
avatar_link = Column(String(255))
email = Column(String(127), unique=True, index=True)
phone = Column(String(32), unique=True, index=True)
password_hash = Column(String(255))
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
class VerificationCodeRLDBModel(RLDBBaseModel):
__tablename__ = 'verification_codes'
id = Column(String(36), primary_key=True)
target_type = Column(String(16))
target = Column(String(255), index=True)
code = Column(String(16))
scene = Column(String(32))
expires_at = Column(DateTime(timezone=True))
used_at = Column(DateTime(timezone=True), nullable=True)
class UserTokenRLDBModel(RLDBBaseModel):
__tablename__ = 'user_tokens'
id = Column(String(36), primary_key=True)
user_id = Column(String(36), index=True)
token = Column(Text)
expired_at = Column(DateTime(timezone=True))
revoked = Column(Boolean, default=False)
class User:
id: str
nickname: str
avatar_link: str
email: str
phone: str
password_hash: str
created_at: datetime = None
def __init__(self, **kwargs):
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.nickname = kwargs.get('nickname', '') if kwargs.get('nickname', '') is not None else ''
self.avatar_link = kwargs.get('avatar_link', '') if kwargs.get('avatar_link', '') is not None else ''
self.email = kwargs.get('email', '') if kwargs.get('email', '') is not None else ''
self.phone = kwargs.get('phone', '') if kwargs.get('phone', '') is not None else ''
self.password_hash = kwargs.get('password_hash', '') if kwargs.get('password_hash', '') is not None else ''
self.created_at = kwargs.get('created_at', None)
def __str__(self) -> str:
return (f"User(id={self.id}, nickname={self.nickname}, avatar_link={self.avatar_link}, "
f"email={self.email}, phone={self.phone}, created_at={self.created_at})")
@classmethod
def from_dict(cls, data: dict):
if 'updated_at' in data:
del data['updated_at']
if 'deleted_at' in data:
del data['deleted_at']
return cls(**data)
@classmethod
def from_rldb_model(cls, data: UserRLDBModel):
return cls(
id=data.id,
nickname=data.nickname,
avatar_link=data.avatar_link,
email=data.email,
phone=data.phone,
password_hash=data.password_hash,
created_at=data.created_at,
)
def to_dict(self) -> dict:
return {
'id': self.id,
'nickname': self.nickname,
'avatar_link': self.avatar_link,
'email': self.email,
'phone': self.phone,
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
}
def to_rldb_model(self) -> UserRLDBModel:
return UserRLDBModel(
id=self.id,
nickname=self.nickname,
avatar_link=self.avatar_link,
email=self.email,
phone=self.phone,
password_hash=self.password_hash,
)
def validate(self) -> error:
err = error(ErrorCode.SUCCESS, "")
if not self.email and not self.phone:
return error(ErrorCode.MODEL_ERROR, "email or phone required")
return err
class VerificationCode:
id: str
target_type: str
target: str
code: str
scene: str
expires_at: datetime
used_at: Optional[datetime] = None
def __init__(self, **kwargs):
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.target_type = kwargs.get('target_type', '')
self.target = kwargs.get('target', '')
self.code = kwargs.get('code', '')
self.scene = kwargs.get('scene', '')
self.expires_at = kwargs.get('expires_at')
self.used_at = kwargs.get('used_at', None)
@classmethod
def from_rldb_model(cls, data: VerificationCodeRLDBModel):
return cls(
id=data.id,
target_type=data.target_type,
target=data.target,
code=data.code,
scene=data.scene,
expires_at=data.expires_at,
used_at=data.used_at,
)
def to_rldb_model(self) -> VerificationCodeRLDBModel:
return VerificationCodeRLDBModel(
id=self.id,
target_type=self.target_type,
target=self.target,
code=self.code,
scene=self.scene,
expires_at=self.expires_at,
used_at=self.used_at,
)
class UserToken:
id: str
user_id: str
token: str
expired_at: datetime
revoked: bool
def __init__(self, **kwargs):
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.user_id = kwargs.get('user_id', '')
self.token = kwargs.get('token', '')
self.expired_at = kwargs.get('expired_at')
self.revoked = kwargs.get('revoked', False)
@classmethod
def from_rldb_model(cls, data: UserTokenRLDBModel):
return cls(
id=data.id,
user_id=data.user_id,
token=data.token,
expired_at=data.expired_at,
revoked=data.revoked,
)
def to_rldb_model(self) -> UserTokenRLDBModel:
return UserTokenRLDBModel(
id=self.id,
user_id=self.user_id,
token=self.token,
expired_at=self.expired_at,
revoked=self.revoked,
)

View File

@@ -1,98 +0,0 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-11-27
import logging
import uuid
from models.custom import Custom, CustomRLDBModel
from utils.error import ErrorCode, error
from utils import rldb
class CustomService:
def __init__(self):
self.rldb = rldb.get_instance()
def save(self, custom: Custom) -> (str, error):
"""
保存客户到数据库。
如果 custom.id 存在,则更新;否则,创建。
:param custom: 客户对象
:return: 客户ID 和 错误对象
"""
# 0. 生成 custom id
custom.id = custom.id if custom.id else uuid.uuid4().hex
# 1. 转换模型,并保存到 SQL 数据库
try:
custom_orm = custom.to_rldb_model()
self.rldb.upsert(custom_orm)
return custom.id, error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to save custom {custom.id}: {e}")
return "", error(ErrorCode.RLDB_ERROR, f"Failed to save custom data: {str(e)}")
def delete(self, custom_id: str) -> error:
"""
从数据库删除客户。
:param custom_id: 客户ID
:return: 错误对象
"""
try:
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
if not custom_orm:
return error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
self.rldb.delete(custom_orm)
return error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to delete custom {custom_id}: {e}")
return error(ErrorCode.RLDB_ERROR, f"Failed to delete custom data: {str(e)}")
def get(self, custom_id: str) -> (Custom, error):
"""
从数据库获取单个客户。
:param custom_id: 客户ID
:return: 客户对象 和 错误对象
"""
try:
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
if not custom_orm:
return None, error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
custom = Custom.from_rldb_model(custom_orm)
return custom, error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to get custom {custom_id}: {e}")
return None, error(ErrorCode.RLDB_ERROR, f"Failed to retrieve custom data: {str(e)}")
def list(self, conds: dict = None, limit: int = 10, offset: int = 0) -> (list[Custom], error):
"""
根据条件从数据库列出客户(支持分页)。
:param conds: 查询条件字典
:param limit: 每页数量
:param offset: 偏移量
:return: 客户对象列表 和 错误对象
"""
if conds is None:
conds = {}
try:
custom_orms = self.rldb.query(CustomRLDBModel, limit=limit, offset=offset, **conds)
customs = [Custom.from_rldb_model(orm) for orm in custom_orms]
return customs, error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to list customs with conds {conds}: {e}")
return [], error(ErrorCode.RLDB_ERROR, f"Failed to list custom data: {str(e)}")
# --- Singleton Pattern ---
custom_service = None
def init():
"""初始化 CustomService 单例"""
global custom_service
custom_service = CustomService()
def get_instance() -> CustomService:
"""获取 CustomService 单例"""
return custom_service

View File

@@ -1,10 +1,8 @@
import logging
import uuid import uuid
from models.people import People, PeopleRLDBModel, Comment from models.people import People, PeopleRLDBModel
from datetime import datetime
from utils.error import ErrorCode, error from utils.error import ErrorCode, error
from utils import rldb from utils import rldb
@@ -68,50 +66,6 @@ class PeopleService:
return peoples, error(ErrorCode.SUCCESS, "") return peoples, error(ErrorCode.SUCCESS, "")
def save_remark(self, people_id: str, content: str) -> error:
"""
为人物添加或更新备注
:param people_id: 人物ID
:param content: 备注内容
:return: 错误对象
"""
people: People
err: error
people, err = self.get(people_id)
logging.info(f"get people before save remark: {people}")
if not err.success:
return err
remark = people.comments.get("remark", None)
if remark is not None:
remark.content = content
remark.updated_at = datetime.now()
else:
people.comments["remark"] = Comment(content=content)
logging.info(f"save remark for people {people}")
_, err = self.save(people)
return err
def delete_remark(self, people_id: str) -> error:
"""
删除人物备注
:param people_id: 人物ID
:return: 错误对象
"""
people: People
err: error
people, err = self.get(people_id)
if not err.success:
return err
if "remark" in people.comments:
del people.comments["remark"]
_, err = self.save(people)
return err
return error(ErrorCode.SUCCESS, "")
people_service = None people_service = None

View File

@@ -1,220 +0,0 @@
import uuid
import hmac
import base64
import os
from datetime import datetime, timedelta
from typing import Optional
from utils.error import ErrorCode, error
from utils import rldb, mailer, sms, config
from models.user import (
User,
UserRLDBModel,
VerificationCode,
VerificationCodeRLDBModel,
UserToken,
UserTokenRLDBModel,
)
class UserService:
def __init__(self):
self.rldb = rldb.get_instance()
self.mailer = mailer.get_instance()
self.sms = sms.get_instance()
self.conf = config.get_instance()
def _hash_password(self, password: str, salt: Optional[str] = None) -> str:
salt = salt if salt else base64.urlsafe_b64encode(os.urandom(16)).decode('utf-8')
digest = hmac.new(salt.encode('utf-8'), password.encode('utf-8'), 'sha256').digest()
return f"{salt}:{base64.urlsafe_b64encode(digest).decode('utf-8')}"
def _verify_password(self, password: str, password_hash: str) -> bool:
parts = password_hash.split(':')
if len(parts) != 2:
return False
salt = parts[0]
return self._hash_password(password, salt) == password_hash
def send_code(self, target_type: str, target: str, scene: str) -> error:
scens = {
"register": "注册",
"update": "信息更新",
# "login": "登录",
}
if scene not in scens:
return error(ErrorCode.MODEL_ERROR, f'scene {scene} not supported')
scene_name = scens.get(scene, scene)
code = f"{uuid.uuid4().int % 1000000:06d}"
expires = datetime.now() + timedelta(minutes=10)
vc = VerificationCode(
id=uuid.uuid4().hex,
target_type=target_type,
target=target,
code=code,
scene=scene,
expires_at=expires,
)
self.rldb.upsert(vc.to_rldb_model())
content = f"IF.U服务{scene_name}验证码: {code}, 10分钟内有效"
sent = True
if target_type == 'email':
sent = self.mailer.send(target, f'IF.U服务{scene_name}验证码', content) if self.mailer else False
elif target_type == 'phone':
sent = self.sms.send(target, content) if self.sms else False
if not sent:
return error(ErrorCode.RLDB_ERROR, 'send code failed')
return error(ErrorCode.SUCCESS, '')
def _get_user_by_identifier(self, email: Optional[str], phone: Optional[str]) -> Optional[User]:
if email:
users = self.rldb.query(UserRLDBModel, email=email, limit=1)
if users:
return User.from_rldb_model(users[0])
if phone:
users = self.rldb.query(UserRLDBModel, phone=phone, limit=1)
if users:
return User.from_rldb_model(users[0])
return None
def register(self, user: User, code: str) -> (str, error):
if not user.email and not user.phone:
return '', error(ErrorCode.MODEL_ERROR, 'email or phone required')
existed = self._get_user_by_identifier(user.email, user.phone)
if existed:
return '', error(ErrorCode.MODEL_ERROR, 'user existed')
target_type = 'phone' if user.phone else 'email'
target = user.phone if user.phone else user.email
vc_list = self.rldb.query(
VerificationCodeRLDBModel,
target_type=target_type,
target=target,
scene='register',
limit=1,
)
if not vc_list:
return '', error(ErrorCode.MODEL_ERROR, 'code not found')
vc = vc_list[0]
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
return '', error(ErrorCode.MODEL_ERROR, 'invalid code')
vc.used_at = datetime.now()
self.rldb.upsert(vc)
user.id = uuid.uuid4().hex
hashed = self._hash_password(user.password_hash)
user.password_hash = hashed
self.rldb.upsert(user.to_rldb_model())
return user.id, error(ErrorCode.SUCCESS, '')
def login(self, email: Optional[str], phone: Optional[str], password: str) -> (dict, error):
u = self._get_user_by_identifier(email, phone)
if not u:
return {}, error(ErrorCode.MODEL_ERROR, 'user not found')
if not self._verify_password(password, u.password_hash):
return {}, error(ErrorCode.MODEL_ERROR, 'invalid password')
ttl_days = self.conf.getint('auth', 'token_ttl_days', fallback=30)
expired_at = datetime.now() + timedelta(days=ttl_days)
token_raw = f"{u.id}.{uuid.uuid4().hex}.{int(expired_at.timestamp())}"
secret = self.conf.get('auth', 'jwt_secret', fallback='dev-secret')
signature = hmac.new(secret.encode('utf-8'), token_raw.encode('utf-8'), 'sha256').digest()
token = base64.urlsafe_b64encode(token_raw.encode('utf-8')).decode('utf-8') + '.' + base64.urlsafe_b64encode(signature).decode('utf-8')
ut = UserToken(id=uuid.uuid4().hex, user_id=u.id, token=token, expired_at=expired_at, revoked=False)
self.rldb.upsert(ut.to_rldb_model())
return {'token': token, 'expired_at': int(expired_at.timestamp())}, error(ErrorCode.SUCCESS, '')
def logout(self, token: str) -> error:
tokens = self.rldb.query(UserTokenRLDBModel, token=token, limit=1)
if not tokens:
return error(ErrorCode.MODEL_ERROR, 'token not found')
t = tokens[0]
t.revoked = True
self.rldb.upsert(t)
return error(ErrorCode.SUCCESS, '')
def delete_user(self, user_id: str) -> error:
u = self.rldb.get(UserRLDBModel, user_id)
if not u:
return error(ErrorCode.MODEL_ERROR, 'user not found')
self.rldb.delete(u)
return error(ErrorCode.SUCCESS, '')
def get(self, user_id: str) -> (User, error):
u = self.rldb.get(UserRLDBModel, user_id)
if not u:
return None, error(ErrorCode.MODEL_ERROR, 'user not found')
return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '')
def update_profile(self, user_id: str, nickname: str = None, avatar_link: str = None, phone: str = None, email: str = None) -> (User, error):
u = self.rldb.get(UserRLDBModel, user_id)
if not u:
return None, error(ErrorCode.MODEL_ERROR, 'user not found')
has_email = bool(u.email)
has_phone = bool(u.phone)
if nickname is not None:
u.nickname = nickname
if avatar_link is not None:
u.avatar_link = avatar_link
if email is not None:
new_email = email
if has_email:
if not has_phone:
return None, error(ErrorCode.MODEL_ERROR, 'email update requires phone exists')
conflicts = self.rldb.query(UserRLDBModel, email=new_email, limit=1)
if conflicts and conflicts[0].id != user_id:
return None, error(ErrorCode.MODEL_ERROR, 'email existed')
u.email = new_email
if phone is not None:
new_phone = phone
if has_phone:
if not has_email:
return None, error(ErrorCode.MODEL_ERROR, 'phone update requires email exists')
conflicts = self.rldb.query(UserRLDBModel, phone=new_phone, limit=1)
if conflicts and conflicts[0].id != user_id:
return None, error(ErrorCode.MODEL_ERROR, 'phone existed')
u.phone = new_phone
self.rldb.upsert(u)
return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '')
def update_phone_with_code(self, user_id: str, new_phone: str, code: str) -> (User, error):
vc_list = self.rldb.query(
VerificationCodeRLDBModel,
target_type='phone',
target=new_phone,
scene='update',
limit=1,
)
if not vc_list:
return None, error(ErrorCode.MODEL_ERROR, 'code not found')
vc = vc_list[0]
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
return None, error(ErrorCode.MODEL_ERROR, 'invalid code')
vc.used_at = datetime.now()
self.rldb.upsert(vc)
return self.update_profile(user_id, phone=new_phone)
def update_email_with_code(self, user_id: str, new_email: str, code: str) -> (User, error):
vc_list = self.rldb.query(
VerificationCodeRLDBModel,
target_type='email',
target=new_email,
scene='update',
limit=1,
)
if not vc_list:
return None, error(ErrorCode.MODEL_ERROR, 'code not found')
vc = vc_list[0]
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
return None, error(ErrorCode.MODEL_ERROR, 'invalid code')
vc.used_at = datetime.now()
self.rldb.upsert(vc)
return self.update_profile(user_id, email=new_email)
user_service = None
def init():
global user_service
user_service = UserService()
def get_instance() -> UserService:
return user_service

View File

@@ -7,10 +7,6 @@ class ErrorCode(Enum):
SUCCESS = 0 SUCCESS = 0
MODEL_ERROR = 1000 MODEL_ERROR = 1000
RLDB_ERROR = 2100 RLDB_ERROR = 2100
RLDB_NOT_FOUND = 2101
OBS_ERROR = 3100
OBS_INPUT_ERROR = 3102
OBS_SERVICE_ERROR = 3103
class error(Protocol): class error(Protocol):
_error_code: int = 0 _error_code: int = 0
@@ -19,6 +15,7 @@ class error(Protocol):
def __init__(self, error_code: ErrorCode, error_info: str): def __init__(self, error_code: ErrorCode, error_info: str):
self._error_code = int(error_code.value) self._error_code = int(error_code.value)
self._error_info = error_info self._error_info = error_info
logging.info(f"errorcode: {type(self._error_code)}")
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.__class__.__name__}({self._error_code}, {self._error_info})" return f"{self.__class__.__name__}({self._error_code}, {self._error_info})"

View File

@@ -1,60 +0,0 @@
import logging
import smtplib
from email.mime.text import MIMEText
from typing import Protocol
from .config import get_instance as get_config
class Mailer(Protocol):
def send(self, to_email: str, subject: str, content: str) -> bool:
...
class FakeMailer:
def __init__(self) -> None:
conf = get_config()
self.fake_message = conf.get('fake_mailer', 'message', fallback="FakeEmail")
def send(self, to_email: str, subject: str, content: str) -> bool:
logging.info(f"{self.fake_message}: to_email={to_email}, subject={subject}, content={content}")
return True
class RealMailer:
def __init__(self):
conf = get_config()
self.smtp_host = conf.get('real_mailer', 'smtp_host', fallback=None)
self.smtp_port = conf.getint('real_mailer', 'smtp_port', fallback=587)
self.smtp_user = conf.get('real_mailer', 'smtp_user', fallback=None)
self.smtp_pass = conf.get('real_mailer', 'smtp_pass', fallback=None)
self.from_email = conf.get('real_mailer', 'from_email', fallback=self.smtp_user)
def send(self, to_email: str, subject: str, content: str) -> bool:
if not self.smtp_host or not self.smtp_user or not self.smtp_pass:
return False
msg = MIMEText(content, 'plain', 'utf-8')
msg['Subject'] = subject
msg['From'] = self.from_email
msg['To'] = to_email
try:
server = smtplib.SMTP(self.smtp_host, self.smtp_port)
server.starttls()
server.login(self.smtp_user, self.smtp_pass)
server.sendmail(self.from_email, [to_email], msg.as_string())
server.quit()
return True
except Exception:
return False
_mailer: Mailer = None
def init(type: str = 'real'):
global _mailer
if type == 'real':
_mailer = RealMailer()
else:
_mailer = FakeMailer()
def get_instance() -> Mailer:
return _mailer

View File

@@ -4,13 +4,11 @@ import logging
from typing import Protocol from typing import Protocol
import qiniu import qiniu
import requests import requests
from .error import ErrorCode, error
from .config import get_instance as get_config from .config import get_instance as get_config
class OBS(Protocol): class OBS(Protocol):
def put(self, obs_path: str, content: bytes) -> str: def Put(self, obs_path: str, content: bytes) -> str:
""" """
上传文件到OBS 上传文件到OBS
@@ -23,7 +21,7 @@ class OBS(Protocol):
""" """
... ...
def get(self, obs_path: str) -> bytes: def Get(self, obs_path: str) -> bytes:
""" """
从OBS下载文件 从OBS下载文件
@@ -35,7 +33,7 @@ class OBS(Protocol):
""" """
... ...
def list(self, obs_path: str) -> list: def List(self, obs_path: str) -> list:
""" """
列出OBS目录下的所有文件 列出OBS目录下的所有文件
@@ -47,7 +45,7 @@ class OBS(Protocol):
""" """
... ...
def delete(self, obs_path: str) -> error: def Del(self, obs_path: str) -> bool:
""" """
删除OBS文件 删除OBS文件
@@ -59,7 +57,7 @@ class OBS(Protocol):
""" """
... ...
def get_link(self, obs_path: str) -> str: def Link(self, obs_path: str) -> str:
""" """
获取OBS文件链接 获取OBS文件链接
@@ -70,31 +68,6 @@ class OBS(Protocol):
str: OBS文件链接 str: OBS文件链接
""" """
... ...
def delete_by_link(self, obs_link: str) -> error:
"""
根据OBS文件链接删除文件
Args:
obs_link (str): OBS文件链接
Returns:
bool: 是否删除成功
"""
...
def get_obs_path_by_link(self, obs_link: str) -> (str, error):
"""
从OBS文件链接获取OBS路径
Args:
obs_link (str): OBS文件链接
Returns:
str: OBS文件路径
error: 错误信息
"""
...
class Koodo: class Koodo:
@@ -109,7 +82,7 @@ class Koodo:
self.bucket = qiniu.BucketManager(self.auth) self.bucket = qiniu.BucketManager(self.auth)
pass pass
def put(self, obs_path: str, content: bytes) -> str: def Put(self, obs_path: str, content: bytes) -> str:
""" """
上传文件到OBS 上传文件到OBS
@@ -130,7 +103,7 @@ class Koodo:
logging.info(f"文件 {obs_path} 上传成功, OBS路径: {full_path}") logging.info(f"文件 {obs_path} 上传成功, OBS路径: {full_path}")
return f"{self.outer_domain}/{full_path}" return f"{self.outer_domain}/{full_path}"
def get(self, obs_path: str) -> bytes: def Get(self, obs_path: str) -> bytes:
""" """
从OBS下载文件 从OBS下载文件
@@ -148,7 +121,7 @@ class Koodo:
return None return None
return resp.content return resp.content
def list(self, prefix: str = "") -> list[str]: def List(self, prefix: str = "") -> list[str]:
""" """
列出OBS目录下的所有文件 列出OBS目录下的所有文件
@@ -170,7 +143,7 @@ class Koodo:
# logging.debug(f"info: {info}") # logging.debug(f"info: {info}")
return keys return keys
def delete(self, obs_path: str) -> error: def Del(self, obs_path: str) -> bool:
""" """
删除OBS文件 删除OBS文件
@@ -178,17 +151,17 @@ class Koodo:
obs_path (str): OBS文件路径 obs_path (str): OBS文件路径
Returns: Returns:
error: 删除结果 bool: 是否删除成功
""" """
ret, info = self.bucket.delete(self.bucket_name, f"{self.prefix_path}{obs_path}") ret, info = self.bucket.delete(self.bucket_name, f"{self.prefix_path}{obs_path}")
logging.debug(f"文件 {self.prefix_path}{obs_path} 删除 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}") logging.debug(f"文件 {obs_path} 删除 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}")
if ret is None or info.status_code != 200: if ret is None or info.status_code != 200:
logging.error(f"文件 {obs_path} 删除 OBS 失败, 错误信息: {info.text_body}") logging.error(f"文件 {obs_path} 删除 OBS 失败, 错误信息: {info.text_body}")
return error(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {self.prefix_path}{obs_path} 删除 OBS 失败, 错误信息: {info.text_body}") return False
logging.info(f"文件 {obs_path} 删除 OBS 成功") logging.info(f"文件 {obs_path} 删除 OBS 成功")
return error(error_code=ErrorCode.SUCCESS, error_info="success") return True
def get_link(self, obs_path: str) -> str: def Link(self, obs_path: str) -> str:
""" """
获取OBS文件链接 获取OBS文件链接
@@ -200,38 +173,6 @@ class Koodo:
""" """
return f"{self.outer_domain}/{self.prefix_path}{obs_path}" return f"{self.outer_domain}/{self.prefix_path}{obs_path}"
def delete_by_link(self, obs_link: str) -> error:
"""
根据OBS文件链接删除文件
Args:
obs_link (str): OBS文件链接
Returns:
error: 删除结果
"""
obs_path, err = self.get_obs_path_by_link(obs_link)
if not err.success:
return err
return self.delete(obs_path)
def get_obs_path_by_link(self, obs_link: str) -> (str, error):
"""
从OBS文件链接获取OBS路径
Args:
obs_link (str): OBS文件链接
Returns:
str: OBS文件路径
error: 错误信息
"""
if not obs_link.startswith(f"{self.outer_domain}/{self.prefix_path}"):
logging.error(f"文件 {obs_link} 不是 OBS 文件链接")
return "", error(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {obs_link} 不是 OBS 文件链接")
obs_path = obs_link[len(self.outer_domain) + len(self.prefix_path) + 1:]
return obs_path, error(error_code=ErrorCode.SUCCESS, error_info="success")
_obs_instance: OBS = None _obs_instance: OBS = None
@@ -272,8 +213,8 @@ if __name__ == "__main__":
# print(f"文件 {obs_path} 链接: {link}") # print(f"文件 {obs_path} 链接: {link}")
# 列出OBS目录下的所有文件 # 列出OBS目录下的所有文件
keys = obs.list("") keys = obs.List("")
print(f"OBS 目录下的所有文件: {keys}") print(f"OBS 目录下的所有文件: {keys}")
for key in keys: for key in keys:
link = obs.delete(key) link = obs.Del(key)
print(f"文件 {key} 删除 OBS 成功: {link}") print(f"文件 {key} 删除 OBS 成功: {link}")

View File

@@ -1,5 +1,4 @@
from re import S
from typing import Protocol from typing import Protocol
import uuid import uuid
from sqlalchemy import Column, DateTime, String, create_engine, func from sqlalchemy import Column, DateTime, String, create_engine, func

View File

@@ -1,51 +0,0 @@
import logging
from typing import Protocol
import requests
from .config import get_instance as get_config
class SMS(Protocol):
def send(self, phone: str, content: str) -> bool:
...
class FakeSMS:
def __init__(self) -> None:
conf = get_config()
self.fake_message = conf.get('fake_sms', 'message', fallback="FakeSMS")
def send(self, phone: str, content: str) -> bool:
logging.info(f"{self.fake_message}: phone={phone}, content={content}")
return True
class RealSMS:
def __init__(self):
conf = get_config()
self.webhook_url = conf.get('real_sms', 'webhook_url', fallback=None)
self.webhook_token = conf.get('real_sms', 'webhook_token', fallback=None)
def send(self, phone: str, content: str) -> bool:
if not self.webhook_url:
return False
try:
headers = {'Authorization': f'Bearer {self.webhook_token}'} if self.webhook_token else {}
data = {'phone': phone, 'content': content}
resp = requests.post(self.webhook_url, json=data, headers=headers, timeout=5)
return resp.status_code >= 200 and resp.status_code < 300
except Exception:
return False
_sms: SMS = None
def init(type: str = 'real'):
global _sms
if type == 'real':
_sms = RealSMS()
else:
_sms = FakeSMS()
def get_instance() -> SMS:
return _sms

View File

@@ -1,60 +1,152 @@
import os import os
import uuid import uuid
from fastapi import FastAPI, UploadFile, File, APIRouter, Depends import logging
from fastapi.concurrency import run_in_threadpool from typing import Any, Optional
from fastapi import FastAPI, UploadFile, File, Query
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from web.auth import require_auth from services.people import get_instance as get_people_service
from utils import obs from models.people import People
from web.schemas import BaseResponse from agents.extract_people_agent import ExtractPeopleAgent
from web.custom import router as custom_router from utils import obs, ocr
from web.people import router as people_router
from web.user import router as user_router
from web.recognition import router as recognition_router
api = FastAPI(title="Single People Management and Searching", version="0.1") api = FastAPI(title="Single People Management and Searching", version="0.1")
api.add_middleware( api.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["https://localhost:5173", "https://ifu.mamamiyear.site"], allow_origins=["*"],
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
authorized_router = APIRouter(dependencies=[Depends(require_auth)]) class BaseResponse(BaseModel):
error_code: int
error_info: str
data: Optional[Any] = None
@api.post("/api/ping") @api.post("/ping")
async def ping(): async def ping():
return BaseResponse(error_code=0, error_info="success") return BaseResponse(error_code=0, error_info="success")
@authorized_router.post("/api/upload/image") class PostInputRequest(BaseModel):
async def post_upload_image(image: UploadFile = File(...)): text: str
@api.post("/recognition/input")
async def post_input(request: PostInputRequest):
people = extract_people(request.text)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
@api.post("/recognition/image")
async def post_input_image(image: UploadFile = File(...)):
# 实现上传图片的处理 # 实现上传图片的处理
# 保存上传的图片文件 # 保存上传的图片文件
# 生成唯一的文件名 # 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1] file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}" unique_filename = f"{uuid.uuid4()}{file_extension}"
# 确保uploads目录存在
os.makedirs("uploads", exist_ok=True)
# 保存文件到对象存储 # 保存文件到对象存储
file_path = f"uploads/{unique_filename}" file_path = f"uploads/{unique_filename}"
obs_util = obs.get_instance() obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read()) obs_util.Put(file_path, await image.read())
# 获取对象存储外链 # 获取对象存储外链
obs_url = obs_util.get_link(file_path) obs_url = obs_util.Link(file_path)
return BaseResponse(error_code=0, error_info="success", data=obs_url) logging.info(f"obs_url: {obs_url}")
# 调用OCR处理图片
ocr_util = ocr.get_instance()
ocr_result = ocr_util.recognize_image_text(obs_url)
logging.info(f"ocr_result: {ocr_result}")
people = extract_people(ocr_result, obs_url)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
def extract_people(text: str, cover_link: str = None) -> People:
extra_agent = ExtractPeopleAgent()
people = extra_agent.extract_people_info(text)
people.cover = cover_link
logging.info(f"people: {people}")
return people
api.include_router(authorized_router) class PostPeopleRequest(BaseModel):
people: dict
# Register custom router @api.post("/people")
api.include_router(custom_router, dependencies=[Depends(require_auth)]) async def post_people(post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
service = get_people_service()
people.id, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=people.id)
# Register people router @api.put("/people/{people_id}")
api.include_router(people_router, dependencies=[Depends(require_auth)]) async def update_people(people_id: str, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.id = people_id
service = get_people_service()
res, error = service.get(people_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
_, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
# Register user router @api.delete("/people/{people_id}")
api.include_router(user_router) async def delete_people(people_id: str):
service = get_people_service()
error = service.delete(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
# Register recognition router class GetPeopleRequest(BaseModel):
api.include_router(recognition_router, dependencies=[Depends(require_auth)]) query: Optional[str] = None
conds: Optional[dict] = None
top_k: int = 5
@api.get("/peoples")
async def get_peoples(
name: Optional[str] = Query(None, description="姓名"),
gender: Optional[str] = Query(None, description="性别"),
age: Optional[int] = Query(None, description="年龄"),
height: Optional[int] = Query(None, description="身高"),
marital_status: Optional[str] = Query(None, description="婚姻状态"),
limit: int = Query(10, description="分页大小"),
offset: int = Query(0, description="分页偏移量"),
):
# 解析查询参数为字典
conds = {}
if name:
conds["name"] = name
if gender:
conds["gender"] = gender
if age:
conds["age"] = age
if height:
conds["height"] = height
if marital_status:
conds["marital_status"] = marital_status
logging.info(f"conds: , limit: {limit}, offset: {offset}")
results = []
service = get_people_service()
results, error = service.list(conds, limit=limit, offset=offset)
logging.info(f"query results: {results}")
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
peoples = [people.to_dict() for people in results]
return BaseResponse(error_code=0, error_info="success", data=peoples)

View File

@@ -1,28 +0,0 @@
from typing import Optional
from fastapi import Cookie, HTTPException, Request
from utils import rldb as rldb_util
from models.user import User, UserTokenRLDBModel, UserRLDBModel
from datetime import datetime
def require_auth(request: Request, token: Optional[str] = Cookie(None)):
if not token:
raise HTTPException(status_code=401, detail="unauthorized")
db = rldb_util.get_instance()
tokens = db.query(UserTokenRLDBModel, token=token, limit=1)
if not tokens:
raise HTTPException(status_code=401, detail="unauthorized")
t = tokens[0]
if getattr(t, 'revoked', False):
raise HTTPException(status_code=401, detail="unauthorized")
if getattr(t, 'expired_at', None) and t.expired_at < datetime.now():
raise HTTPException(status_code=401, detail="unauthorized")
user_orm = db.get(UserRLDBModel, t.user_id)
if not user_orm:
raise HTTPException(status_code=401, detail="unauthorized")
user = User.from_rldb_model(user_orm)
request.state.user_id = user.id
request.state.user_nickname = user.nickname
request.state.user_email = user.email
request.state.user_phone = user.phone
request.state.token = token

View File

@@ -1,151 +0,0 @@
import logging
import os
import uuid
from fastapi import APIRouter, Depends, Request, Query, UploadFile, File
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from models.custom import Custom
from services.custom import get_instance as get_custom_service
from utils.error import ErrorCode
from utils import obs
from web.schemas import BaseResponse
router = APIRouter(tags=["custom"])
class PostCustomRequest(BaseModel):
custom: dict
@router.post("/api/custom")
def create_custom(request: Request, post_custom_request: PostCustomRequest):
logging.debug(f"post_custom_request: {post_custom_request}")
custom = Custom.from_dict(post_custom_request.custom)
# Validate custom data
err = custom.validate()
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
custom.user_id = getattr(request.state, 'user_id', '')
service = get_custom_service()
custom.id, error = service.save(custom)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=custom.id)
@router.put("/api/custom/{custom_id}")
def update_custom(request: Request, custom_id: str, post_custom_request: PostCustomRequest):
logging.debug(f"post_custom_request: {post_custom_request}")
custom = Custom.from_dict(post_custom_request.custom)
custom.id = custom_id
# Validate custom data
err = custom.validate()
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
service = get_custom_service()
# Check permission
res, error = service.get(custom_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
custom.user_id = res.user_id # Ensure user_id is not changed or is set correctly
_, error = service.save(custom)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/custom/{custom_id}")
def delete_custom(request: Request, custom_id: str):
service = get_custom_service()
res, error = service.get(custom_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete(custom_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=custom_id)
@router.get("/api/customs")
def get_customs(request: Request, limit: int = Query(10, ge=1, le=1000), offset: int = Query(0, ge=0)):
service = get_custom_service()
res, error = service.list({'user_id': getattr(request.state, 'user_id', '')}, limit=limit, offset=offset)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
# custom对象转换为字典
customs = [custom.to_dict() for custom in res]
return BaseResponse(error_code=0, error_info="success", data=customs)
@router.get("/api/custom/{custom_id}")
def get_custom(request: Request, custom_id: str):
service = get_custom_service()
res, error = service.get(custom_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
return BaseResponse(error_code=0, error_info="success", data=res.to_dict())
@router.post("/api/custom/{custom_id}/image")
async def post_custom_image(request: Request, custom_id: str, image: UploadFile = File(...)):
# 检查 custom id 是否存在
service = get_custom_service()
custom, err = service.get(custom_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if custom.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"customs/{custom_id}/images/{unique_filename}"
obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
return BaseResponse(error_code=0, error_info="success", data=obs_url)
@router.delete("/api/custom/{custom_id}/image")
async def delete_custom_image(request: Request, custom_id: str, image_url: str):
# 检查 custom id 是否存在
service = get_custom_service()
custom, err = service.get(custom_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if custom.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 检查 image_url 是否是该 custom 名下的图片链接
obs_util = obs.get_instance()
obs_path, err = obs_util.get_obs_path_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if not obs_path.startswith(f"customs/{custom_id}/images/"):
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {custom_id} 名下的图片链接")
# 实现删除图片的处理
# 删除对象存储中的文件
err = obs_util.delete_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
return BaseResponse(error_code=0, error_info="success")

View File

@@ -1,192 +0,0 @@
import os
import uuid
import logging
from typing import Optional
from fastapi import APIRouter, Request, UploadFile, File, Query
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from services.people import get_instance as get_people_service
from models.people import People
from utils import obs
from utils.error import ErrorCode
from web.schemas import BaseResponse
router = APIRouter(tags=["people"])
class PostPeopleRequest(BaseModel):
people: dict
@router.post("/api/people")
async def post_people(request: Request, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.user_id = getattr(request.state, 'user_id', '')
service = get_people_service()
people.id, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=people.id)
@router.put("/api/people/{people_id}")
async def update_people(request: Request, people_id: str, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.id = people_id
service = get_people_service()
res, error = service.get(people_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
people.user_id = res.user_id
_, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/people/{people_id}")
async def delete_people(request: Request, people_id: str):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
class GetPeopleRequest(BaseModel):
query: Optional[str] = None
conds: Optional[dict] = None
top_k: int = 5
@router.get("/api/peoples")
async def get_peoples(
request: Request,
name: Optional[str] = Query(None, description="姓名"),
gender: Optional[str] = Query(None, description="性别"),
age: Optional[int] = Query(None, description="年龄"),
height: Optional[int] = Query(None, description="身高"),
marital_status: Optional[str] = Query(None, description="婚姻状态"),
limit: int = Query(10, description="分页大小"),
offset: int = Query(0, description="分页偏移量"),
):
# 解析查询参数为字典
conds = {}
conds["user_id"] = getattr(request.state, 'user_id', '')
if name:
conds["name"] = name
if gender:
conds["gender"] = gender
if age:
conds["age"] = age
if height:
conds["height"] = height
if marital_status:
conds["marital_status"] = marital_status
logging.info(f"conds: , limit: {limit}, offset: {offset}")
results = []
service = get_people_service()
results, error = service.list(conds, limit=limit, offset=offset)
logging.info(f"query results: {results}")
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
peoples = [people.to_dict() for people in results]
return BaseResponse(error_code=0, error_info="success", data=peoples)
class RemarkRequest(BaseModel):
content: str
@router.post("/api/people/{people_id}/remark")
async def post_people_remark(request: Request, people_id: str, body: RemarkRequest):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.save_remark(people_id, body.content)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/people/{people_id}/remark")
async def delete_people_remark(request: Request, people_id: str):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete_remark(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.post("/api/people/{people_id}/image")
async def post_people_image(request: Request, people_id: str, image: UploadFile = File(...)):
# 检查 people id 是否存在
service = get_people_service()
people, err = service.get(people_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if people.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"peoples/{people_id}/images/{unique_filename}"
obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
return BaseResponse(error_code=0, error_info="success", data=obs_url)
@router.delete("/api/people/{people_id}/image")
async def delete_people_image(request: Request, people_id: str, image_url: str):
# 检查 people id 是否存在
service = get_people_service()
people, err = service.get(people_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if people.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 检查 image_url 是否是该 people 名下的图片链接
obs_util = obs.get_instance()
obs_path, err = obs_util.get_obs_path_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if not obs_path.startswith(f"peoples/{people_id}/images/"):
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {people_id} 名下的图片链接")
# 实现删除图片的处理
# 删除对象存储中的文件
err = obs_util.delete_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
return BaseResponse(error_code=0, error_info="success")

View File

@@ -1,91 +0,0 @@
import os
import uuid
import logging
from typing import Optional
from fastapi import APIRouter, UploadFile, File
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from models.people import People
from models.custom import Custom
from agents.extract_people_agent import ExtractPeopleAgent
from agents.extract_custom_agent import ExtractCustomAgent
from utils import obs, ocr
from web.schemas import BaseResponse
from utils.error import ErrorCode
router = APIRouter(tags=["recognition"])
def extract_people(text: str, cover_link: str = None) -> People:
extra_agent = ExtractPeopleAgent()
people = extra_agent.extract_people_info(text)
if people:
people.cover = cover_link
logging.info(f"people: {people}")
return people
def extract_custom(text: str, image_link: str = None) -> Custom:
extra_agent = ExtractCustomAgent()
custom = extra_agent.extract_custom_info(text)
if custom:
if image_link:
custom.images = [image_link]
logging.info(f"custom: {custom}")
return custom
class PostInputRequest(BaseModel):
text: str
@router.post("/api/recognition/{model}/input")
async def post_recognition_input(model: str, request: PostInputRequest):
if model == "people":
result = await run_in_threadpool(extract_people, request.text)
elif model == "custom":
result = await run_in_threadpool(extract_custom, request.text)
else:
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
if result is None:
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
resp = BaseResponse(error_code=0, error_info="success")
resp.data = result.to_dict()
return resp
@router.post("/api/recognition/{model}/image")
async def post_recognition_image(model: str, image: UploadFile = File(...)):
if model not in ["people", "custom"]:
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"uploads/{model}/{unique_filename}"
obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
# 调用OCR处理图片
ocr_util = ocr.get_instance()
ocr_result = await run_in_threadpool(ocr_util.recognize_image_text, obs_url)
logging.info(f"ocr_result: {ocr_result}")
if model == "people":
result = await run_in_threadpool(extract_people, ocr_result, obs_url)
elif model == "custom":
result = await run_in_threadpool(extract_custom, ocr_result, obs_url)
if result is None:
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
resp = BaseResponse(error_code=0, error_info="success")
resp.data = result.to_dict()
return resp

View File

@@ -1,7 +0,0 @@
from typing import Any, Optional
from pydantic import BaseModel
class BaseResponse(BaseModel):
error_code: int
error_info: str
data: Optional[Any] = None

View File

@@ -1,223 +0,0 @@
import os
import time
import logging
from typing import Optional, Literal
from fastapi import APIRouter, Depends, Request, HTTPException, Response, UploadFile, File
from pydantic import BaseModel
from services.user import get_instance as get_user_service
from web.auth import require_auth
from utils import obs
from utils.config import get_instance as get_config
from web.schemas import BaseResponse
router = APIRouter(tags=["user"])
class SendCodeRequest(BaseModel):
target_type: str
target: str
scene: Literal['register', 'update']
# scene: Literal['register', 'login']
@router.post("/api/user/send_code")
async def send_user_code(request: SendCodeRequest):
service = get_user_service()
err = service.send_code(request.target_type, request.target, request.scene)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success")
class RegisterRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
password: str
code: str
@router.post("/api/user")
async def user_register(request: RegisterRequest):
service = get_user_service()
from models.user import User
u = User(
nickname=request.nickname or "",
avatar_link=request.avatar_link or "",
email=request.email or "",
phone=request.phone or "",
password_hash=request.password,
)
uid, err = service.register(u, request.code)
if not err.success:
logging.error(f"register failed: {err}")
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data=uid)
class LoginRequest(BaseModel):
email: Optional[str] = None
phone: Optional[str] = None
password: str
@router.post("/api/user/login")
async def user_login(request: LoginRequest, response: Response):
service = get_user_service()
data, err = service.login(request.email, request.phone, request.password)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
ttl_days = conf.getint('auth', 'token_ttl_days', fallback=30)
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
cookie_secure = conf.getboolean('auth', 'cookie_secure', fallback=False)
cookie_samesite = conf.get('auth', 'cookie_samesite', fallback=None)
response.set_cookie(
key="token",
value=data.get('token', ''),
max_age=ttl_days * 24 * 3600,
httponly=True,
secure=cookie_secure,
samesite=cookie_samesite,
domain=cookie_domain,
path="/",
)
return BaseResponse(error_code=0, error_info="success", data={"expired_at": data.get('expired_at')})
@router.delete("/api/user/me/login", dependencies=[Depends(require_auth)])
async def user_logout(response: Response, request: Request):
service = get_user_service()
err = service.logout(getattr(request.state, 'token', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/user/me", dependencies=[Depends(require_auth)])
async def user_delete(response: Response, request: Request):
service = get_user_service()
err = service.delete_user(getattr(request.state, 'user_id', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@router.get("/api/user/me", dependencies=[Depends(require_auth)])
async def user_me(request: Request):
service = get_user_service()
user, err = service.get(getattr(request.state, 'user_id', None))
if not err.success or not user:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateMeRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
phone: Optional[str] = None
email: Optional[str] = None
@router.put("/api/user/me", dependencies=[Depends(require_auth)])
async def update_user_me(request: Request, body: UpdateMeRequest):
service = get_user_service()
user, err = service.update_profile(
getattr(request.state, 'user_id', None),
nickname=body.nickname,
avatar_link=body.avatar_link,
phone=body.phone,
email=body.email,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
@router.put("/api/user/me/avatar", dependencies=[Depends(require_auth)])
async def upload_avatar(request: Request, avatar: UploadFile = File(...)):
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="unauthorized")
file_extension = os.path.splitext(avatar.filename)[1]
timestamp = int(time.time())
avatar_path = f"users/{user_id}/avatar-{timestamp}{file_extension}"
try:
obs_util = obs.get_instance()
obs_util.Put(avatar_path, await avatar.read())
avatar_url = obs_util.Link(avatar_path)
user_service = get_user_service()
_, err = user_service.update_profile(user_id, avatar_link=avatar_url)
if not err.success:
raise HTTPException(status_code=500, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data={"avatar_link": avatar_url})
except Exception as e:
logging.error(f"upload avatar failed: {e}")
raise HTTPException(status_code=500, detail="upload avatar failed")
class UpdatePhoneRequest(BaseModel):
phone: str
code: str
@router.put("/api/user/me/phone", dependencies=[Depends(require_auth)])
async def update_user_phone(request: Request, body: UpdatePhoneRequest):
service = get_user_service()
user, err = service.update_phone_with_code(
getattr(request.state, 'user_id', None),
body.phone,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateEmailRequest(BaseModel):
email: str
code: str
@router.put("/api/user/me/email", dependencies=[Depends(require_auth)])
async def update_user_email(request: Request, body: UpdateEmailRequest):
service = get_user_service()
user, err = service.update_email_with_code(
getattr(request.state, 'user_id', None),
body.email,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)