feat: support custom management

- add custom model and rldb model
- add service for custom to operate rldb
- add apis to CURD custom and image upload and delete
- support to recognize custom from text or image
- refactor web servcie start mode and api group
- - group the apis
- - support uvicorn start service in terminal
- - refactor recognization api for both people and custom
This commit is contained in:
2025-11-27 00:46:53 +08:00
parent 25fb6ba9ce
commit 12757afda1
11 changed files with 1190 additions and 461 deletions

View File

@@ -0,0 +1,85 @@
import datetime
import json
import logging
from langchain.prompts import ChatPromptTemplate
from .base_agent import BaseAgent
from models.custom import Custom
class ExtractCustomAgent(BaseAgent):
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
super().__init__(api_url, api_key, model_name)
self.prompt = ChatPromptTemplate.from_messages([
(
"system",
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')}"
"你是一个专业的客户信息录入助手,善于从一段文字描述中,精确获取客户的以下信息:\n"
"姓名 name\n"
"性别 gender (男/女/未知)\n"
"出生年份 birth (整数年份,如 1990若文本只提供了年龄请根据当前日期计算出出生年份)\n"
"手机号 phone\n"
"邮箱 email\n"
"身高(cm) height (整数)\n"
"体重(kg) weight (整数)\n"
"学历 degree\n"
"毕业院校 academy\n"
"职业 occupation\n"
"年收入(万) income (整数)\n"
"资产(万) assets (整数)\n"
"流动资产(万) current_assets (整数)\n"
"房产情况 house (必须为以下之一: '有房无贷', '有房有贷', '无自有房', 若未提及则不填)\n"
"车辆情况 car (必须为以下之一: '有车无贷', '有车有贷', '无自有车', 若未提及则不填)\n"
"户口城市 registered_city\n"
"居住城市 live_city\n"
"籍贯 native_place\n"
"原生家庭情况 original_family\n"
"是否独生子女 is_single_child (true/false)\n"
"择偶要求 match_requirement\n"
"\n"
"以上信息需要严格按照 JSON 格式输出,字段名与条目中英文保持一致。\n"
"若未识别到某项,则不返回该字段,不要自行填写“未知”、“未填写”等。\n"
"\n"
"除了上述基本信息,还有一个字段:\n"
"其他介绍 introductions\n"
"其余的信息需要按照字典的方式进行提炼和总结,都放在 introductions 字段中key 使用提炼好的中文。\n"
),
("human", "{input}")
])
def extract_custom_info(self, text: str) -> Custom:
"""从文本中提取客户信息"""
prompt = self.prompt.format_prompt(input=text)
response = self.llm.invoke(prompt)
logging.info(f"llm response: {response.content}")
try:
custom_dict = json.loads(response.content)
# 类型安全转换防止LLM返回字符串类型的数字
int_fields = ['birth', 'height', 'weight', 'income', 'assets', 'scores', 'current_assets']
for field in int_fields:
if field in custom_dict and isinstance(custom_dict[field], str):
try:
# 尝试提取数字,简单处理
import re
num = re.findall(r'\d+', custom_dict[field])
if num:
custom_dict[field] = int(num[0])
else:
del custom_dict[field] # 无法转换则移除
except:
del custom_dict[field]
custom = Custom.from_dict(custom_dict)
err = custom.validate()
if not err.success:
logging.warning(f"Validation warning: {err.info}")
# 即使校验失败(如某些必填项缺失),也尽可能返回已提取的对象,
# 让上层业务逻辑决定是否接受或需要补充
return custom
except json.JSONDecodeError:
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
return None
except Exception as e:
logging.error(f"Failed to process custom info: {e}")
return None

View File

@@ -1,22 +1,23 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-09-27
import os
import sys
# Add src directory to sys.path to ensure modules can be imported correctly when running with uvicorn
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import argparse
import uvicorn
from services import people as people_service
from services import user as user_service
from services import custom as custom_service
from utils import config, logger, obs, ocr, rldb, sms, mailer
from web.api import api
# 主函数
def main():
main_path = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser(description='IF.u 服务')
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
args = parser.parse_args()
config.init(args.config)
def initialize_app(config_path):
"""Initialize application components with the given config path."""
config.init(config_path)
conf = config.get_instance()
logger.init()
@@ -28,12 +29,28 @@ def main():
people_service.init()
user_service.init()
custom_service.init()
# 主函数
def main():
main_path = os.path.dirname(os.path.abspath(__file__))
parser = argparse.ArgumentParser(description='IF.u 服务')
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
args = parser.parse_args()
initialize_app(args.config)
conf = config.get_instance()
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
port = conf.getint('web_service', 'server_port', fallback=8099)
uvicorn.run(api, host=host, port=port)
uvicorn.run("src.main:api", host=host, port=port, reload=True) # Modified to string import for reload support in main too, though api object also works
if __name__ == "__main__":
main()
main()
else:
# Support for running via 'uvicorn src.main:api'
# Use environment variable for config path or default
main_path = os.path.dirname(os.path.abspath(__file__))
default_config_path = os.path.join(main_path, '../configuration/test_conf.ini')
config_path = os.environ.get('IFU_CONFIG_PATH', default_config_path)
initialize_app(config_path)

291
src/models/custom.py Normal file
View File

@@ -0,0 +1,291 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-11-27
import json
import logging
from typing import Dict, List
from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime, func, Boolean
from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error
class CustomRLDBModel(RLDBBaseModel):
"""
客户数据的数据库模型 (SQLAlchemy Model) - 更新版
"""
__tablename__ = 'customs'
id = Column(String(36), primary_key=True)
user_id = Column(String(36), index=True, nullable=False)
# 基本信息
name = Column(String(255), index=True, nullable=False)
gender = Column(String(10), nullable=False)
birth = Column(Integer, nullable=False) # 出生年份
phone = Column(String(50), index=True)
email = Column(String(255), index=True)
# 外貌信息
height = Column(Integer)
weight = Column(Integer)
images = Column(Text) # JSON string for list[str]
scores = Column(Integer)
# 学历职业
degree = Column(String(255))
academy = Column(String(255))
occupation = Column(String(255))
income = Column(Integer) # 单位:万
assets = Column(Integer) # 单位:万
current_assets = Column(Integer) # 单位:万
house = Column(String(50))
car = Column(String(50))
# 户口家庭
registered_city = Column(String(255))
live_city = Column(String(255))
native_place = Column(String(255))
original_family = Column(Text)
is_single_child = Column(Boolean, default=False)
match_requirement = Column(Text)
introductions = Column(Text) # JSON string for Dict[str, str]
# 客户信息
custom_level = Column(String(255))
comments = Column(Text) # JSON string for Dict[str, str]
is_public = Column(Boolean, default=False)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
class Custom:
"""
客户数据的业务逻辑模型 (Business Logic Model) - 更新版
"""
id: str
user_id: str
# 基本信息
name: str
gender: str
birth: int
phone: str
email: str
# 外貌信息
height: int
weight: int
images: List[str]
scores: int
# 学历职业
degree: str
academy: str
occupation: str
income: int
assets: int
current_assets: int
house: str
car: str
# 户口家庭
registered_city: str
live_city: str
native_place: str
original_family: str
is_single_child: bool
match_requirement: str
introductions: Dict[str, str]
# 客户信息
custom_level: str
comments: Dict[str, str]
is_public: bool
created_at: datetime = None
def __init__(self, **kwargs):
# 初始化所有属性
self.id = kwargs.get('id', '')
self.user_id = kwargs.get('user_id', '')
self.name = kwargs.get('name', '')
self.gender = kwargs.get('gender', '未知')
self.birth = kwargs.get('birth', 0)
self.phone = kwargs.get('phone', '')
self.email = kwargs.get('email', '')
self.height = kwargs.get('height', 0)
self.weight = kwargs.get('weight', 0)
self.images = kwargs.get('images', [])
self.scores = kwargs.get('scores', 0)
self.degree = kwargs.get('degree', '')
self.academy = kwargs.get('academy', '')
self.occupation = kwargs.get('occupation', '')
self.income = kwargs.get('income', 0)
self.assets = kwargs.get('assets', 0)
self.current_assets = kwargs.get('current_assets', 0)
self.house = kwargs.get('house', '')
self.car = kwargs.get('car', '')
self.registered_city = kwargs.get('registered_city', '')
self.live_city = kwargs.get('live_city', '')
self.native_place = kwargs.get('native_place', '')
self.original_family = kwargs.get('original_family', '')
self.is_single_child = kwargs.get('is_single_child', False)
self.match_requirement = kwargs.get('match_requirement', '')
self.introductions = kwargs.get('introductions', {})
self.custom_level = kwargs.get('custom_level', '')
self.comments = kwargs.get('comments', {})
self.is_public = kwargs.get('is_public', False)
self.created_at = kwargs.get('created_at')
def __str__(self) -> str:
# 返回对象的字符串表示
attributes = ", ".join(f"{k}={v}" for k, v in self.to_dict().items())
return f"Custom({attributes})"
@classmethod
def from_dict(cls, data: dict):
# 从字典创建对象实例
if 'created_at' in data and data['created_at'] is not None:
data['created_at'] = datetime.fromtimestamp(data['created_at'])
# 移除ORM特有的时间戳避免初始化错误
data.pop('updated_at', None)
data.pop('deleted_at', None)
return cls(**data)
def to_dict(self) -> dict:
# 将对象转换为字典
return {
'id': self.id,
'user_id': self.user_id,
'name': self.name,
'gender': self.gender,
'birth': self.birth,
'phone': self.phone,
'email': self.email,
'height': self.height,
'weight': self.weight,
'images': self.images,
'scores': self.scores,
'degree': self.degree,
'academy': self.academy,
'occupation': self.occupation,
'income': self.income,
'assets': self.assets,
'current_assets': self.current_assets,
'house': self.house,
'car': self.car,
'registered_city': self.registered_city,
'live_city': self.live_city,
'native_place': self.native_place,
'original_family': self.original_family,
'is_single_child': self.is_single_child,
'match_requirement': self.match_requirement,
'introductions': self.introductions,
'custom_level': self.custom_level,
'comments': self.comments,
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
}
@classmethod
def from_rldb_model(cls, data: CustomRLDBModel):
# 从数据库模型转换
return cls(
id=data.id,
user_id=data.user_id,
name=data.name,
gender=data.gender,
birth=data.birth,
phone=data.phone,
email=data.email,
height=data.height,
weight=data.weight,
images=json.loads(data.images) if data.images else [],
scores=data.scores,
degree=data.degree,
academy=data.academy,
occupation=data.occupation,
income=data.income,
assets=data.assets,
house=data.house,
car=data.car,
registered_city=data.registered_city,
live_city=data.live_city,
native_place=data.native_place,
original_family=data.original_family,
is_single_child=data.is_single_child,
match_requirement=data.match_requirement,
introductions=json.loads(data.introductions) if data.introductions else {},
custom_level=data.custom_level,
comments=json.loads(data.comments) if data.comments else {},
is_public=data.is_public,
created_at=data.created_at,
)
def to_rldb_model(self) -> CustomRLDBModel:
# 转换为数据库模型
return CustomRLDBModel(
id=self.id,
user_id=self.user_id,
name=self.name,
gender=self.gender,
birth=self.birth,
phone=self.phone,
email=self.email,
height=self.height,
weight=self.weight,
images=json.dumps(self.images, ensure_ascii=False),
scores=self.scores,
degree=self.degree,
academy=self.academy,
occupation=self.occupation,
income=self.income,
assets=self.assets,
current_assets=self.current_assets,
house=self.house,
car=self.car,
registered_city=self.registered_city,
live_city=self.live_city,
native_place=self.native_place,
original_family=self.original_family,
is_single_child=self.is_single_child,
match_requirement=self.match_requirement,
introductions=json.dumps(self.introductions, ensure_ascii=False),
custom_level=self.custom_level,
comments=json.dumps(self.comments, ensure_ascii=False),
is_public=self.is_public,
)
def validate(self) -> error:
# 数据校验逻辑
if not self.name:
return error(ErrorCode.INVALID_PARAMS, "Name cannot be empty.")
if not self.gender:
return error(ErrorCode.INVALID_PARAMS, "Gender cannot be empty.")
if self.gender not in ['', '']:
return error(ErrorCode.INVALID_PARAMS, "Gender must be '' or ''.")
current_year = datetime.now().year
min_birth_year = 1950
max_birth_year = current_year - 18
if not isinstance(self.birth, int):
return error(ErrorCode.INVALID_PARAMS, "Birth year must be an integer.")
if self.birth < min_birth_year or self.birth > max_birth_year:
return error(ErrorCode.INVALID_PARAMS, f"Birth year must be between {min_birth_year} and {max_birth_year}.")
valid_houses = ["", "有房无贷", "有房有贷", "无自有房"]
if self.house not in valid_houses:
return error(ErrorCode.INVALID_PARAMS, f"House must be one of {valid_houses}")
valid_cars = ["", "有车无贷", "有车有贷", "无自有车"]
if self.car not in valid_cars:
return error(ErrorCode.INVALID_PARAMS, f"Car must be one of {valid_cars}")
# ... 可根据需要添加更多校验 ...
return error(ErrorCode.SUCCESS, "")

98
src/services/custom.py Normal file
View File

@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
# created by mmmy on 2025-11-27
import logging
import uuid
from models.custom import Custom, CustomRLDBModel
from utils.error import ErrorCode, error
from utils import rldb
class CustomService:
def __init__(self):
self.rldb = rldb.get_instance()
def save(self, custom: Custom) -> (str, error):
"""
保存客户到数据库。
如果 custom.id 存在,则更新;否则,创建。
:param custom: 客户对象
:return: 客户ID 和 错误对象
"""
# 0. 生成 custom id
custom.id = custom.id if custom.id else uuid.uuid4().hex
# 1. 转换模型,并保存到 SQL 数据库
try:
custom_orm = custom.to_rldb_model()
self.rldb.upsert(custom_orm)
return custom.id, error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to save custom {custom.id}: {e}")
return "", error(ErrorCode.RLDB_ERROR, f"Failed to save custom data: {str(e)}")
def delete(self, custom_id: str) -> error:
"""
从数据库删除客户。
:param custom_id: 客户ID
:return: 错误对象
"""
try:
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
if not custom_orm:
return error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
self.rldb.delete(custom_orm)
return error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to delete custom {custom_id}: {e}")
return error(ErrorCode.RLDB_ERROR, f"Failed to delete custom data: {str(e)}")
def get(self, custom_id: str) -> (Custom, error):
"""
从数据库获取单个客户。
:param custom_id: 客户ID
:return: 客户对象 和 错误对象
"""
try:
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
if not custom_orm:
return None, error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
custom = Custom.from_rldb_model(custom_orm)
return custom, error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to get custom {custom_id}: {e}")
return None, error(ErrorCode.RLDB_ERROR, f"Failed to retrieve custom data: {str(e)}")
def list(self, conds: dict = None, limit: int = 10, offset: int = 0) -> (list[Custom], error):
"""
根据条件从数据库列出客户(支持分页)。
:param conds: 查询条件字典
:param limit: 每页数量
:param offset: 偏移量
:return: 客户对象列表 和 错误对象
"""
if conds is None:
conds = {}
try:
custom_orms = self.rldb.query(CustomRLDBModel, limit=limit, offset=offset, **conds)
customs = [Custom.from_rldb_model(orm) for orm in custom_orms]
return customs, error(ErrorCode.SUCCESS, "")
except Exception as e:
logging.error(f"Failed to list customs with conds {conds}: {e}")
return [], error(ErrorCode.RLDB_ERROR, f"Failed to list custom data: {str(e)}")
# --- Singleton Pattern ---
custom_service = None
def init():
"""初始化 CustomService 单例"""
global custom_service
custom_service = CustomService()
def get_instance() -> CustomService:
"""获取 CustomService 单例"""
return custom_service

View File

@@ -7,6 +7,7 @@ class ErrorCode(Enum):
SUCCESS = 0
MODEL_ERROR = 1000
RLDB_ERROR = 2100
RLDB_NOT_FOUND = 2101
OBS_ERROR = 3100
OBS_INPUT_ERROR = 3102
OBS_SERVICE_ERROR = 3103

View File

@@ -1,19 +1,15 @@
import os
import time
import uuid
import logging
from typing import Any, Optional, Literal
from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Response, APIRouter, Depends, Request
from pydantic import BaseModel
from fastapi import FastAPI, UploadFile, File, APIRouter, Depends
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from services.people import get_instance as get_people_service
from services.user import get_instance as get_user_service
from web.auth import require_auth
from models.people import People
from agents.extract_people_agent import ExtractPeopleAgent
from utils import obs, ocr
from utils.config import get_instance as get_config
from utils.error import ErrorCode
from utils import obs
from web.schemas import BaseResponse
from web.custom import router as custom_router
from web.people import router as people_router
from web.user import router as user_router
from web.recognition import router as recognition_router
api = FastAPI(title="Single People Management and Searching", version="0.1")
api.add_middleware(
@@ -26,18 +22,10 @@ api.add_middleware(
authorized_router = APIRouter(dependencies=[Depends(require_auth)])
class BaseResponse(BaseModel):
error_code: int
error_info: str
data: Optional[Any] = None
@api.post("/api/ping")
async def ping():
return BaseResponse(error_code=0, error_info="success")
class PostInputRequest(BaseModel):
text: str
@authorized_router.post("/api/upload/image")
async def post_upload_image(image: UploadFile = File(...)):
# 实现上传图片的处理
@@ -49,439 +37,24 @@ async def post_upload_image(image: UploadFile = File(...)):
# 保存文件到对象存储
file_path = f"uploads/{unique_filename}"
obs_util = obs.get_instance()
obs_util.put(file_path, await image.read())
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
return BaseResponse(error_code=0, error_info="success", data=obs_url)
@authorized_router.post("/api/recognition/input")
async def post_recognition_input(request: PostInputRequest):
people = extract_people(request.text)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
@authorized_router.post("/api/recognition/image")
async def post_recognition_image(image: UploadFile = File(...)):
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"uploads/{unique_filename}"
obs_util = obs.get_instance()
obs_util.put(file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
# 调用OCR处理图片
ocr_util = ocr.get_instance()
ocr_result = ocr_util.recognize_image_text(obs_url)
logging.info(f"ocr_result: {ocr_result}")
people = extract_people(ocr_result, obs_url)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
def extract_people(text: str, cover_link: str = None) -> People:
extra_agent = ExtractPeopleAgent()
people = extra_agent.extract_people_info(text)
people.cover = cover_link
logging.info(f"people: {people}")
return people
class PostPeopleRequest(BaseModel):
people: dict
@authorized_router.post("/api/people")
async def post_people(request: Request, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.user_id = getattr(request.state, 'user_id', '')
service = get_people_service()
people.id, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=people.id)
@authorized_router.put("/api/people/{people_id}")
async def update_people(request: Request, people_id: str, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.id = people_id
service = get_people_service()
res, error = service.get(people_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
people.user_id = res.user_id
_, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@authorized_router.delete("/api/people/{people_id}")
async def delete_people(request: Request, people_id: str):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
class GetPeopleRequest(BaseModel):
query: Optional[str] = None
conds: Optional[dict] = None
top_k: int = 5
@authorized_router.get("/api/peoples")
async def get_peoples(
request: Request,
name: Optional[str] = Query(None, description="姓名"),
gender: Optional[str] = Query(None, description="性别"),
age: Optional[int] = Query(None, description="年龄"),
height: Optional[int] = Query(None, description="身高"),
marital_status: Optional[str] = Query(None, description="婚姻状态"),
limit: int = Query(10, description="分页大小"),
offset: int = Query(0, description="分页偏移量"),
):
# 解析查询参数为字典
conds = {}
conds["user_id"] = getattr(request.state, 'user_id', '')
if name:
conds["name"] = name
if gender:
conds["gender"] = gender
if age:
conds["age"] = age
if height:
conds["height"] = height
if marital_status:
conds["marital_status"] = marital_status
logging.info(f"conds: , limit: {limit}, offset: {offset}")
results = []
service = get_people_service()
results, error = service.list(conds, limit=limit, offset=offset)
logging.info(f"query results: {results}")
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
peoples = [people.to_dict() for people in results]
return BaseResponse(error_code=0, error_info="success", data=peoples)
class RemarkRequest(BaseModel):
content: str
@authorized_router.post("/api/people/{people_id}/remark")
async def post_people_remark(request: Request, people_id: str, body: RemarkRequest):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.save_remark(people_id, body.content)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@authorized_router.delete("/api/people/{people_id}/remark")
async def delete_people_remark(request: Request, people_id: str):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete_remark(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@authorized_router.post("/api/people/{people_id}/image")
async def post_people_image(request: Request, people_id: str, image: UploadFile = File(...)):
# 检查 people id 是否存在
service = get_people_service()
people, err = service.get(people_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if people.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"peoples/{people_id}/images/{unique_filename}"
obs_util = obs.get_instance()
obs_util.put(file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
return BaseResponse(error_code=0, error_info="success", data=obs_url)
@authorized_router.delete("/api/people/{people_id}/image")
async def delete_people_image(request: Request, people_id: str, image_url: str):
# 检查 people id 是否存在
service = get_people_service()
people, err = service.get(people_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if people.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 检查 image_url 是否是该 people 名下的图片链接
obs_util = obs.get_instance()
obs_path, err = obs_util.get_obs_path_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if not obs_path.startswith(f"peoples/{people_id}/images/"):
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {people_id} 名下的图片链接")
# 实现删除图片的处理
# 删除对象存储中的文件
err = obs_util.delete_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
return BaseResponse(error_code=0, error_info="success")
class SendCodeRequest(BaseModel):
target_type: str
target: str
scene: Literal['register', 'update']
# scene: Literal['register', 'login']
@api.post("/api/user/send_code")
async def send_user_code(request: SendCodeRequest):
service = get_user_service()
err = service.send_code(request.target_type, request.target, request.scene)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success")
class RegisterRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
password: str
code: str
@api.post("/api/user")
async def user_register(request: RegisterRequest):
service = get_user_service()
from models.user import User
u = User(
nickname=request.nickname or "",
avatar_link=request.avatar_link or "",
email=request.email or "",
phone=request.phone or "",
password_hash=request.password,
)
uid, err = service.register(u, request.code)
if not err.success:
logging.error(f"register failed: {err}")
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data=uid)
class LoginRequest(BaseModel):
email: Optional[str] = None
phone: Optional[str] = None
password: str
@api.post("/api/user/login")
async def user_login(request: LoginRequest, response: Response):
service = get_user_service()
data, err = service.login(request.email, request.phone, request.password)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
ttl_days = conf.getint('auth', 'token_ttl_days', fallback=30)
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
cookie_secure = conf.getboolean('auth', 'cookie_secure', fallback=False)
cookie_samesite = conf.get('auth', 'cookie_samesite', fallback=None)
response.set_cookie(
key="token",
value=data.get('token', ''),
max_age=ttl_days * 24 * 3600,
httponly=True,
secure=cookie_secure,
samesite=cookie_samesite,
domain=cookie_domain,
path="/",
)
return BaseResponse(error_code=0, error_info="success", data={"expired_at": data.get('expired_at')})
@authorized_router.delete("/api/user/me/login")
async def user_logout(response: Response, request: Request):
service = get_user_service()
err = service.logout(getattr(request.state, 'token', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@authorized_router.delete("/api/user/me")
async def user_delete(response: Response, request: Request):
service = get_user_service()
err = service.delete_user(getattr(request.state, 'user_id', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@authorized_router.get("/api/user/me")
async def user_me(request: Request):
service = get_user_service()
user, err = service.get(getattr(request.state, 'user_id', None))
if not err.success or not user:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateMeRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
phone: Optional[str] = None
email: Optional[str] = None
@authorized_router.put("/api/user/me")
async def update_user_me(request: Request, body: UpdateMeRequest):
service = get_user_service()
user, err = service.update_profile(
getattr(request.state, 'user_id', None),
nickname=body.nickname,
avatar_link=body.avatar_link,
phone=body.phone,
email=body.email,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
@authorized_router.put("/api/user/me/avatar")
async def upload_avatar(request: Request, avatar: UploadFile = File(...)):
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="unauthorized")
file_extension = os.path.splitext(avatar.filename)[1]
timestamp = int(time.time())
avatar_path = f"users/{user_id}/avatar-{timestamp}{file_extension}"
try:
obs_util = obs.get_instance()
obs_util.Put(avatar_path, await avatar.read())
avatar_url = obs_util.Link(avatar_path)
user_service = get_user_service()
_, err = user_service.update_profile(user_id, avatar_link=avatar_url)
if not err.success:
raise HTTPException(status_code=500, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data={"avatar_link": avatar_url})
except Exception as e:
logging.error(f"upload avatar failed: {e}")
raise HTTPException(status_code=500, detail="upload avatar failed")
class UpdatePhoneRequest(BaseModel):
phone: str
code: str
@authorized_router.put("/api/user/me/phone")
async def update_user_phone(request: Request, body: UpdatePhoneRequest):
service = get_user_service()
user, err = service.update_phone_with_code(
getattr(request.state, 'user_id', None),
body.phone,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateEmailRequest(BaseModel):
email: str
code: str
@authorized_router.put("/api/user/me/email")
async def update_user_email(request: Request, body: UpdateEmailRequest):
service = get_user_service()
user, err = service.update_email_with_code(
getattr(request.state, 'user_id', None),
body.email,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
api.include_router(authorized_router)
# Register custom router
api.include_router(custom_router, dependencies=[Depends(require_auth)])
# Register people router
api.include_router(people_router, dependencies=[Depends(require_auth)])
# Register user router
api.include_router(user_router)
# Register recognition router
api.include_router(recognition_router, dependencies=[Depends(require_auth)])

151
src/web/custom.py Normal file
View File

@@ -0,0 +1,151 @@
import logging
import os
import uuid
from fastapi import APIRouter, Depends, Request, Query, UploadFile, File
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from models.custom import Custom
from services.custom import get_instance as get_custom_service
from utils.error import ErrorCode
from utils import obs
from web.schemas import BaseResponse
router = APIRouter(tags=["custom"])
class PostCustomRequest(BaseModel):
custom: dict
@router.post("/api/custom")
def create_custom(request: Request, post_custom_request: PostCustomRequest):
logging.debug(f"post_custom_request: {post_custom_request}")
custom = Custom.from_dict(post_custom_request.custom)
# Validate custom data
err = custom.validate()
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
custom.user_id = getattr(request.state, 'user_id', '')
service = get_custom_service()
custom.id, error = service.save(custom)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=custom.id)
@router.put("/api/custom/{custom_id}")
def update_custom(request: Request, custom_id: str, post_custom_request: PostCustomRequest):
logging.debug(f"post_custom_request: {post_custom_request}")
custom = Custom.from_dict(post_custom_request.custom)
custom.id = custom_id
# Validate custom data
err = custom.validate()
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
service = get_custom_service()
# Check permission
res, error = service.get(custom_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
custom.user_id = res.user_id # Ensure user_id is not changed or is set correctly
_, error = service.save(custom)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/custom/{custom_id}")
def delete_custom(request: Request, custom_id: str):
service = get_custom_service()
res, error = service.get(custom_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete(custom_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=custom_id)
@router.get("/api/customs")
def get_customs(request: Request, limit: int = Query(10, ge=1, le=1000), offset: int = Query(0, ge=0)):
service = get_custom_service()
res, error = service.list({'user_id': getattr(request.state, 'user_id', '')}, limit=limit, offset=offset)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
# custom对象转换为字典
customs = [custom.to_dict() for custom in res]
return BaseResponse(error_code=0, error_info="success", data=customs)
@router.get("/api/custom/{custom_id}")
def get_custom(request: Request, custom_id: str):
service = get_custom_service()
res, error = service.get(custom_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
return BaseResponse(error_code=0, error_info="success", data=res.to_dict())
@router.post("/api/custom/{custom_id}/image")
async def post_custom_image(request: Request, custom_id: str, image: UploadFile = File(...)):
# 检查 custom id 是否存在
service = get_custom_service()
custom, err = service.get(custom_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if custom.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"customs/{custom_id}/images/{unique_filename}"
obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
return BaseResponse(error_code=0, error_info="success", data=obs_url)
@router.delete("/api/custom/{custom_id}/image")
async def delete_custom_image(request: Request, custom_id: str, image_url: str):
# 检查 custom id 是否存在
service = get_custom_service()
custom, err = service.get(custom_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if custom.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 检查 image_url 是否是该 custom 名下的图片链接
obs_util = obs.get_instance()
obs_path, err = obs_util.get_obs_path_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if not obs_path.startswith(f"customs/{custom_id}/images/"):
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {custom_id} 名下的图片链接")
# 实现删除图片的处理
# 删除对象存储中的文件
err = obs_util.delete_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
return BaseResponse(error_code=0, error_info="success")

192
src/web/people.py Normal file
View File

@@ -0,0 +1,192 @@
import os
import uuid
import logging
from typing import Optional
from fastapi import APIRouter, Request, UploadFile, File, Query
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from services.people import get_instance as get_people_service
from models.people import People
from utils import obs
from utils.error import ErrorCode
from web.schemas import BaseResponse
router = APIRouter(tags=["people"])
class PostPeopleRequest(BaseModel):
people: dict
@router.post("/api/people")
async def post_people(request: Request, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.user_id = getattr(request.state, 'user_id', '')
service = get_people_service()
people.id, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=people.id)
@router.put("/api/people/{people_id}")
async def update_people(request: Request, people_id: str, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
people.id = people_id
service = get_people_service()
res, error = service.get(people_id)
if not error.success or not res:
return BaseResponse(error_code=error.code, error_info=error.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
people.user_id = res.user_id
_, error = service.save(people)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/people/{people_id}")
async def delete_people(request: Request, people_id: str):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
class GetPeopleRequest(BaseModel):
query: Optional[str] = None
conds: Optional[dict] = None
top_k: int = 5
@router.get("/api/peoples")
async def get_peoples(
request: Request,
name: Optional[str] = Query(None, description="姓名"),
gender: Optional[str] = Query(None, description="性别"),
age: Optional[int] = Query(None, description="年龄"),
height: Optional[int] = Query(None, description="身高"),
marital_status: Optional[str] = Query(None, description="婚姻状态"),
limit: int = Query(10, description="分页大小"),
offset: int = Query(0, description="分页偏移量"),
):
# 解析查询参数为字典
conds = {}
conds["user_id"] = getattr(request.state, 'user_id', '')
if name:
conds["name"] = name
if gender:
conds["gender"] = gender
if age:
conds["age"] = age
if height:
conds["height"] = height
if marital_status:
conds["marital_status"] = marital_status
logging.info(f"conds: , limit: {limit}, offset: {offset}")
results = []
service = get_people_service()
results, error = service.list(conds, limit=limit, offset=offset)
logging.info(f"query results: {results}")
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
peoples = [people.to_dict() for people in results]
return BaseResponse(error_code=0, error_info="success", data=peoples)
class RemarkRequest(BaseModel):
content: str
@router.post("/api/people/{people_id}/remark")
async def post_people_remark(request: Request, people_id: str, body: RemarkRequest):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.save_remark(people_id, body.content)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/people/{people_id}/remark")
async def delete_people_remark(request: Request, people_id: str):
service = get_people_service()
res, err = service.get(people_id)
if not err.success or not res:
return BaseResponse(error_code=err.code, error_info=err.info)
if res.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
error = service.delete_remark(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@router.post("/api/people/{people_id}/image")
async def post_people_image(request: Request, people_id: str, image: UploadFile = File(...)):
# 检查 people id 是否存在
service = get_people_service()
people, err = service.get(people_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if people.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"peoples/{people_id}/images/{unique_filename}"
obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
return BaseResponse(error_code=0, error_info="success", data=obs_url)
@router.delete("/api/people/{people_id}/image")
async def delete_people_image(request: Request, people_id: str, image_url: str):
# 检查 people id 是否存在
service = get_people_service()
people, err = service.get(people_id)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if people.user_id != getattr(request.state, 'user_id', ''):
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
# 检查 image_url 是否是该 people 名下的图片链接
obs_util = obs.get_instance()
obs_path, err = obs_util.get_obs_path_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
if not obs_path.startswith(f"peoples/{people_id}/images/"):
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {people_id} 名下的图片链接")
# 实现删除图片的处理
# 删除对象存储中的文件
err = obs_util.delete_by_link(image_url)
if not err.success:
return BaseResponse(error_code=err.code, error_info=err.info)
return BaseResponse(error_code=0, error_info="success")

91
src/web/recognition.py Normal file
View File

@@ -0,0 +1,91 @@
import os
import uuid
import logging
from typing import Optional
from fastapi import APIRouter, UploadFile, File
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel
from models.people import People
from models.custom import Custom
from agents.extract_people_agent import ExtractPeopleAgent
from agents.extract_custom_agent import ExtractCustomAgent
from utils import obs, ocr
from web.schemas import BaseResponse
from utils.error import ErrorCode
router = APIRouter(tags=["recognition"])
def extract_people(text: str, cover_link: str = None) -> People:
extra_agent = ExtractPeopleAgent()
people = extra_agent.extract_people_info(text)
if people:
people.cover = cover_link
logging.info(f"people: {people}")
return people
def extract_custom(text: str, image_link: str = None) -> Custom:
extra_agent = ExtractCustomAgent()
custom = extra_agent.extract_custom_info(text)
if custom:
if image_link:
custom.images = [image_link]
logging.info(f"custom: {custom}")
return custom
class PostInputRequest(BaseModel):
text: str
@router.post("/api/recognition/{model}/input")
async def post_recognition_input(model: str, request: PostInputRequest):
if model == "people":
result = await run_in_threadpool(extract_people, request.text)
elif model == "custom":
result = await run_in_threadpool(extract_custom, request.text)
else:
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
if result is None:
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
resp = BaseResponse(error_code=0, error_info="success")
resp.data = result.to_dict()
return resp
@router.post("/api/recognition/{model}/image")
async def post_recognition_image(model: str, image: UploadFile = File(...)):
if model not in ["people", "custom"]:
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
# 实现上传图片的处理
# 保存上传的图片文件
# 生成唯一的文件名
file_extension = os.path.splitext(image.filename)[1]
unique_filename = f"{uuid.uuid4()}{file_extension}"
# 保存文件到对象存储
file_path = f"uploads/{model}/{unique_filename}"
obs_util = obs.get_instance()
await run_in_threadpool(obs_util.put, file_path, await image.read())
# 获取对象存储外链
obs_url = obs_util.get_link(file_path)
logging.info(f"obs_url: {obs_url}")
# 调用OCR处理图片
ocr_util = ocr.get_instance()
ocr_result = await run_in_threadpool(ocr_util.recognize_image_text, obs_url)
logging.info(f"ocr_result: {ocr_result}")
if model == "people":
result = await run_in_threadpool(extract_people, ocr_result, obs_url)
elif model == "custom":
result = await run_in_threadpool(extract_custom, ocr_result, obs_url)
if result is None:
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
resp = BaseResponse(error_code=0, error_info="success")
resp.data = result.to_dict()
return resp

7
src/web/schemas.py Normal file
View File

@@ -0,0 +1,7 @@
from typing import Any, Optional
from pydantic import BaseModel
class BaseResponse(BaseModel):
error_code: int
error_info: str
data: Optional[Any] = None

223
src/web/user.py Normal file
View File

@@ -0,0 +1,223 @@
import os
import time
import logging
from typing import Optional, Literal
from fastapi import APIRouter, Depends, Request, HTTPException, Response, UploadFile, File
from pydantic import BaseModel
from services.user import get_instance as get_user_service
from web.auth import require_auth
from utils import obs
from utils.config import get_instance as get_config
from web.schemas import BaseResponse
router = APIRouter(tags=["user"])
class SendCodeRequest(BaseModel):
target_type: str
target: str
scene: Literal['register', 'update']
# scene: Literal['register', 'login']
@router.post("/api/user/send_code")
async def send_user_code(request: SendCodeRequest):
service = get_user_service()
err = service.send_code(request.target_type, request.target, request.scene)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success")
class RegisterRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
password: str
code: str
@router.post("/api/user")
async def user_register(request: RegisterRequest):
service = get_user_service()
from models.user import User
u = User(
nickname=request.nickname or "",
avatar_link=request.avatar_link or "",
email=request.email or "",
phone=request.phone or "",
password_hash=request.password,
)
uid, err = service.register(u, request.code)
if not err.success:
logging.error(f"register failed: {err}")
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data=uid)
class LoginRequest(BaseModel):
email: Optional[str] = None
phone: Optional[str] = None
password: str
@router.post("/api/user/login")
async def user_login(request: LoginRequest, response: Response):
service = get_user_service()
data, err = service.login(request.email, request.phone, request.password)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
ttl_days = conf.getint('auth', 'token_ttl_days', fallback=30)
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
cookie_secure = conf.getboolean('auth', 'cookie_secure', fallback=False)
cookie_samesite = conf.get('auth', 'cookie_samesite', fallback=None)
response.set_cookie(
key="token",
value=data.get('token', ''),
max_age=ttl_days * 24 * 3600,
httponly=True,
secure=cookie_secure,
samesite=cookie_samesite,
domain=cookie_domain,
path="/",
)
return BaseResponse(error_code=0, error_info="success", data={"expired_at": data.get('expired_at')})
@router.delete("/api/user/me/login", dependencies=[Depends(require_auth)])
async def user_logout(response: Response, request: Request):
service = get_user_service()
err = service.logout(getattr(request.state, 'token', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@router.delete("/api/user/me", dependencies=[Depends(require_auth)])
async def user_delete(response: Response, request: Request):
service = get_user_service()
err = service.delete_user(getattr(request.state, 'user_id', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@router.get("/api/user/me", dependencies=[Depends(require_auth)])
async def user_me(request: Request):
service = get_user_service()
user, err = service.get(getattr(request.state, 'user_id', None))
if not err.success or not user:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateMeRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
phone: Optional[str] = None
email: Optional[str] = None
@router.put("/api/user/me", dependencies=[Depends(require_auth)])
async def update_user_me(request: Request, body: UpdateMeRequest):
service = get_user_service()
user, err = service.update_profile(
getattr(request.state, 'user_id', None),
nickname=body.nickname,
avatar_link=body.avatar_link,
phone=body.phone,
email=body.email,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
@router.put("/api/user/me/avatar", dependencies=[Depends(require_auth)])
async def upload_avatar(request: Request, avatar: UploadFile = File(...)):
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="unauthorized")
file_extension = os.path.splitext(avatar.filename)[1]
timestamp = int(time.time())
avatar_path = f"users/{user_id}/avatar-{timestamp}{file_extension}"
try:
obs_util = obs.get_instance()
obs_util.Put(avatar_path, await avatar.read())
avatar_url = obs_util.Link(avatar_path)
user_service = get_user_service()
_, err = user_service.update_profile(user_id, avatar_link=avatar_url)
if not err.success:
raise HTTPException(status_code=500, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data={"avatar_link": avatar_url})
except Exception as e:
logging.error(f"upload avatar failed: {e}")
raise HTTPException(status_code=500, detail="upload avatar failed")
class UpdatePhoneRequest(BaseModel):
phone: str
code: str
@router.put("/api/user/me/phone", dependencies=[Depends(require_auth)])
async def update_user_phone(request: Request, body: UpdatePhoneRequest):
service = get_user_service()
user, err = service.update_phone_with_code(
getattr(request.state, 'user_id', None),
body.phone,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateEmailRequest(BaseModel):
email: str
code: str
@router.put("/api/user/me/email", dependencies=[Depends(require_auth)])
async def update_user_email(request: Request, body: UpdateEmailRequest):
service = get_user_service()
user, err = service.update_email_with_code(
getattr(request.state, 'user_id', None),
body.email,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)