Compare commits
7 Commits
dev
...
fca2b1449f
| Author | SHA1 | Date | |
|---|---|---|---|
| fca2b1449f | |||
| 01f6003d35 | |||
| d6d6bc3bc8 | |||
| 2e928310cf | |||
| 40a39a0f1a | |||
| dd4e0c24a8 | |||
| 52d1bc5cf4 |
@@ -1,18 +1,18 @@
|
||||
[project]
|
||||
name = "service"
|
||||
version = "0.1"
|
||||
version = "0.1.0"
|
||||
description = "This project is the web servcie sub-system for if.u projuect"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"alibabacloud-ocr-api20210707>=3.1.3",
|
||||
"alibabacloud-tea-openapi>=0.4.1",
|
||||
"fastapi>=0.118.3",
|
||||
"langchain==0.3.27",
|
||||
"langchain-openai==0.3.35",
|
||||
"chromadb>=1.1.1",
|
||||
"fastapi>=0.118.2",
|
||||
"langchain>=0.3.27",
|
||||
"langchain-openai>=0.3.35",
|
||||
"numpy>=2.3.3",
|
||||
"pymysql>=1.1.2",
|
||||
"python-multipart>=0.0.20",
|
||||
"qiniu>=7.17.0",
|
||||
"sqlalchemy>=2.0.44",
|
||||
"uvicorn>=0.38.0",
|
||||
"requests>=2.32.5",
|
||||
]
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from utils.config import get_instance as get_config
|
||||
|
||||
class BaseAgent:
|
||||
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
|
||||
config = get_config()
|
||||
llm_api_url = api_url or config.get("ai", "llm_api_url")
|
||||
llm_api_key = api_key or config.get("ai", "llm_api_key")
|
||||
llm_model_name = model_name or config.get("ai", "llm_model_name")
|
||||
self.llm = ChatOpenAI(
|
||||
openai_api_key=llm_api_key,
|
||||
openai_api_base=llm_api_url,
|
||||
model_name=llm_model_name,
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class SummaryPeopleAgent(BaseAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
@@ -1,85 +0,0 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
from models.custom import Custom
|
||||
|
||||
class ExtractCustomAgent(BaseAgent):
|
||||
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
|
||||
super().__init__(api_url, api_key, model_name)
|
||||
self.prompt = ChatPromptTemplate.from_messages([
|
||||
(
|
||||
"system",
|
||||
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')},"
|
||||
"你是一个专业的客户信息录入助手,善于从一段文字描述中,精确获取客户的以下信息:\n"
|
||||
"姓名 name\n"
|
||||
"性别 gender (男/女/未知)\n"
|
||||
"出生年份 birth (整数年份,如 1990;若文本只提供了年龄,请根据当前日期计算出出生年份)\n"
|
||||
"手机号 phone\n"
|
||||
"邮箱 email\n"
|
||||
"身高(cm) height (整数)\n"
|
||||
"体重(kg) weight (整数)\n"
|
||||
"学历 degree\n"
|
||||
"毕业院校 academy\n"
|
||||
"职业 occupation\n"
|
||||
"年收入(万) income (整数)\n"
|
||||
"资产(万) assets (整数)\n"
|
||||
"流动资产(万) current_assets (整数)\n"
|
||||
"房产情况 house (必须为以下之一: '有房无贷', '有房有贷', '无自有房', 若未提及则不填)\n"
|
||||
"车辆情况 car (必须为以下之一: '有车无贷', '有车有贷', '无自有车', 若未提及则不填)\n"
|
||||
"户口城市 registered_city\n"
|
||||
"居住城市 live_city\n"
|
||||
"籍贯 native_place\n"
|
||||
"原生家庭情况 original_family\n"
|
||||
"是否独生子女 is_single_child (true/false)\n"
|
||||
"择偶要求 match_requirement\n"
|
||||
"\n"
|
||||
"以上信息需要严格按照 JSON 格式输出,字段名与条目中英文保持一致。\n"
|
||||
"若未识别到某项,则不返回该字段,不要自行填写“未知”、“未填写”等。\n"
|
||||
"\n"
|
||||
"除了上述基本信息,还有一个字段:\n"
|
||||
"其他介绍 introductions\n"
|
||||
"其余的信息需要按照字典的方式进行提炼和总结,都放在 introductions 字段中,key 使用提炼好的中文。\n"
|
||||
),
|
||||
("human", "{input}")
|
||||
])
|
||||
|
||||
def extract_custom_info(self, text: str) -> Custom:
|
||||
"""从文本中提取客户信息"""
|
||||
prompt = self.prompt.format_prompt(input=text)
|
||||
response = self.llm.invoke(prompt)
|
||||
logging.info(f"llm response: {response.content}")
|
||||
try:
|
||||
custom_dict = json.loads(response.content)
|
||||
|
||||
# 类型安全转换,防止LLM返回字符串类型的数字
|
||||
int_fields = ['birth', 'height', 'weight', 'income', 'assets', 'scores', 'current_assets']
|
||||
for field in int_fields:
|
||||
if field in custom_dict and isinstance(custom_dict[field], str):
|
||||
try:
|
||||
# 尝试提取数字,简单处理
|
||||
import re
|
||||
num = re.findall(r'\d+', custom_dict[field])
|
||||
if num:
|
||||
custom_dict[field] = int(num[0])
|
||||
else:
|
||||
del custom_dict[field] # 无法转换则移除
|
||||
except:
|
||||
del custom_dict[field]
|
||||
|
||||
custom = Custom.from_dict(custom_dict)
|
||||
err = custom.validate()
|
||||
if not err.success:
|
||||
logging.warning(f"Validation warning: {err.info}")
|
||||
# 即使校验失败(如某些必填项缺失),也尽可能返回已提取的对象,
|
||||
# 让上层业务逻辑决定是否接受或需要补充
|
||||
return custom
|
||||
except json.JSONDecodeError:
|
||||
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to process custom info: {e}")
|
||||
return None
|
||||
@@ -1,19 +1,25 @@
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
from .base_agent import BaseAgent
|
||||
from models.people import People
|
||||
|
||||
class BaseAgent:
|
||||
def __init__(self):
|
||||
self.llm = ChatOpenAI(
|
||||
openai_api_key="56d82040-85c7-4701-8f87-734985e27909",
|
||||
openai_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||
model_name="ep-20250722161445-n9lfq"
|
||||
)
|
||||
pass
|
||||
|
||||
class ExtractPeopleAgent(BaseAgent):
|
||||
def __init__(self, api_url: str = None, api_key: str = None, model_name: str = None):
|
||||
super().__init__(api_url, api_key, model_name)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.prompt = ChatPromptTemplate.from_messages([
|
||||
(
|
||||
"system",
|
||||
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')},"
|
||||
"你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n"
|
||||
"姓名 name\n"
|
||||
"性别 gender\n"
|
||||
@@ -21,9 +27,7 @@ class ExtractPeopleAgent(BaseAgent):
|
||||
"身高(cm) height\n"
|
||||
"婚姻状况 marital_status\n"
|
||||
"择偶要求 match_requirement\n"
|
||||
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致; 若未识别到以上的某项,则不返回该字段,不要自行填写“未知”,“未填写”等类似字眼。\n"
|
||||
"其中,'年龄 age' 和 '身高(cm) height' 必须是一个整数,不能是一个字符串;\n"
|
||||
"并且,'性别 gender' 根据识别结果,必须从 男,女,未知 三选一填写。\n"
|
||||
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n"
|
||||
"除了上述基本信息,还有一个字段\n"
|
||||
"个人介绍 introduction\n"
|
||||
"其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n"
|
||||
@@ -38,15 +42,13 @@ class ExtractPeopleAgent(BaseAgent):
|
||||
response = self.llm.invoke(prompt)
|
||||
logging.info(f"llm response: {response.content}")
|
||||
try:
|
||||
people = People.from_dict(json.loads(response.content))
|
||||
err = people.validate()
|
||||
if not err.success:
|
||||
raise ValueError(f"Failed to validate people info: {err.info}")
|
||||
return people
|
||||
return People.from_dict(json.loads(response.content))
|
||||
except json.JSONDecodeError:
|
||||
logging.error(f"Failed to parse JSON from LLM response: {response.content}")
|
||||
return None
|
||||
except ValueError as e:
|
||||
logging.error(f"Failed to validate people info: {e}")
|
||||
return None
|
||||
pass
|
||||
|
||||
class SummaryPeopleAgent(BaseAgent):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
pass
|
||||
130
src/app/api.py
Normal file
130
src/app/api.py
Normal file
@@ -0,0 +1,130 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from typing import Any, Optional
|
||||
from fastapi import FastAPI, File, UploadFile, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ai.agent import ExtractPeopleAgent
|
||||
from models.people import People
|
||||
from utils import obs, ocr, vsdb
|
||||
from storage.people_store import get_instance as get_people_store
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
|
||||
api.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
error_code: int
|
||||
error_info: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
class PostInputRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
@api.post("/input")
|
||||
async def post_input(request: PostInputRequest):
|
||||
extra_agent = ExtractPeopleAgent()
|
||||
people = extra_agent.extract_people_info(request.text)
|
||||
logging.info(f"people: {people}")
|
||||
resp = BaseResponse(error_code=0, error_info="success")
|
||||
resp.data = people.to_dict()
|
||||
return resp
|
||||
|
||||
@api.post("/input_image")
|
||||
async def post_input_image(image: UploadFile = File(...)):
|
||||
# 实现上传图片的处理
|
||||
# 保存上传的图片文件
|
||||
# 生成唯一的文件名
|
||||
file_extension = os.path.splitext(image.filename)[1]
|
||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# 确保uploads目录存在
|
||||
os.makedirs("uploads", exist_ok=True)
|
||||
|
||||
# 保存文件到对象存储
|
||||
file_path = f"uploads/{unique_filename}"
|
||||
obs_util = obs.get_instance()
|
||||
obs_util.Put(file_path, await image.read())
|
||||
|
||||
# 获取对象存储外链
|
||||
obs_url = obs_util.Link(file_path)
|
||||
logging.info(f"obs_url: {obs_url}")
|
||||
|
||||
# 调用OCR处理图片
|
||||
ocr_util = ocr.get_instance()
|
||||
ocr_result = ocr_util.recognize_image_text(obs_url)
|
||||
logging.info(f"ocr_result: {ocr_result}")
|
||||
|
||||
post_input_request = PostInputRequest(text=ocr_result)
|
||||
return await post_input(post_input_request)
|
||||
|
||||
class PostPeopleRequest(BaseModel):
|
||||
people: dict
|
||||
|
||||
@api.post("/peoples")
|
||||
async def post_people(post_people_request: PostPeopleRequest):
|
||||
logging.debug(f"post_people_request: {post_people_request}")
|
||||
people = People.from_dict(post_people_request.people)
|
||||
store = get_people_store()
|
||||
people.id = store.save(people)
|
||||
return BaseResponse(error_code=0, error_info="success", data=people.id)
|
||||
|
||||
class GetPeopleRequest(BaseModel):
|
||||
query: Optional[str] = None
|
||||
conds: Optional[dict] = None
|
||||
top_k: int = 5
|
||||
|
||||
@api.get("/peoples")
|
||||
async def get_peoples(
|
||||
name: Optional[str] = Query(None, description="姓名"),
|
||||
gender: Optional[str] = Query(None, description="性别"),
|
||||
age: Optional[int] = Query(None, description="年龄"),
|
||||
height: Optional[int] = Query(None, description="身高"),
|
||||
marital_status: Optional[str] = Query(None, description="婚姻状态"),
|
||||
limit: int = Query(10, description="分页大小"),
|
||||
offset: int = Query(0, description="分页偏移量"),
|
||||
search: Optional[str] = Query(None, description="搜索内容"),
|
||||
top_k: int = Query(5, description="搜索结果数量"),
|
||||
):
|
||||
|
||||
# 解析查询参数为字典
|
||||
conds = {}
|
||||
if name:
|
||||
conds["name"] = name
|
||||
if gender:
|
||||
conds["gender"] = gender
|
||||
if age:
|
||||
conds["age"] = age
|
||||
if height:
|
||||
conds["height"] = height
|
||||
if marital_status:
|
||||
conds["marital_status"] = marital_status
|
||||
|
||||
logging.info(f"conds: , limit: {limit}, offset: {offset}, search: {search}, top_k: {top_k}")
|
||||
|
||||
results = []
|
||||
store = get_people_store()
|
||||
if search:
|
||||
results = store.search(search, conds, ids=None, top_k=top_k)
|
||||
logging.info(f"search results: {results}")
|
||||
else:
|
||||
results = store.query(conds, limit=limit, offset=offset)
|
||||
logging.info(f"query results: {results}")
|
||||
peoples = [people.to_dict() for people in results]
|
||||
return BaseResponse(error_code=0, error_info="success", data=peoples)
|
||||
|
||||
@api.delete("/peoples/{people_id}")
|
||||
async def delete_people(people_id: str):
|
||||
store = get_people_store()
|
||||
store.delete(people_id)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
52
src/app/app.py
Normal file
52
src/app/app.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# created by mmmy on 2025-09-27
|
||||
|
||||
import logging
|
||||
from ai.agent import ExtractPeopleAgent
|
||||
from utils import ocr, vsdb, obs
|
||||
from models.people import People
|
||||
from fastapi import FastAPI
|
||||
|
||||
api = FastAPI(title="Single People Management and Searching", version="1.0.0")
|
||||
|
||||
class App:
|
||||
def __init__(self):
|
||||
self.extract_people_agent = ExtractPeopleAgent()
|
||||
self.ocr = ocr.get_instance()
|
||||
self.vedb = vsdb.get_instance(db_type='chromadb')
|
||||
self.obs = obs.get_instance()
|
||||
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
class InputForPeople:
|
||||
image: bytes
|
||||
text: str
|
||||
|
||||
def input_to_people(self, input: InputForPeople) -> People:
|
||||
if not input.image and not input.text:
|
||||
return None
|
||||
if input.image:
|
||||
content = self.ocr.recognize_image_text(input.image)
|
||||
else:
|
||||
content = input.text
|
||||
print(content)
|
||||
people = self._trans_text_to_people(content)
|
||||
return people
|
||||
|
||||
def _trans_text_to_people(self, text: str) -> People:
|
||||
if not text:
|
||||
return None
|
||||
person = self.extract_people_agent.extract_people_info(text)
|
||||
print(person)
|
||||
return person
|
||||
|
||||
def create_people(self, people: People) -> bool:
|
||||
if not people:
|
||||
return False
|
||||
try:
|
||||
people.save()
|
||||
except Exception as e:
|
||||
logging.error(f"保存人物到数据库失败: {e}")
|
||||
return False
|
||||
return True
|
||||
54
src/main.py
54
src/main.py
@@ -1,35 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# created by mmmy on 2025-09-27
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Add src directory to sys.path to ensure modules can be imported correctly when running with uvicorn
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
import argparse
|
||||
from venv import logger
|
||||
import uvicorn
|
||||
from services import people as people_service
|
||||
from services import user as user_service
|
||||
from services import custom as custom_service
|
||||
from utils import config, logger, obs, ocr, rldb, sms, mailer
|
||||
|
||||
from web.api import api
|
||||
|
||||
def initialize_app(config_path):
|
||||
"""Initialize application components with the given config path."""
|
||||
config.init(config_path)
|
||||
conf = config.get_instance()
|
||||
|
||||
logger.init()
|
||||
rldb.init()
|
||||
ocr.init()
|
||||
obs.init()
|
||||
mailer.init(conf.get('mailer', 'type', fallback='real'))
|
||||
sms.init(conf.get('sms', 'type', fallback='real'))
|
||||
|
||||
people_service.init()
|
||||
user_service.init()
|
||||
custom_service.init()
|
||||
from app.api import api
|
||||
from utils import obs, ocr, vsdb, logger, config
|
||||
from storage import people_store
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
@@ -37,20 +15,16 @@ def main():
|
||||
parser = argparse.ArgumentParser(description='IF.u 服务')
|
||||
parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径')
|
||||
args = parser.parse_args()
|
||||
|
||||
initialize_app(args.config)
|
||||
|
||||
config.init(args.config)
|
||||
logger.init()
|
||||
obs.init()
|
||||
ocr.init()
|
||||
vsdb.init()
|
||||
people_store.init()
|
||||
conf = config.get_instance()
|
||||
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
|
||||
host = conf.get('web_service', 'server_host', fallback='127.0.0.1')
|
||||
port = conf.getint('web_service', 'server_port', fallback=8099)
|
||||
uvicorn.run("src.main:api", host=host, port=port, reload=True) # Modified to string import for reload support in main too, though api object also works
|
||||
uvicorn.run(api, host=host, port=port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
else:
|
||||
# Support for running via 'uvicorn src.main:api'
|
||||
# Use environment variable for config path or default
|
||||
main_path = os.path.dirname(os.path.abspath(__file__))
|
||||
default_config_path = os.path.join(main_path, '../configuration/test_conf.ini')
|
||||
config_path = os.environ.get('IFU_CONFIG_PATH', default_config_path)
|
||||
initialize_app(config_path)
|
||||
main()
|
||||
@@ -1,291 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# created by mmmy on 2025-11-27
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, func, Boolean
|
||||
from utils.rldb import RLDBBaseModel
|
||||
from utils.error import ErrorCode, error
|
||||
|
||||
class CustomRLDBModel(RLDBBaseModel):
|
||||
"""
|
||||
客户数据的数据库模型 (SQLAlchemy Model) - 更新版
|
||||
"""
|
||||
__tablename__ = 'customs'
|
||||
id = Column(String(36), primary_key=True)
|
||||
user_id = Column(String(36), index=True, nullable=False)
|
||||
|
||||
# 基本信息
|
||||
name = Column(String(255), index=True, nullable=False)
|
||||
gender = Column(String(10), nullable=False)
|
||||
birth = Column(Integer, nullable=False) # 出生年份
|
||||
phone = Column(String(50), index=True)
|
||||
email = Column(String(255), index=True)
|
||||
|
||||
# 外貌信息
|
||||
height = Column(Integer)
|
||||
weight = Column(Integer)
|
||||
images = Column(Text) # JSON string for list[str]
|
||||
scores = Column(Integer)
|
||||
|
||||
# 学历职业
|
||||
degree = Column(String(255))
|
||||
academy = Column(String(255))
|
||||
occupation = Column(String(255))
|
||||
income = Column(Integer) # 单位:万
|
||||
assets = Column(Integer) # 单位:万
|
||||
current_assets = Column(Integer) # 单位:万
|
||||
house = Column(String(50))
|
||||
car = Column(String(50))
|
||||
|
||||
# 户口家庭
|
||||
registered_city = Column(String(255))
|
||||
live_city = Column(String(255))
|
||||
native_place = Column(String(255))
|
||||
original_family = Column(Text)
|
||||
is_single_child = Column(Boolean, default=False)
|
||||
|
||||
match_requirement = Column(Text)
|
||||
|
||||
introductions = Column(Text) # JSON string for Dict[str, str]
|
||||
|
||||
# 客户信息
|
||||
custom_level = Column(String(255))
|
||||
comments = Column(Text) # JSON string for Dict[str, str]
|
||||
is_public = Column(Boolean, default=False)
|
||||
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
class Custom:
|
||||
"""
|
||||
客户数据的业务逻辑模型 (Business Logic Model) - 更新版
|
||||
"""
|
||||
id: str
|
||||
user_id: str
|
||||
|
||||
# 基本信息
|
||||
name: str
|
||||
gender: str
|
||||
birth: int
|
||||
phone: str
|
||||
email: str
|
||||
|
||||
# 外貌信息
|
||||
height: int
|
||||
weight: int
|
||||
images: List[str]
|
||||
scores: int
|
||||
|
||||
# 学历职业
|
||||
degree: str
|
||||
academy: str
|
||||
occupation: str
|
||||
income: int
|
||||
assets: int
|
||||
current_assets: int
|
||||
house: str
|
||||
car: str
|
||||
|
||||
# 户口家庭
|
||||
registered_city: str
|
||||
live_city: str
|
||||
native_place: str
|
||||
original_family: str
|
||||
is_single_child: bool
|
||||
|
||||
match_requirement: str
|
||||
|
||||
introductions: Dict[str, str]
|
||||
|
||||
# 客户信息
|
||||
custom_level: str
|
||||
comments: Dict[str, str]
|
||||
is_public: bool
|
||||
created_at: datetime = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# 初始化所有属性
|
||||
self.id = kwargs.get('id', '')
|
||||
self.user_id = kwargs.get('user_id', '')
|
||||
self.name = kwargs.get('name', '')
|
||||
self.gender = kwargs.get('gender', '未知')
|
||||
self.birth = kwargs.get('birth', 0)
|
||||
self.phone = kwargs.get('phone', '')
|
||||
self.email = kwargs.get('email', '')
|
||||
self.height = kwargs.get('height', 0)
|
||||
self.weight = kwargs.get('weight', 0)
|
||||
self.images = kwargs.get('images', [])
|
||||
self.scores = kwargs.get('scores', 0)
|
||||
self.degree = kwargs.get('degree', '')
|
||||
self.academy = kwargs.get('academy', '')
|
||||
self.occupation = kwargs.get('occupation', '')
|
||||
self.income = kwargs.get('income', 0)
|
||||
self.assets = kwargs.get('assets', 0)
|
||||
self.current_assets = kwargs.get('current_assets', 0)
|
||||
self.house = kwargs.get('house', '')
|
||||
self.car = kwargs.get('car', '')
|
||||
self.registered_city = kwargs.get('registered_city', '')
|
||||
self.live_city = kwargs.get('live_city', '')
|
||||
self.native_place = kwargs.get('native_place', '')
|
||||
self.original_family = kwargs.get('original_family', '')
|
||||
self.is_single_child = kwargs.get('is_single_child', False)
|
||||
self.match_requirement = kwargs.get('match_requirement', '')
|
||||
self.introductions = kwargs.get('introductions', {})
|
||||
self.custom_level = kwargs.get('custom_level', '')
|
||||
self.comments = kwargs.get('comments', {})
|
||||
self.is_public = kwargs.get('is_public', False)
|
||||
self.created_at = kwargs.get('created_at')
|
||||
|
||||
def __str__(self) -> str:
|
||||
# 返回对象的字符串表示
|
||||
attributes = ", ".join(f"{k}={v}" for k, v in self.to_dict().items())
|
||||
return f"Custom({attributes})"
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
# 从字典创建对象实例
|
||||
if 'created_at' in data and data['created_at'] is not None:
|
||||
data['created_at'] = datetime.fromtimestamp(data['created_at'])
|
||||
# 移除ORM特有的时间戳,避免初始化错误
|
||||
data.pop('updated_at', None)
|
||||
data.pop('deleted_at', None)
|
||||
return cls(**data)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# 将对象转换为字典
|
||||
return {
|
||||
'id': self.id,
|
||||
'user_id': self.user_id,
|
||||
'name': self.name,
|
||||
'gender': self.gender,
|
||||
'birth': self.birth,
|
||||
'phone': self.phone,
|
||||
'email': self.email,
|
||||
'height': self.height,
|
||||
'weight': self.weight,
|
||||
'images': self.images,
|
||||
'scores': self.scores,
|
||||
'degree': self.degree,
|
||||
'academy': self.academy,
|
||||
'occupation': self.occupation,
|
||||
'income': self.income,
|
||||
'assets': self.assets,
|
||||
'current_assets': self.current_assets,
|
||||
'house': self.house,
|
||||
'car': self.car,
|
||||
'registered_city': self.registered_city,
|
||||
'live_city': self.live_city,
|
||||
'native_place': self.native_place,
|
||||
'original_family': self.original_family,
|
||||
'is_single_child': self.is_single_child,
|
||||
'match_requirement': self.match_requirement,
|
||||
'introductions': self.introductions,
|
||||
'custom_level': self.custom_level,
|
||||
'comments': self.comments,
|
||||
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_rldb_model(cls, data: CustomRLDBModel):
|
||||
# 从数据库模型转换
|
||||
return cls(
|
||||
id=data.id,
|
||||
user_id=data.user_id,
|
||||
name=data.name,
|
||||
gender=data.gender,
|
||||
birth=data.birth,
|
||||
phone=data.phone,
|
||||
email=data.email,
|
||||
height=data.height,
|
||||
weight=data.weight,
|
||||
images=json.loads(data.images) if data.images else [],
|
||||
scores=data.scores,
|
||||
degree=data.degree,
|
||||
academy=data.academy,
|
||||
occupation=data.occupation,
|
||||
income=data.income,
|
||||
assets=data.assets,
|
||||
house=data.house,
|
||||
car=data.car,
|
||||
registered_city=data.registered_city,
|
||||
live_city=data.live_city,
|
||||
native_place=data.native_place,
|
||||
original_family=data.original_family,
|
||||
is_single_child=data.is_single_child,
|
||||
match_requirement=data.match_requirement,
|
||||
introductions=json.loads(data.introductions) if data.introductions else {},
|
||||
custom_level=data.custom_level,
|
||||
comments=json.loads(data.comments) if data.comments else {},
|
||||
is_public=data.is_public,
|
||||
created_at=data.created_at,
|
||||
)
|
||||
|
||||
def to_rldb_model(self) -> CustomRLDBModel:
|
||||
# 转换为数据库模型
|
||||
return CustomRLDBModel(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
name=self.name,
|
||||
gender=self.gender,
|
||||
birth=self.birth,
|
||||
phone=self.phone,
|
||||
email=self.email,
|
||||
height=self.height,
|
||||
weight=self.weight,
|
||||
images=json.dumps(self.images, ensure_ascii=False),
|
||||
scores=self.scores,
|
||||
degree=self.degree,
|
||||
academy=self.academy,
|
||||
occupation=self.occupation,
|
||||
income=self.income,
|
||||
assets=self.assets,
|
||||
current_assets=self.current_assets,
|
||||
house=self.house,
|
||||
car=self.car,
|
||||
registered_city=self.registered_city,
|
||||
live_city=self.live_city,
|
||||
native_place=self.native_place,
|
||||
original_family=self.original_family,
|
||||
is_single_child=self.is_single_child,
|
||||
match_requirement=self.match_requirement,
|
||||
introductions=json.dumps(self.introductions, ensure_ascii=False),
|
||||
custom_level=self.custom_level,
|
||||
comments=json.dumps(self.comments, ensure_ascii=False),
|
||||
is_public=self.is_public,
|
||||
)
|
||||
|
||||
def validate(self) -> error:
|
||||
# 数据校验逻辑
|
||||
if not self.name:
|
||||
return error(ErrorCode.INVALID_PARAMS, "Name cannot be empty.")
|
||||
|
||||
if not self.gender:
|
||||
return error(ErrorCode.INVALID_PARAMS, "Gender cannot be empty.")
|
||||
|
||||
if self.gender not in ['男', '女']:
|
||||
return error(ErrorCode.INVALID_PARAMS, "Gender must be '男' or '女'.")
|
||||
|
||||
current_year = datetime.now().year
|
||||
min_birth_year = 1950
|
||||
max_birth_year = current_year - 18
|
||||
|
||||
if not isinstance(self.birth, int):
|
||||
return error(ErrorCode.INVALID_PARAMS, "Birth year must be an integer.")
|
||||
|
||||
if self.birth < min_birth_year or self.birth > max_birth_year:
|
||||
return error(ErrorCode.INVALID_PARAMS, f"Birth year must be between {min_birth_year} and {max_birth_year}.")
|
||||
|
||||
valid_houses = ["", "有房无贷", "有房有贷", "无自有房"]
|
||||
if self.house not in valid_houses:
|
||||
return error(ErrorCode.INVALID_PARAMS, f"House must be one of {valid_houses}")
|
||||
|
||||
valid_cars = ["", "有车无贷", "有车有贷", "无自有车"]
|
||||
if self.car not in valid_cars:
|
||||
return error(ErrorCode.INVALID_PARAMS, f"Car must be one of {valid_cars}")
|
||||
|
||||
# ... 可根据需要添加更多校验 ...
|
||||
return error(ErrorCode.SUCCESS, "")
|
||||
@@ -1,69 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# created by mmmy on 2025-09-30
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, func
|
||||
from utils.rldb import RLDBBaseModel
|
||||
from utils.error import ErrorCode, error
|
||||
|
||||
class PeopleRLDBModel(RLDBBaseModel):
|
||||
__tablename__ = 'peoples'
|
||||
id = Column(String(36), primary_key=True)
|
||||
user_id = Column(String(36), index=True)
|
||||
name = Column(String(255), index=True)
|
||||
contact = Column(String(255), index=True)
|
||||
gender = Column(String(10))
|
||||
age = Column(Integer)
|
||||
height = Column(Integer)
|
||||
marital_status = Column(String(20))
|
||||
match_requirement = Column(Text)
|
||||
introduction = Column(Text)
|
||||
comments = Column(Text)
|
||||
cover = Column(String(255), nullable=True)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
|
||||
class Comment:
|
||||
# 评论内容
|
||||
content: str
|
||||
# 评论人
|
||||
author: str
|
||||
# 创建时间
|
||||
created_at: datetime
|
||||
# 更新时间
|
||||
updated_at: datetime
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.content = kwargs.get('content', '')
|
||||
self.author = kwargs.get('author', '')
|
||||
self.created_at = kwargs.get('created_at', datetime.now())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.now())
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'content': self.content,
|
||||
'author': self.author,
|
||||
'created_at': int(self.created_at.timestamp()),
|
||||
'updated_at': int(self.updated_at.timestamp()),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
data['created_at'] = datetime.fromtimestamp(data['created_at'])
|
||||
data['updated_at'] = datetime.fromtimestamp(data['updated_at'])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class People:
|
||||
# 数据库 ID
|
||||
id: str
|
||||
# 所属用户 ID
|
||||
user_id: str
|
||||
# 姓名
|
||||
name: str
|
||||
# 联系人
|
||||
@@ -74,23 +18,27 @@ class People:
|
||||
age: int
|
||||
# 身高(cm)
|
||||
height: int
|
||||
# 体重(kg)
|
||||
# weight: int
|
||||
# 婚姻状况
|
||||
# [
|
||||
# "未婚(single)",
|
||||
# "已婚(married)",
|
||||
# "离异(divorced)",
|
||||
# "丧偶(widowed)"
|
||||
# ]
|
||||
marital_status: str
|
||||
# 择偶要求
|
||||
match_requirement: str
|
||||
# 个人介绍
|
||||
introduction: Dict[str, str]
|
||||
# 总结评价
|
||||
comments: Dict[str, "Comment"]
|
||||
# 封面
|
||||
cover: str = None
|
||||
# 创建时间
|
||||
created_at: datetime = None
|
||||
comments: Dict[str, str]
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# 初始化所有属性,从kwargs中获取值,如果不存在则设置默认值
|
||||
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
|
||||
self.user_id = kwargs.get('user_id', '') if kwargs.get('user_id', '') is not None else ''
|
||||
self.name = kwargs.get('name', '') if kwargs.get('name', '') is not None else ''
|
||||
self.contact = kwargs.get('contact', '') if kwargs.get('contact', '') is not None else ''
|
||||
self.gender = kwargs.get('gender', '') if kwargs.get('gender', '') is not None else ''
|
||||
@@ -100,18 +48,15 @@ class People:
|
||||
self.match_requirement = kwargs.get('match_requirement', '') if kwargs.get('match_requirement', '') is not None else ''
|
||||
self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {}
|
||||
self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {}
|
||||
self.cover = kwargs.get('cover', None) if kwargs.get('cover', None) is not None else None
|
||||
self.created_at = kwargs.get('created_at', None)
|
||||
|
||||
def __str__(self) -> str:
|
||||
# 返回对象的字符串表示,包含所有属性
|
||||
return (f"People(id={self.id}, name={self.name}, contact={self.contact}, gender={self.gender}, "
|
||||
f"age={self.age}, height={self.height}, marital_status={self.marital_status}, "
|
||||
f"match_requirement={self.match_requirement}, introduction={self.introduction}, "
|
||||
f"comments={self.comments}, cover={self.cover}, created_at={self.created_at})")
|
||||
return self.tonl()
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
if 'created_at' in data:
|
||||
# 移除 created_at 字段,避免类型错误
|
||||
del data['created_at']
|
||||
if 'updated_at' in data:
|
||||
# 移除 updated_at 字段,避免类型错误
|
||||
del data['updated_at']
|
||||
@@ -120,30 +65,10 @@ class People:
|
||||
del data['deleted_at']
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_rldb_model(cls, data: PeopleRLDBModel):
|
||||
# 将关系数据库模型转换为对象
|
||||
return cls(
|
||||
id=data.id,
|
||||
user_id=data.user_id,
|
||||
name=data.name,
|
||||
contact=data.contact,
|
||||
gender=data.gender,
|
||||
age=data.age,
|
||||
height=data.height,
|
||||
marital_status=data.marital_status,
|
||||
match_requirement=data.match_requirement,
|
||||
introduction=json.loads(data.introduction) if data.introduction else {},
|
||||
comments={k: Comment.from_dict(v) for k, v in json.loads(data.comments).items()} if data.comments else {},
|
||||
cover=data.cover,
|
||||
created_at=data.created_at,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# 将对象转换为字典格式
|
||||
return {
|
||||
'id': self.id,
|
||||
'user_id': self.user_id,
|
||||
'name': self.name,
|
||||
'contact': self.contact,
|
||||
'gender': self.gender,
|
||||
@@ -152,40 +77,45 @@ class People:
|
||||
'marital_status': self.marital_status,
|
||||
'match_requirement': self.match_requirement,
|
||||
'introduction': self.introduction,
|
||||
'comments': {k: v.to_dict() for k, v in self.comments.items()},
|
||||
'cover': self.cover,
|
||||
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
|
||||
}
|
||||
|
||||
def to_rldb_model(self) -> PeopleRLDBModel:
|
||||
# 将对象转换为关系数据库模型
|
||||
return PeopleRLDBModel(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
name=self.name,
|
||||
contact=self.contact,
|
||||
gender=self.gender,
|
||||
age=self.age,
|
||||
height=self.height,
|
||||
marital_status=self.marital_status,
|
||||
match_requirement=self.match_requirement,
|
||||
introduction=json.dumps(self.introduction, ensure_ascii=False),
|
||||
comments=json.dumps({k: v.to_dict() for k, v in self.comments.items()}, ensure_ascii=False),
|
||||
cover=self.cover,
|
||||
)
|
||||
'comments': self.comments,
|
||||
}
|
||||
|
||||
def validate(self) -> error:
|
||||
err = error(ErrorCode.SUCCESS, "")
|
||||
if not self.name:
|
||||
logging.error("Name is required, use default")
|
||||
self.name = ""
|
||||
if not self.gender in ['男', '女', '未知']:
|
||||
logging.error("Gender must be '男', '女', or '未知', use default")
|
||||
self.gender = "未知"
|
||||
if not isinstance(self.age, int) or self.age < 0:
|
||||
logging.error("Age must be an integer and greater than 0, use default")
|
||||
self.age = 0
|
||||
if not isinstance(self.height, int) or self.height < 0:
|
||||
logging.error("Height must be an integer and greater than 0, use default")
|
||||
self.height = 0
|
||||
return err
|
||||
def meta(self) -> Dict[str, str]:
|
||||
# 返回对象的元数据信息
|
||||
meta = {
|
||||
'id': self.id,
|
||||
'name': self.name,
|
||||
'gender': self.gender,
|
||||
'age': self.age,
|
||||
'height': self.height,
|
||||
'marital_status': self.marital_status,
|
||||
}
|
||||
logging.info(f"people meta: {meta}")
|
||||
return meta
|
||||
|
||||
def tonl(self) -> str:
|
||||
# 将对象转换为文档格式的字符串
|
||||
doc = []
|
||||
doc.append(f"姓名: {self.name}")
|
||||
doc.append(f"性别: {self.gender}")
|
||||
if self.age:
|
||||
doc.append(f"年龄: {self.age}")
|
||||
if self.height:
|
||||
doc.append(f"身高: {self.height}cm")
|
||||
if self.marital_status:
|
||||
doc.append(f"婚姻状况: {self.marital_status}")
|
||||
if self.match_requirement:
|
||||
doc.append(f"择偶要求: {self.match_requirement}")
|
||||
if self.introduction:
|
||||
doc.append("个人介绍:")
|
||||
for key, value in self.introduction.items():
|
||||
doc.append(f" - {key}: {value}")
|
||||
if self.comments:
|
||||
doc.append("总结评价:")
|
||||
for key, value in self.comments.items():
|
||||
doc.append(f" - {key}: {value}")
|
||||
return '\n'.join(doc)
|
||||
|
||||
def comment(self, comment: Dict[str, str]):
|
||||
# 添加总结评价
|
||||
self.comments.update(comment)
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
from typing import Optional
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy import Column, String, Text, DateTime, Integer, Boolean, func, UniqueConstraint
|
||||
from utils.rldb import RLDBBaseModel
|
||||
from utils.error import ErrorCode, error
|
||||
|
||||
|
||||
class UserRLDBModel(RLDBBaseModel):
|
||||
__tablename__ = 'users'
|
||||
id = Column(String(36), primary_key=True)
|
||||
nickname = Column(String(255))
|
||||
avatar_link = Column(String(255))
|
||||
email = Column(String(127), unique=True, index=True)
|
||||
phone = Column(String(32), unique=True, index=True)
|
||||
password_hash = Column(String(255))
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
|
||||
|
||||
class VerificationCodeRLDBModel(RLDBBaseModel):
|
||||
__tablename__ = 'verification_codes'
|
||||
id = Column(String(36), primary_key=True)
|
||||
target_type = Column(String(16))
|
||||
target = Column(String(255), index=True)
|
||||
code = Column(String(16))
|
||||
scene = Column(String(32))
|
||||
expires_at = Column(DateTime(timezone=True))
|
||||
used_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
|
||||
class UserTokenRLDBModel(RLDBBaseModel):
|
||||
__tablename__ = 'user_tokens'
|
||||
id = Column(String(36), primary_key=True)
|
||||
user_id = Column(String(36), index=True)
|
||||
token = Column(Text)
|
||||
expired_at = Column(DateTime(timezone=True))
|
||||
revoked = Column(Boolean, default=False)
|
||||
|
||||
|
||||
class User:
|
||||
id: str
|
||||
nickname: str
|
||||
avatar_link: str
|
||||
email: str
|
||||
phone: str
|
||||
password_hash: str
|
||||
created_at: datetime = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
|
||||
self.nickname = kwargs.get('nickname', '') if kwargs.get('nickname', '') is not None else ''
|
||||
self.avatar_link = kwargs.get('avatar_link', '') if kwargs.get('avatar_link', '') is not None else ''
|
||||
self.email = kwargs.get('email', '') if kwargs.get('email', '') is not None else ''
|
||||
self.phone = kwargs.get('phone', '') if kwargs.get('phone', '') is not None else ''
|
||||
self.password_hash = kwargs.get('password_hash', '') if kwargs.get('password_hash', '') is not None else ''
|
||||
self.created_at = kwargs.get('created_at', None)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (f"User(id={self.id}, nickname={self.nickname}, avatar_link={self.avatar_link}, "
|
||||
f"email={self.email}, phone={self.phone}, created_at={self.created_at})")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
if 'updated_at' in data:
|
||||
del data['updated_at']
|
||||
if 'deleted_at' in data:
|
||||
del data['deleted_at']
|
||||
return cls(**data)
|
||||
|
||||
@classmethod
|
||||
def from_rldb_model(cls, data: UserRLDBModel):
|
||||
return cls(
|
||||
id=data.id,
|
||||
nickname=data.nickname,
|
||||
avatar_link=data.avatar_link,
|
||||
email=data.email,
|
||||
phone=data.phone,
|
||||
password_hash=data.password_hash,
|
||||
created_at=data.created_at,
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
'id': self.id,
|
||||
'nickname': self.nickname,
|
||||
'avatar_link': self.avatar_link,
|
||||
'email': self.email,
|
||||
'phone': self.phone,
|
||||
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
|
||||
}
|
||||
|
||||
def to_rldb_model(self) -> UserRLDBModel:
|
||||
return UserRLDBModel(
|
||||
id=self.id,
|
||||
nickname=self.nickname,
|
||||
avatar_link=self.avatar_link,
|
||||
email=self.email,
|
||||
phone=self.phone,
|
||||
password_hash=self.password_hash,
|
||||
)
|
||||
|
||||
def validate(self) -> error:
|
||||
err = error(ErrorCode.SUCCESS, "")
|
||||
if not self.email and not self.phone:
|
||||
return error(ErrorCode.MODEL_ERROR, "email or phone required")
|
||||
return err
|
||||
|
||||
|
||||
class VerificationCode:
|
||||
id: str
|
||||
target_type: str
|
||||
target: str
|
||||
code: str
|
||||
scene: str
|
||||
expires_at: datetime
|
||||
used_at: Optional[datetime] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
|
||||
self.target_type = kwargs.get('target_type', '')
|
||||
self.target = kwargs.get('target', '')
|
||||
self.code = kwargs.get('code', '')
|
||||
self.scene = kwargs.get('scene', '')
|
||||
self.expires_at = kwargs.get('expires_at')
|
||||
self.used_at = kwargs.get('used_at', None)
|
||||
|
||||
@classmethod
|
||||
def from_rldb_model(cls, data: VerificationCodeRLDBModel):
|
||||
return cls(
|
||||
id=data.id,
|
||||
target_type=data.target_type,
|
||||
target=data.target,
|
||||
code=data.code,
|
||||
scene=data.scene,
|
||||
expires_at=data.expires_at,
|
||||
used_at=data.used_at,
|
||||
)
|
||||
|
||||
def to_rldb_model(self) -> VerificationCodeRLDBModel:
|
||||
return VerificationCodeRLDBModel(
|
||||
id=self.id,
|
||||
target_type=self.target_type,
|
||||
target=self.target,
|
||||
code=self.code,
|
||||
scene=self.scene,
|
||||
expires_at=self.expires_at,
|
||||
used_at=self.used_at,
|
||||
)
|
||||
|
||||
|
||||
class UserToken:
|
||||
id: str
|
||||
user_id: str
|
||||
token: str
|
||||
expired_at: datetime
|
||||
revoked: bool
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
|
||||
self.user_id = kwargs.get('user_id', '')
|
||||
self.token = kwargs.get('token', '')
|
||||
self.expired_at = kwargs.get('expired_at')
|
||||
self.revoked = kwargs.get('revoked', False)
|
||||
|
||||
@classmethod
|
||||
def from_rldb_model(cls, data: UserTokenRLDBModel):
|
||||
return cls(
|
||||
id=data.id,
|
||||
user_id=data.user_id,
|
||||
token=data.token,
|
||||
expired_at=data.expired_at,
|
||||
revoked=data.revoked,
|
||||
)
|
||||
|
||||
def to_rldb_model(self) -> UserTokenRLDBModel:
|
||||
return UserTokenRLDBModel(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
token=self.token,
|
||||
expired_at=self.expired_at,
|
||||
revoked=self.revoked,
|
||||
)
|
||||
@@ -1,98 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# created by mmmy on 2025-11-27
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from models.custom import Custom, CustomRLDBModel
|
||||
from utils.error import ErrorCode, error
|
||||
from utils import rldb
|
||||
|
||||
class CustomService:
|
||||
def __init__(self):
|
||||
self.rldb = rldb.get_instance()
|
||||
|
||||
def save(self, custom: Custom) -> (str, error):
|
||||
"""
|
||||
保存客户到数据库。
|
||||
如果 custom.id 存在,则更新;否则,创建。
|
||||
|
||||
:param custom: 客户对象
|
||||
:return: 客户ID 和 错误对象
|
||||
"""
|
||||
# 0. 生成 custom id
|
||||
custom.id = custom.id if custom.id else uuid.uuid4().hex
|
||||
|
||||
# 1. 转换模型,并保存到 SQL 数据库
|
||||
try:
|
||||
custom_orm = custom.to_rldb_model()
|
||||
self.rldb.upsert(custom_orm)
|
||||
return custom.id, error(ErrorCode.SUCCESS, "")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to save custom {custom.id}: {e}")
|
||||
return "", error(ErrorCode.RLDB_ERROR, f"Failed to save custom data: {str(e)}")
|
||||
|
||||
def delete(self, custom_id: str) -> error:
|
||||
"""
|
||||
从数据库删除客户。
|
||||
|
||||
:param custom_id: 客户ID
|
||||
:return: 错误对象
|
||||
"""
|
||||
try:
|
||||
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
|
||||
if not custom_orm:
|
||||
return error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
|
||||
self.rldb.delete(custom_orm)
|
||||
return error(ErrorCode.SUCCESS, "")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to delete custom {custom_id}: {e}")
|
||||
return error(ErrorCode.RLDB_ERROR, f"Failed to delete custom data: {str(e)}")
|
||||
|
||||
def get(self, custom_id: str) -> (Custom, error):
|
||||
"""
|
||||
从数据库获取单个客户。
|
||||
|
||||
:param custom_id: 客户ID
|
||||
:return: 客户对象 和 错误对象
|
||||
"""
|
||||
try:
|
||||
custom_orm = self.rldb.get(CustomRLDBModel, custom_id)
|
||||
if not custom_orm:
|
||||
return None, error(ErrorCode.RLDB_NOT_FOUND, f"Custom {custom_id} not found.")
|
||||
|
||||
custom = Custom.from_rldb_model(custom_orm)
|
||||
return custom, error(ErrorCode.SUCCESS, "")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get custom {custom_id}: {e}")
|
||||
return None, error(ErrorCode.RLDB_ERROR, f"Failed to retrieve custom data: {str(e)}")
|
||||
|
||||
def list(self, conds: dict = None, limit: int = 10, offset: int = 0) -> (list[Custom], error):
|
||||
"""
|
||||
根据条件从数据库列出客户(支持分页)。
|
||||
|
||||
:param conds: 查询条件字典
|
||||
:param limit: 每页数量
|
||||
:param offset: 偏移量
|
||||
:return: 客户对象列表 和 错误对象
|
||||
"""
|
||||
if conds is None:
|
||||
conds = {}
|
||||
try:
|
||||
custom_orms = self.rldb.query(CustomRLDBModel, limit=limit, offset=offset, **conds)
|
||||
customs = [Custom.from_rldb_model(orm) for orm in custom_orms]
|
||||
return customs, error(ErrorCode.SUCCESS, "")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to list customs with conds {conds}: {e}")
|
||||
return [], error(ErrorCode.RLDB_ERROR, f"Failed to list custom data: {str(e)}")
|
||||
|
||||
# --- Singleton Pattern ---
|
||||
custom_service = None
|
||||
|
||||
def init():
|
||||
"""初始化 CustomService 单例"""
|
||||
global custom_service
|
||||
custom_service = CustomService()
|
||||
|
||||
def get_instance() -> CustomService:
|
||||
"""获取 CustomService 单例"""
|
||||
return custom_service
|
||||
@@ -1,124 +0,0 @@
|
||||
|
||||
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from models.people import People, PeopleRLDBModel, Comment
|
||||
from datetime import datetime
|
||||
from utils.error import ErrorCode, error
|
||||
from utils import rldb
|
||||
|
||||
|
||||
class PeopleService:
|
||||
def __init__(self):
|
||||
self.rldb = rldb.get_instance()
|
||||
|
||||
def save(self, people: People) -> (str, error):
|
||||
"""
|
||||
保存人物到数据库和向量数据库
|
||||
|
||||
:param people: 人物对象
|
||||
:return: 人物ID
|
||||
"""
|
||||
# 0. 生成 people id
|
||||
people.id = people.id if people.id else uuid.uuid4().hex
|
||||
|
||||
# 1. 转换模型,并保存到 SQL 数据库
|
||||
people_orm = people.to_rldb_model()
|
||||
self.rldb.upsert(people_orm)
|
||||
|
||||
return people.id, error(ErrorCode.SUCCESS, "")
|
||||
|
||||
def delete(self, people_id: str) -> error:
|
||||
"""
|
||||
删除人物从数据库和向量数据库
|
||||
|
||||
:param people_id: 人物ID
|
||||
:return: 错误对象
|
||||
"""
|
||||
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
|
||||
if not people_orm:
|
||||
return error(ErrorCode.RLDB_ERROR, f"people {people_id} not found")
|
||||
self.rldb.delete(people_orm)
|
||||
return error(ErrorCode.SUCCESS, "")
|
||||
|
||||
def get(self, people_id: str) -> (People, error):
|
||||
"""
|
||||
从数据库获取人物
|
||||
|
||||
:param people_id: 人物ID
|
||||
:return: 人物对象
|
||||
"""
|
||||
people_orm = self.rldb.get(PeopleRLDBModel, people_id)
|
||||
if not people_orm:
|
||||
return None, error(ErrorCode.MODEL_ERROR, f"people {people_id} not found")
|
||||
return People.from_rldb_model(people_orm), error(ErrorCode.SUCCESS, "")
|
||||
|
||||
def list(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> (list[People], error):
|
||||
"""
|
||||
从数据库列出人物
|
||||
|
||||
:param conds: 查询条件字典
|
||||
:param limit: 分页大小
|
||||
:param offset: 分页偏移量
|
||||
:return: 人物对象列表
|
||||
"""
|
||||
people_orms = self.rldb.query(PeopleRLDBModel, **conds)
|
||||
peoples = [People.from_rldb_model(people_orm) for people_orm in people_orms]
|
||||
|
||||
return peoples, error(ErrorCode.SUCCESS, "")
|
||||
|
||||
def save_remark(self, people_id: str, content: str) -> error:
|
||||
"""
|
||||
为人物添加或更新备注
|
||||
|
||||
:param people_id: 人物ID
|
||||
:param content: 备注内容
|
||||
:return: 错误对象
|
||||
"""
|
||||
people: People
|
||||
err: error
|
||||
people, err = self.get(people_id)
|
||||
logging.info(f"get people before save remark: {people}")
|
||||
if not err.success:
|
||||
return err
|
||||
remark = people.comments.get("remark", None)
|
||||
if remark is not None:
|
||||
remark.content = content
|
||||
remark.updated_at = datetime.now()
|
||||
else:
|
||||
people.comments["remark"] = Comment(content=content)
|
||||
logging.info(f"save remark for people {people}")
|
||||
_, err = self.save(people)
|
||||
return err
|
||||
|
||||
def delete_remark(self, people_id: str) -> error:
|
||||
"""
|
||||
删除人物备注
|
||||
|
||||
:param people_id: 人物ID
|
||||
:return: 错误对象
|
||||
"""
|
||||
people: People
|
||||
err: error
|
||||
people, err = self.get(people_id)
|
||||
if not err.success:
|
||||
return err
|
||||
|
||||
if "remark" in people.comments:
|
||||
del people.comments["remark"]
|
||||
_, err = self.save(people)
|
||||
return err
|
||||
|
||||
return error(ErrorCode.SUCCESS, "")
|
||||
|
||||
|
||||
|
||||
people_service = None
|
||||
|
||||
def init():
|
||||
global people_service
|
||||
people_service = PeopleService()
|
||||
|
||||
def get_instance() -> PeopleService:
|
||||
return people_service
|
||||
@@ -1,220 +0,0 @@
|
||||
import uuid
|
||||
import hmac
|
||||
import base64
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from utils.error import ErrorCode, error
|
||||
from utils import rldb, mailer, sms, config
|
||||
from models.user import (
|
||||
User,
|
||||
UserRLDBModel,
|
||||
VerificationCode,
|
||||
VerificationCodeRLDBModel,
|
||||
UserToken,
|
||||
UserTokenRLDBModel,
|
||||
)
|
||||
|
||||
|
||||
class UserService:
|
||||
def __init__(self):
|
||||
self.rldb = rldb.get_instance()
|
||||
self.mailer = mailer.get_instance()
|
||||
self.sms = sms.get_instance()
|
||||
self.conf = config.get_instance()
|
||||
|
||||
def _hash_password(self, password: str, salt: Optional[str] = None) -> str:
|
||||
salt = salt if salt else base64.urlsafe_b64encode(os.urandom(16)).decode('utf-8')
|
||||
digest = hmac.new(salt.encode('utf-8'), password.encode('utf-8'), 'sha256').digest()
|
||||
return f"{salt}:{base64.urlsafe_b64encode(digest).decode('utf-8')}"
|
||||
|
||||
def _verify_password(self, password: str, password_hash: str) -> bool:
|
||||
parts = password_hash.split(':')
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
salt = parts[0]
|
||||
return self._hash_password(password, salt) == password_hash
|
||||
|
||||
def send_code(self, target_type: str, target: str, scene: str) -> error:
|
||||
scens = {
|
||||
"register": "注册",
|
||||
"update": "信息更新",
|
||||
# "login": "登录",
|
||||
}
|
||||
if scene not in scens:
|
||||
return error(ErrorCode.MODEL_ERROR, f'scene {scene} not supported')
|
||||
scene_name = scens.get(scene, scene)
|
||||
code = f"{uuid.uuid4().int % 1000000:06d}"
|
||||
expires = datetime.now() + timedelta(minutes=10)
|
||||
vc = VerificationCode(
|
||||
id=uuid.uuid4().hex,
|
||||
target_type=target_type,
|
||||
target=target,
|
||||
code=code,
|
||||
scene=scene,
|
||||
expires_at=expires,
|
||||
)
|
||||
self.rldb.upsert(vc.to_rldb_model())
|
||||
content = f"IF.U服务{scene_name}验证码: {code}, 10分钟内有效"
|
||||
sent = True
|
||||
if target_type == 'email':
|
||||
sent = self.mailer.send(target, f'IF.U服务{scene_name}验证码', content) if self.mailer else False
|
||||
elif target_type == 'phone':
|
||||
sent = self.sms.send(target, content) if self.sms else False
|
||||
if not sent:
|
||||
return error(ErrorCode.RLDB_ERROR, 'send code failed')
|
||||
return error(ErrorCode.SUCCESS, '')
|
||||
|
||||
def _get_user_by_identifier(self, email: Optional[str], phone: Optional[str]) -> Optional[User]:
|
||||
if email:
|
||||
users = self.rldb.query(UserRLDBModel, email=email, limit=1)
|
||||
if users:
|
||||
return User.from_rldb_model(users[0])
|
||||
if phone:
|
||||
users = self.rldb.query(UserRLDBModel, phone=phone, limit=1)
|
||||
if users:
|
||||
return User.from_rldb_model(users[0])
|
||||
return None
|
||||
|
||||
def register(self, user: User, code: str) -> (str, error):
|
||||
if not user.email and not user.phone:
|
||||
return '', error(ErrorCode.MODEL_ERROR, 'email or phone required')
|
||||
existed = self._get_user_by_identifier(user.email, user.phone)
|
||||
if existed:
|
||||
return '', error(ErrorCode.MODEL_ERROR, 'user existed')
|
||||
target_type = 'phone' if user.phone else 'email'
|
||||
target = user.phone if user.phone else user.email
|
||||
vc_list = self.rldb.query(
|
||||
VerificationCodeRLDBModel,
|
||||
target_type=target_type,
|
||||
target=target,
|
||||
scene='register',
|
||||
limit=1,
|
||||
)
|
||||
if not vc_list:
|
||||
return '', error(ErrorCode.MODEL_ERROR, 'code not found')
|
||||
vc = vc_list[0]
|
||||
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
|
||||
return '', error(ErrorCode.MODEL_ERROR, 'invalid code')
|
||||
vc.used_at = datetime.now()
|
||||
self.rldb.upsert(vc)
|
||||
user.id = uuid.uuid4().hex
|
||||
hashed = self._hash_password(user.password_hash)
|
||||
user.password_hash = hashed
|
||||
self.rldb.upsert(user.to_rldb_model())
|
||||
return user.id, error(ErrorCode.SUCCESS, '')
|
||||
|
||||
def login(self, email: Optional[str], phone: Optional[str], password: str) -> (dict, error):
|
||||
u = self._get_user_by_identifier(email, phone)
|
||||
if not u:
|
||||
return {}, error(ErrorCode.MODEL_ERROR, 'user not found')
|
||||
if not self._verify_password(password, u.password_hash):
|
||||
return {}, error(ErrorCode.MODEL_ERROR, 'invalid password')
|
||||
ttl_days = self.conf.getint('auth', 'token_ttl_days', fallback=30)
|
||||
expired_at = datetime.now() + timedelta(days=ttl_days)
|
||||
token_raw = f"{u.id}.{uuid.uuid4().hex}.{int(expired_at.timestamp())}"
|
||||
secret = self.conf.get('auth', 'jwt_secret', fallback='dev-secret')
|
||||
signature = hmac.new(secret.encode('utf-8'), token_raw.encode('utf-8'), 'sha256').digest()
|
||||
token = base64.urlsafe_b64encode(token_raw.encode('utf-8')).decode('utf-8') + '.' + base64.urlsafe_b64encode(signature).decode('utf-8')
|
||||
ut = UserToken(id=uuid.uuid4().hex, user_id=u.id, token=token, expired_at=expired_at, revoked=False)
|
||||
self.rldb.upsert(ut.to_rldb_model())
|
||||
return {'token': token, 'expired_at': int(expired_at.timestamp())}, error(ErrorCode.SUCCESS, '')
|
||||
|
||||
def logout(self, token: str) -> error:
|
||||
tokens = self.rldb.query(UserTokenRLDBModel, token=token, limit=1)
|
||||
if not tokens:
|
||||
return error(ErrorCode.MODEL_ERROR, 'token not found')
|
||||
t = tokens[0]
|
||||
t.revoked = True
|
||||
self.rldb.upsert(t)
|
||||
return error(ErrorCode.SUCCESS, '')
|
||||
|
||||
def delete_user(self, user_id: str) -> error:
|
||||
u = self.rldb.get(UserRLDBModel, user_id)
|
||||
if not u:
|
||||
return error(ErrorCode.MODEL_ERROR, 'user not found')
|
||||
self.rldb.delete(u)
|
||||
return error(ErrorCode.SUCCESS, '')
|
||||
|
||||
def get(self, user_id: str) -> (User, error):
|
||||
u = self.rldb.get(UserRLDBModel, user_id)
|
||||
if not u:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'user not found')
|
||||
return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '')
|
||||
|
||||
def update_profile(self, user_id: str, nickname: str = None, avatar_link: str = None, phone: str = None, email: str = None) -> (User, error):
|
||||
u = self.rldb.get(UserRLDBModel, user_id)
|
||||
if not u:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'user not found')
|
||||
has_email = bool(u.email)
|
||||
has_phone = bool(u.phone)
|
||||
if nickname is not None:
|
||||
u.nickname = nickname
|
||||
if avatar_link is not None:
|
||||
u.avatar_link = avatar_link
|
||||
if email is not None:
|
||||
new_email = email
|
||||
if has_email:
|
||||
if not has_phone:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'email update requires phone exists')
|
||||
conflicts = self.rldb.query(UserRLDBModel, email=new_email, limit=1)
|
||||
if conflicts and conflicts[0].id != user_id:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'email existed')
|
||||
u.email = new_email
|
||||
if phone is not None:
|
||||
new_phone = phone
|
||||
if has_phone:
|
||||
if not has_email:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'phone update requires email exists')
|
||||
conflicts = self.rldb.query(UserRLDBModel, phone=new_phone, limit=1)
|
||||
if conflicts and conflicts[0].id != user_id:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'phone existed')
|
||||
u.phone = new_phone
|
||||
self.rldb.upsert(u)
|
||||
return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '')
|
||||
|
||||
def update_phone_with_code(self, user_id: str, new_phone: str, code: str) -> (User, error):
|
||||
vc_list = self.rldb.query(
|
||||
VerificationCodeRLDBModel,
|
||||
target_type='phone',
|
||||
target=new_phone,
|
||||
scene='update',
|
||||
limit=1,
|
||||
)
|
||||
if not vc_list:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'code not found')
|
||||
vc = vc_list[0]
|
||||
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'invalid code')
|
||||
vc.used_at = datetime.now()
|
||||
self.rldb.upsert(vc)
|
||||
return self.update_profile(user_id, phone=new_phone)
|
||||
|
||||
def update_email_with_code(self, user_id: str, new_email: str, code: str) -> (User, error):
|
||||
vc_list = self.rldb.query(
|
||||
VerificationCodeRLDBModel,
|
||||
target_type='email',
|
||||
target=new_email,
|
||||
scene='update',
|
||||
limit=1,
|
||||
)
|
||||
if not vc_list:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'code not found')
|
||||
vc = vc_list[0]
|
||||
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
|
||||
return None, error(ErrorCode.MODEL_ERROR, 'invalid code')
|
||||
vc.used_at = datetime.now()
|
||||
self.rldb.upsert(vc)
|
||||
return self.update_profile(user_id, email=new_email)
|
||||
|
||||
|
||||
user_service = None
|
||||
|
||||
|
||||
def init():
|
||||
global user_service
|
||||
user_service = UserService()
|
||||
|
||||
|
||||
def get_instance() -> UserService:
|
||||
return user_service
|
||||
216
src/storage/people_store.py
Normal file
216
src/storage/people_store.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, func
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from utils.config import get_instance as get_config
|
||||
from utils.vsdb import VectorDB, get_instance as get_vsdb
|
||||
from utils.obs import OBS, get_instance as get_obs
|
||||
|
||||
from models.people import People
|
||||
|
||||
people_store = None
|
||||
Base = declarative_base()
|
||||
class PeopleORM(Base):
|
||||
__tablename__ = 'peoples'
|
||||
id = Column(String(36), primary_key=True)
|
||||
name = Column(String(255), index=True)
|
||||
contact = Column(String(255), index=True)
|
||||
gender = Column(String(10))
|
||||
age = Column(Integer)
|
||||
height = Column(Integer)
|
||||
marital_status = Column(String(20))
|
||||
match_requirement = Column(Text)
|
||||
introduction = Column(Text)
|
||||
comments = Column(Text)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
def parse_from_people(self, people: People):
|
||||
import json
|
||||
self.id = people.id
|
||||
self.name = people.name
|
||||
self.contact = people.contact
|
||||
self.gender = people.gender
|
||||
self.age = people.age
|
||||
self.height = people.height
|
||||
self.marital_status = people.marital_status
|
||||
self.match_requirement = people.match_requirement
|
||||
# 将字典类型字段序列化为JSON字符串存储
|
||||
self.introduction = json.dumps(people.introduction, ensure_ascii=False)
|
||||
self.comments = json.dumps(people.comments, ensure_ascii=False)
|
||||
|
||||
def to_people(self) -> People:
|
||||
import json
|
||||
people = People()
|
||||
people.id = self.id
|
||||
people.name = self.name
|
||||
people.contact = self.contact
|
||||
people.gender = self.gender
|
||||
people.age = self.age
|
||||
people.height = self.height
|
||||
people.marital_status = self.marital_status
|
||||
people.match_requirement = self.match_requirement
|
||||
# 将JSON字符串反序列化为字典类型字段
|
||||
try:
|
||||
people.introduction = json.loads(self.introduction) if self.introduction else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
people.introduction = {}
|
||||
|
||||
try:
|
||||
people.comments = json.loads(self.comments) if self.comments else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
people.comments = {}
|
||||
|
||||
return people
|
||||
|
||||
class PeopleStore:
|
||||
def __init__(self):
|
||||
config = get_config()
|
||||
self.sqldb_engine = create_engine(config.get("sqlalchemy", "database_dsn"))
|
||||
Base.metadata.create_all(self.sqldb_engine)
|
||||
self.session_maker = sessionmaker(bind=self.sqldb_engine)
|
||||
self.vsdb: VectorDB = get_vsdb()
|
||||
self.obs: OBS = get_obs()
|
||||
|
||||
def save(self, people: People) -> str:
|
||||
"""
|
||||
保存人物到数据库和向量数据库
|
||||
|
||||
:param people: 人物对象
|
||||
:return: 人物ID
|
||||
"""
|
||||
# 0. 生成 people id
|
||||
people.id = people.id if people.id else uuid.uuid4().hex
|
||||
|
||||
# 1. 转换模型,并保存到 SQL 数据库
|
||||
people_orm = PeopleORM()
|
||||
people_orm.parse_from_people(people)
|
||||
with self.session_maker() as session:
|
||||
session.add(people_orm)
|
||||
session.commit()
|
||||
|
||||
# 2. 保存到向量数据库
|
||||
people_metadata = people.meta()
|
||||
people_document = people.tonl()
|
||||
logging.info(f"people: {people}")
|
||||
logging.info(f"people_metadata: {people_metadata}")
|
||||
logging.info(f"people_document: {people_document}")
|
||||
results = self.vsdb.insert(metadatas=[people_metadata], documents=[people_document], ids=[people.id])
|
||||
logging.info(f"results: {results}")
|
||||
if len(results) == 0:
|
||||
raise Exception("insert failed")
|
||||
|
||||
# 3. 保存到 OBS 存储
|
||||
people_dict = people.to_dict()
|
||||
people_json = json.dumps(people_dict, ensure_ascii=False)
|
||||
obs_url = self.obs.Put(f"peoples/{people.id}/detail.json", people_json.encode('utf-8'))
|
||||
logging.info(f"obs_url: {obs_url}")
|
||||
|
||||
return people.id
|
||||
|
||||
def update(self, people: People) -> None:
|
||||
raise Exception("update not implemented")
|
||||
return None
|
||||
|
||||
def find(self, people_id: str) -> People:
|
||||
"""
|
||||
根据人物ID查询人物
|
||||
|
||||
:param people_id: 人物ID
|
||||
:return: 人物对象
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
people_orm = session.query(PeopleORM).filter(
|
||||
PeopleORM.id == people_id,
|
||||
PeopleORM.deleted_at.is_(None)
|
||||
).first()
|
||||
if not people_orm:
|
||||
raise Exception(f"people not found, people_id: {people_id}")
|
||||
return people_orm.to_people()
|
||||
|
||||
def query(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> list[People]:
|
||||
"""
|
||||
根据查询条件查询人物
|
||||
|
||||
:param conds: 查询条件字典
|
||||
:param limit: 分页大小
|
||||
:param offset: 分页偏移量
|
||||
:return: 人物对象列表
|
||||
"""
|
||||
if conds is None:
|
||||
conds = {}
|
||||
with self.session_maker() as session:
|
||||
people_orms = session.query(PeopleORM).filter_by(**conds).filter(
|
||||
PeopleORM.deleted_at.is_(None)
|
||||
).limit(limit).offset(offset).all()
|
||||
return [people_orm.to_people() for people_orm in people_orms]
|
||||
|
||||
def search(self, search: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[People]:
|
||||
"""
|
||||
根据搜索内容和查询条件查询人物
|
||||
|
||||
:param search: 搜索内容
|
||||
:param metadatas: 查询条件字典
|
||||
:param ids: 可选的人物ID列表,用于过滤结果
|
||||
:param top_k: 返回结果数量
|
||||
:return: 人物对象列表
|
||||
"""
|
||||
peoples = []
|
||||
results = self.vsdb.search(document=search, metadatas=metadatas, ids=ids, top_k=top_k)
|
||||
logging.info(f"results: {results}")
|
||||
people_id_list = []
|
||||
for result in results:
|
||||
people_id = result.get('id', '')
|
||||
if not people_id:
|
||||
continue
|
||||
people_id_list.append(people_id)
|
||||
logging.info(f"people_id_list: {people_id_list}")
|
||||
with self.session_maker() as session:
|
||||
people_orms = session.query(PeopleORM).filter(PeopleORM.id.in_(people_id_list)).filter(
|
||||
PeopleORM.deleted_at.is_(None)
|
||||
).all()
|
||||
# 根据 people_id_list 的顺序对查询结果进行排序
|
||||
order_map = {pid: idx for idx, pid in enumerate(people_id_list)}
|
||||
people_orms.sort(key=lambda orm: order_map.get(orm.id, len(order_map)))
|
||||
for people_orm in people_orms:
|
||||
people = people_orm.to_people()
|
||||
peoples.append(people)
|
||||
return peoples
|
||||
|
||||
def delete(self, people_id: str) -> None:
|
||||
"""
|
||||
删除人物从数据库和向量数据库
|
||||
|
||||
:param people_id: 人物ID
|
||||
"""
|
||||
# 1. 从 SQL 数据库软删除人物
|
||||
with self.session_maker() as session:
|
||||
session.query(PeopleORM).filter(
|
||||
PeopleORM.id == people_id,
|
||||
PeopleORM.deleted_at.is_(None)
|
||||
).update({PeopleORM.deleted_at: func.now()}, synchronize_session=False)
|
||||
session.commit()
|
||||
logging.info(f"人物 {people_id} 标记删除 SQL 成功")
|
||||
|
||||
# 2. 删除向量数据库中的记录
|
||||
self.vsdb.delete(ids=[people_id])
|
||||
logging.info(f"人物 {people_id} 删除向量数据库成功")
|
||||
|
||||
# 3. 删除 OBS 存储中的文件
|
||||
keys = self.obs.List(f"peoples/{people_id}/")
|
||||
for key in keys:
|
||||
self.obs.Del(key)
|
||||
logging.info(f"文件 {key} 删除 OBS 成功")
|
||||
|
||||
def init():
|
||||
global people_store
|
||||
people_store = PeopleStore()
|
||||
|
||||
|
||||
def get_instance() -> PeopleStore:
|
||||
return people_store
|
||||
@@ -0,0 +1,3 @@
|
||||
# 导出utils模块中的子模块
|
||||
from . import config, obs, ocr, vsdb, logger
|
||||
__all__ = ['config', 'obs', 'ocr', 'vsdb', 'logger']
|
||||
|
||||
@@ -2,6 +2,7 @@ import configparser
|
||||
|
||||
config = None
|
||||
|
||||
|
||||
def init(config_file: str):
|
||||
global config
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
@@ -1,33 +0,0 @@
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import Protocol
|
||||
|
||||
class ErrorCode(Enum):
|
||||
SUCCESS = 0
|
||||
MODEL_ERROR = 1000
|
||||
RLDB_ERROR = 2100
|
||||
RLDB_NOT_FOUND = 2101
|
||||
OBS_ERROR = 3100
|
||||
OBS_INPUT_ERROR = 3102
|
||||
OBS_SERVICE_ERROR = 3103
|
||||
|
||||
class error(Protocol):
|
||||
_error_code: int = 0
|
||||
_error_info: str = ""
|
||||
|
||||
def __init__(self, error_code: ErrorCode, error_info: str):
|
||||
self._error_code = int(error_code.value)
|
||||
self._error_info = error_info
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self._error_code}, {self._error_info})"
|
||||
|
||||
@property
|
||||
def code(self) -> int:
|
||||
return self._error_code
|
||||
@property
|
||||
def info(self) -> str:
|
||||
return self._error_info
|
||||
@property
|
||||
def success(self) -> bool:
|
||||
return self._error_code == 0
|
||||
@@ -1,60 +0,0 @@
|
||||
import logging
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Protocol
|
||||
from .config import get_instance as get_config
|
||||
|
||||
class Mailer(Protocol):
|
||||
def send(self, to_email: str, subject: str, content: str) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class FakeMailer:
|
||||
def __init__(self) -> None:
|
||||
conf = get_config()
|
||||
self.fake_message = conf.get('fake_mailer', 'message', fallback="FakeEmail")
|
||||
def send(self, to_email: str, subject: str, content: str) -> bool:
|
||||
logging.info(f"{self.fake_message}: to_email={to_email}, subject={subject}, content={content}")
|
||||
return True
|
||||
|
||||
|
||||
class RealMailer:
|
||||
def __init__(self):
|
||||
conf = get_config()
|
||||
self.smtp_host = conf.get('real_mailer', 'smtp_host', fallback=None)
|
||||
self.smtp_port = conf.getint('real_mailer', 'smtp_port', fallback=587)
|
||||
self.smtp_user = conf.get('real_mailer', 'smtp_user', fallback=None)
|
||||
self.smtp_pass = conf.get('real_mailer', 'smtp_pass', fallback=None)
|
||||
self.from_email = conf.get('real_mailer', 'from_email', fallback=self.smtp_user)
|
||||
|
||||
def send(self, to_email: str, subject: str, content: str) -> bool:
|
||||
if not self.smtp_host or not self.smtp_user or not self.smtp_pass:
|
||||
return False
|
||||
msg = MIMEText(content, 'plain', 'utf-8')
|
||||
msg['Subject'] = subject
|
||||
msg['From'] = self.from_email
|
||||
msg['To'] = to_email
|
||||
try:
|
||||
server = smtplib.SMTP(self.smtp_host, self.smtp_port)
|
||||
server.starttls()
|
||||
server.login(self.smtp_user, self.smtp_pass)
|
||||
server.sendmail(self.from_email, [to_email], msg.as_string())
|
||||
server.quit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
_mailer: Mailer = None
|
||||
|
||||
|
||||
def init(type: str = 'real'):
|
||||
global _mailer
|
||||
if type == 'real':
|
||||
_mailer = RealMailer()
|
||||
else:
|
||||
_mailer = FakeMailer()
|
||||
|
||||
|
||||
def get_instance() -> Mailer:
|
||||
return _mailer
|
||||
@@ -4,13 +4,11 @@ import logging
|
||||
from typing import Protocol
|
||||
import qiniu
|
||||
import requests
|
||||
|
||||
from .error import ErrorCode, error
|
||||
from .config import get_instance as get_config
|
||||
|
||||
|
||||
class OBS(Protocol):
|
||||
def put(self, obs_path: str, content: bytes) -> str:
|
||||
def Put(self, obs_path: str, content: bytes) -> str:
|
||||
"""
|
||||
上传文件到OBS
|
||||
|
||||
@@ -23,7 +21,7 @@ class OBS(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get(self, obs_path: str) -> bytes:
|
||||
def Get(self, obs_path: str) -> bytes:
|
||||
"""
|
||||
从OBS下载文件
|
||||
|
||||
@@ -35,7 +33,7 @@ class OBS(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def list(self, obs_path: str) -> list:
|
||||
def List(self, obs_path: str) -> list:
|
||||
"""
|
||||
列出OBS目录下的所有文件
|
||||
|
||||
@@ -47,7 +45,7 @@ class OBS(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def delete(self, obs_path: str) -> error:
|
||||
def Del(self, obs_path: str) -> bool:
|
||||
"""
|
||||
删除OBS文件
|
||||
|
||||
@@ -59,7 +57,7 @@ class OBS(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_link(self, obs_path: str) -> str:
|
||||
def Link(self, obs_path: str) -> str:
|
||||
"""
|
||||
获取OBS文件链接
|
||||
|
||||
@@ -70,31 +68,6 @@ class OBS(Protocol):
|
||||
str: OBS文件链接
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_by_link(self, obs_link: str) -> error:
|
||||
"""
|
||||
根据OBS文件链接删除文件
|
||||
|
||||
Args:
|
||||
obs_link (str): OBS文件链接
|
||||
|
||||
Returns:
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
...
|
||||
|
||||
def get_obs_path_by_link(self, obs_link: str) -> (str, error):
|
||||
"""
|
||||
从OBS文件链接获取OBS路径
|
||||
|
||||
Args:
|
||||
obs_link (str): OBS文件链接
|
||||
|
||||
Returns:
|
||||
str: OBS文件路径
|
||||
error: 错误信息
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class Koodo:
|
||||
@@ -109,7 +82,7 @@ class Koodo:
|
||||
self.bucket = qiniu.BucketManager(self.auth)
|
||||
pass
|
||||
|
||||
def put(self, obs_path: str, content: bytes) -> str:
|
||||
def Put(self, obs_path: str, content: bytes) -> str:
|
||||
"""
|
||||
上传文件到OBS
|
||||
|
||||
@@ -130,7 +103,7 @@ class Koodo:
|
||||
logging.info(f"文件 {obs_path} 上传成功, OBS路径: {full_path}")
|
||||
return f"{self.outer_domain}/{full_path}"
|
||||
|
||||
def get(self, obs_path: str) -> bytes:
|
||||
def Get(self, obs_path: str) -> bytes:
|
||||
"""
|
||||
从OBS下载文件
|
||||
|
||||
@@ -148,7 +121,7 @@ class Koodo:
|
||||
return None
|
||||
return resp.content
|
||||
|
||||
def list(self, prefix: str = "") -> list[str]:
|
||||
def List(self, prefix: str = "") -> list[str]:
|
||||
"""
|
||||
列出OBS目录下的所有文件
|
||||
|
||||
@@ -170,7 +143,7 @@ class Koodo:
|
||||
# logging.debug(f"info: {info}")
|
||||
return keys
|
||||
|
||||
def delete(self, obs_path: str) -> error:
|
||||
def Del(self, obs_path: str) -> bool:
|
||||
"""
|
||||
删除OBS文件
|
||||
|
||||
@@ -178,17 +151,17 @@ class Koodo:
|
||||
obs_path (str): OBS文件路径
|
||||
|
||||
Returns:
|
||||
error: 删除结果
|
||||
bool: 是否删除成功
|
||||
"""
|
||||
ret, info = self.bucket.delete(self.bucket_name, f"{self.prefix_path}{obs_path}")
|
||||
logging.debug(f"文件 {self.prefix_path}{obs_path} 删除 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}")
|
||||
logging.debug(f"文件 {obs_path} 删除 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}")
|
||||
if ret is None or info.status_code != 200:
|
||||
logging.error(f"文件 {obs_path} 删除 OBS 失败, 错误信息: {info.text_body}")
|
||||
return error(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {self.prefix_path}{obs_path} 删除 OBS 失败, 错误信息: {info.text_body}")
|
||||
return False
|
||||
logging.info(f"文件 {obs_path} 删除 OBS 成功")
|
||||
return error(error_code=ErrorCode.SUCCESS, error_info="success")
|
||||
return True
|
||||
|
||||
def get_link(self, obs_path: str) -> str:
|
||||
def Link(self, obs_path: str) -> str:
|
||||
"""
|
||||
获取OBS文件链接
|
||||
|
||||
@@ -200,38 +173,6 @@ class Koodo:
|
||||
"""
|
||||
return f"{self.outer_domain}/{self.prefix_path}{obs_path}"
|
||||
|
||||
def delete_by_link(self, obs_link: str) -> error:
|
||||
"""
|
||||
根据OBS文件链接删除文件
|
||||
|
||||
Args:
|
||||
obs_link (str): OBS文件链接
|
||||
|
||||
Returns:
|
||||
error: 删除结果
|
||||
"""
|
||||
obs_path, err = self.get_obs_path_by_link(obs_link)
|
||||
if not err.success:
|
||||
return err
|
||||
return self.delete(obs_path)
|
||||
|
||||
def get_obs_path_by_link(self, obs_link: str) -> (str, error):
|
||||
"""
|
||||
从OBS文件链接获取OBS路径
|
||||
|
||||
Args:
|
||||
obs_link (str): OBS文件链接
|
||||
|
||||
Returns:
|
||||
str: OBS文件路径
|
||||
error: 错误信息
|
||||
"""
|
||||
if not obs_link.startswith(f"{self.outer_domain}/{self.prefix_path}"):
|
||||
logging.error(f"文件 {obs_link} 不是 OBS 文件链接")
|
||||
return "", error(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {obs_link} 不是 OBS 文件链接")
|
||||
obs_path = obs_link[len(self.outer_domain) + len(self.prefix_path) + 1:]
|
||||
return obs_path, error(error_code=ErrorCode.SUCCESS, error_info="success")
|
||||
|
||||
|
||||
_obs_instance: OBS = None
|
||||
|
||||
@@ -272,8 +213,8 @@ if __name__ == "__main__":
|
||||
# print(f"文件 {obs_path} 链接: {link}")
|
||||
|
||||
# 列出OBS目录下的所有文件
|
||||
keys = obs.list("")
|
||||
keys = obs.List("")
|
||||
print(f"OBS 目录下的所有文件: {keys}")
|
||||
for key in keys:
|
||||
link = obs.delete(key)
|
||||
link = obs.Del(key)
|
||||
print(f"文件 {key} 删除 OBS 成功: {link}")
|
||||
|
||||
@@ -1,168 +0,0 @@
|
||||
|
||||
from re import S
|
||||
from typing import Protocol
|
||||
import uuid
|
||||
from sqlalchemy import Column, DateTime, String, create_engine, func
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from .config import get_instance as get_config
|
||||
|
||||
SQLAlchemyBase = declarative_base()
|
||||
|
||||
|
||||
class RLDBBaseModel(SQLAlchemyBase):
|
||||
__abstract__ = True
|
||||
id = Column(String(36), primary_key=True, default=lambda: uuid.uuid4().hex)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
|
||||
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
def __str__(self) -> str:
|
||||
# 遍历所有的field,打印出所有的field和value, id 永远排在第一, 三个时间戳排在最后, 其余字段按定义顺序排序
|
||||
fields = [field for field in self.__dict__ if not field.startswith('_')]
|
||||
fields.remove("id") if "id" in fields else None
|
||||
fields.remove("created_at") if "created_at" in fields else None
|
||||
fields.remove("updated_at") if "updated_at" in fields else None
|
||||
fields.remove("deleted_at") if "deleted_at" in fields else None
|
||||
fields = ["id"] + fields + ["created_at", "updated_at", "deleted_at"]
|
||||
field_values = [f"{field}={getattr(self, field)}" for field in fields]
|
||||
return f"{self.__class__.__name__}({', '.join(field_values)})"
|
||||
|
||||
|
||||
class RelationalDB(Protocol):
|
||||
def insert(self, data: RLDBBaseModel) -> str:
|
||||
...
|
||||
|
||||
def update(self, data: RLDBBaseModel) -> str:
|
||||
...
|
||||
|
||||
def upsert(self, data: RLDBBaseModel) -> str:
|
||||
...
|
||||
|
||||
def delete(self, data: RLDBBaseModel) -> str:
|
||||
...
|
||||
|
||||
def get(self,
|
||||
model: type[RLDBBaseModel],
|
||||
id: str,
|
||||
include_deleted: bool = False
|
||||
) -> RLDBBaseModel:
|
||||
...
|
||||
|
||||
def query(self,
|
||||
model: type[RLDBBaseModel],
|
||||
include_deleted: bool = False,
|
||||
limit: int = None,
|
||||
offset: int = None,
|
||||
**filters
|
||||
) -> list[RLDBBaseModel]:
|
||||
...
|
||||
|
||||
|
||||
class SqlAlchemyDB():
|
||||
def __init__(self, dsn: str = None) -> None:
|
||||
config = get_config()
|
||||
dsn = dsn if dsn else config.get("sqlalchemy", "database_dsn")
|
||||
self.sqldb_engine = create_engine(dsn)
|
||||
SQLAlchemyBase.metadata.create_all(self.sqldb_engine)
|
||||
self.session_maker = sessionmaker(bind=self.sqldb_engine)
|
||||
|
||||
def insert(self, data: RLDBBaseModel) -> str:
|
||||
with self.session_maker() as session:
|
||||
session.add(data)
|
||||
session.commit()
|
||||
return data.id
|
||||
|
||||
def update(self, data: RLDBBaseModel) -> str:
|
||||
with self.session_maker() as session:
|
||||
session.merge(data)
|
||||
session.commit()
|
||||
return data.id
|
||||
|
||||
def upsert(self, data: RLDBBaseModel) -> str:
|
||||
existed = data.id and data.id != "" and self.get(data.__class__, data.id) is not None
|
||||
with self.session_maker() as session:
|
||||
session.merge(data) if existed else session.add(data)
|
||||
session.commit()
|
||||
return data.id
|
||||
|
||||
def delete(self, data: RLDBBaseModel) -> str:
|
||||
with self.session_maker() as session:
|
||||
session.delete(data)
|
||||
session.commit()
|
||||
return data.id
|
||||
|
||||
def get(self,
|
||||
model: type[RLDBBaseModel],
|
||||
id: str,
|
||||
) -> RLDBBaseModel:
|
||||
with self.session_maker() as session:
|
||||
sel = session.query(model)
|
||||
sel = sel.filter(model.id == id)
|
||||
sel = sel.filter(model.deleted_at.is_(None))
|
||||
result = sel.first()
|
||||
return result
|
||||
|
||||
def query(self,
|
||||
model: type[RLDBBaseModel],
|
||||
limit: int = None,
|
||||
offset: int = None,
|
||||
**filters
|
||||
) -> list[RLDBBaseModel]:
|
||||
results: list[RLDBBaseModel] = []
|
||||
with self.session_maker() as session:
|
||||
sel = session.query(model)
|
||||
sel = sel.filter(model.deleted_at.is_(None))
|
||||
if filters:
|
||||
sel = sel.filter_by(**filters)
|
||||
if limit:
|
||||
sel = sel.limit(limit)
|
||||
if offset:
|
||||
sel = sel.offset(offset)
|
||||
results = sel.all()
|
||||
results.sort(key=lambda x: x.created_at, reverse=True)
|
||||
return results
|
||||
|
||||
_rldb_instance: RelationalDB = None
|
||||
|
||||
def init(type: str = "sqlalchemy", dsn: str = None):
|
||||
global _rldb_instance
|
||||
if type == "sqlalchemy":
|
||||
_rldb_instance = SqlAlchemyDB(dsn)
|
||||
else:
|
||||
raise ValueError(f"RelationalDB type {type} not supported")
|
||||
|
||||
def get_instance() -> RelationalDB:
|
||||
global _rldb_instance
|
||||
return _rldb_instance
|
||||
|
||||
if __name__ == "__main__":
|
||||
class TestModel(RLDBBaseModel):
|
||||
__tablename__ = "test_model"
|
||||
name = Column(String(36), nullable=True)
|
||||
conf = Column(String(96), nullable=True)
|
||||
init("sqlalchemy", dsn="sqlite:///./demo_storage/rldb.db")
|
||||
db = get_instance()
|
||||
|
||||
test_data = TestModel(name="test", conf="test.config")
|
||||
print(f"before insert: {test_data}")
|
||||
ret = db.insert(test_data)
|
||||
print(f"after insert: {test_data}")
|
||||
|
||||
print(f"before update: {test_data}")
|
||||
test_data.conf = "test.config.new"
|
||||
ret = db.update(test_data)
|
||||
print(f"after update: {test_data}")
|
||||
|
||||
test2_data = TestModel(name="test", conf="test2.config")
|
||||
print(f"before upsert: {test2_data}")
|
||||
ret = db.upsert(test2_data)
|
||||
print(f"after upsert: {test2_data}")
|
||||
|
||||
get_data = db.get(TestModel, test_data.id)
|
||||
print(f"get data: {get_data}")
|
||||
|
||||
query_data = db.query(TestModel, name="test")
|
||||
for data in query_data:
|
||||
print(data.id, data.name, data.conf)
|
||||
print(f"query data: {data}")
|
||||
ret = db.delete(data)
|
||||
print(f"delete data.id: {ret}")
|
||||
@@ -1,51 +0,0 @@
|
||||
import logging
|
||||
from typing import Protocol
|
||||
import requests
|
||||
from .config import get_instance as get_config
|
||||
|
||||
|
||||
class SMS(Protocol):
|
||||
def send(self, phone: str, content: str) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class FakeSMS:
|
||||
def __init__(self) -> None:
|
||||
conf = get_config()
|
||||
self.fake_message = conf.get('fake_sms', 'message', fallback="FakeSMS")
|
||||
def send(self, phone: str, content: str) -> bool:
|
||||
logging.info(f"{self.fake_message}: phone={phone}, content={content}")
|
||||
return True
|
||||
|
||||
|
||||
class RealSMS:
|
||||
def __init__(self):
|
||||
conf = get_config()
|
||||
self.webhook_url = conf.get('real_sms', 'webhook_url', fallback=None)
|
||||
self.webhook_token = conf.get('real_sms', 'webhook_token', fallback=None)
|
||||
|
||||
def send(self, phone: str, content: str) -> bool:
|
||||
if not self.webhook_url:
|
||||
return False
|
||||
try:
|
||||
headers = {'Authorization': f'Bearer {self.webhook_token}'} if self.webhook_token else {}
|
||||
data = {'phone': phone, 'content': content}
|
||||
resp = requests.post(self.webhook_url, json=data, headers=headers, timeout=5)
|
||||
return resp.status_code >= 200 and resp.status_code < 300
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
_sms: SMS = None
|
||||
|
||||
|
||||
def init(type: str = 'real'):
|
||||
global _sms
|
||||
if type == 'real':
|
||||
_sms = RealSMS()
|
||||
else:
|
||||
_sms = FakeSMS()
|
||||
|
||||
|
||||
def get_instance() -> SMS:
|
||||
return _sms
|
||||
241
src/utils/vsdb.py
Normal file
241
src/utils/vsdb.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import uuid
|
||||
import chromadb
|
||||
import logging
|
||||
from typing import Protocol
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils import embedding_functions
|
||||
from .config import get_instance as get_config
|
||||
|
||||
class VectorDB(Protocol):
|
||||
|
||||
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
|
||||
"""
|
||||
插入向量到数据库
|
||||
|
||||
Args:
|
||||
vector (list[float]): 向量
|
||||
metadata (dict): 元数据
|
||||
|
||||
Returns:
|
||||
bool: 是否插入成功
|
||||
"""
|
||||
...
|
||||
|
||||
def delete(self, ids: list[str]) -> bool:
|
||||
"""
|
||||
Delete documents from a collection.
|
||||
|
||||
Args:
|
||||
ids: List of IDs to delete
|
||||
|
||||
Returns:
|
||||
bool: Whether deletion was successful
|
||||
"""
|
||||
...
|
||||
|
||||
def query(self, metadatas: dict, ids: list[str], top_k: int = 5) -> list[dict]:
|
||||
"""
|
||||
查询向量数据库
|
||||
|
||||
Args:
|
||||
query_vector (list[float]): 查询向量
|
||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict]: 查询结果列表
|
||||
"""
|
||||
...
|
||||
|
||||
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
||||
"""
|
||||
搜索向量数据库
|
||||
|
||||
Args:
|
||||
document: Document to search
|
||||
metadatas: Metadata to filter by
|
||||
ids: List of IDs to filter by
|
||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict]: 查询结果列表
|
||||
"""
|
||||
...
|
||||
|
||||
class ChromaDB:
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize the ChromaDB instance.
|
||||
"""
|
||||
config = get_config()
|
||||
self.embedding_functions = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_base=config.get("voc-engine_embedding", "api_url"),
|
||||
api_key=config.get("voc-engine_embedding", "api_key"),
|
||||
model_name=config.get("voc-engine_embedding", "endpoint"),
|
||||
)
|
||||
persist_directory = config.get("chroma_vsdb", "database_dir", fallback=None)
|
||||
logging.debug(f"persist_directory: {persist_directory}")
|
||||
if persist_directory:
|
||||
self.client = chromadb.PersistentClient(
|
||||
path=persist_directory,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
else:
|
||||
self.client = chromadb.Client(
|
||||
settings=Settings(anonymized_telemetry=False),
|
||||
)
|
||||
self.collection_name = config.get("chroma_vsdb", "collection_name", fallback="peoples")
|
||||
metadata: dict = kwargs.get('collection_metadata', {'hnsw:space': 'cosine'})
|
||||
metadata['hnsw:space'] = metadata.get('hnsw:space', 'cosine')
|
||||
self.collection = self.client.get_or_create_collection(
|
||||
name=self.collection_name,
|
||||
embedding_function=self.embedding_functions,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]:
|
||||
"""
|
||||
Insert documents into a collection.
|
||||
|
||||
Args:
|
||||
metadatas: List of metadata corresponding to each document
|
||||
documents: List of documents to insert
|
||||
ids: Optional list of unique IDs for each document. If None, IDs will be generated.
|
||||
|
||||
Returns:
|
||||
list[str]: List of inserted IDs
|
||||
"""
|
||||
|
||||
if not ids:
|
||||
# Generate unique IDs if not provided
|
||||
ids = [str(uuid.uuid4()) for _ in range(len(documents))]
|
||||
|
||||
self.collection.add(
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids
|
||||
)
|
||||
|
||||
return ids
|
||||
|
||||
def delete(self, ids: list[str]) -> bool:
|
||||
"""
|
||||
Delete documents from a collection.
|
||||
|
||||
Args:
|
||||
ids: List of IDs to delete
|
||||
|
||||
Returns:
|
||||
bool: Whether deletion was successful
|
||||
"""
|
||||
try:
|
||||
self.collection.delete(ids)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Error deleting documents: {e}")
|
||||
return False
|
||||
|
||||
def query(self, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
||||
"""
|
||||
查询向量数据库
|
||||
|
||||
Args:
|
||||
metadatas: Metadata to filter by
|
||||
ids: List of IDs to query
|
||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict]: 查询结果列表
|
||||
"""
|
||||
results = self.collection.query(
|
||||
query_embeddings=None,
|
||||
query_texts=None,
|
||||
n_results=top_k,
|
||||
where=metadatas,
|
||||
ids=ids,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
formatted_results = []
|
||||
for i in range(len(results['ids'][0])):
|
||||
result = {
|
||||
'id': results['ids'][0][i],
|
||||
'distance': results['distances'][0][i],
|
||||
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
|
||||
'document': results['documents'][0][i] if results['documents'][0] else ''
|
||||
}
|
||||
formatted_results.append(result)
|
||||
return formatted_results
|
||||
|
||||
def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]:
|
||||
"""
|
||||
搜索向量数据库
|
||||
|
||||
Args:
|
||||
document: Document to search
|
||||
metadatas: Metadata to filter by
|
||||
ids: List of IDs to filter by
|
||||
top_k (int, optional): 返回Top K结果. Defaults to 5.
|
||||
|
||||
Returns:
|
||||
list[dict]: 查询结果列表
|
||||
"""
|
||||
results = self.collection.query(
|
||||
query_embeddings=None,
|
||||
query_texts=[document],
|
||||
n_results=top_k,
|
||||
where=metadatas if metadatas else None,
|
||||
ids=ids,
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
formatted_results = []
|
||||
for i in range(len(results['ids'][0])):
|
||||
logging.info(f"result id: {results['ids'][0][i]}, distance: {results['distances'][0][i]}")
|
||||
result = {
|
||||
'id': results['ids'][0][i],
|
||||
'distance': results['distances'][0][i],
|
||||
'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {},
|
||||
'document': results['documents'][0][i] if results['documents'][0] else ''
|
||||
}
|
||||
formatted_results.append(result)
|
||||
return formatted_results
|
||||
pass
|
||||
|
||||
_vsdb_instance: VectorDB = None
|
||||
|
||||
def init():
|
||||
global _vsdb_instance
|
||||
_vsdb_instance = ChromaDB()
|
||||
|
||||
def get_instance() -> VectorDB:
|
||||
global _vsdb_instance
|
||||
return _vsdb_instance
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
from logger import init as init_logger
|
||||
init_logger(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG)
|
||||
|
||||
from config import init as init_config
|
||||
config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini")
|
||||
init_config(config_file)
|
||||
|
||||
init()
|
||||
vsdb = get_instance()
|
||||
metadatas = [
|
||||
{'name': '丽丽'},
|
||||
{'name': '志刚'},
|
||||
{'name': '张三'},
|
||||
{'name': '李四'},
|
||||
]
|
||||
documents = [
|
||||
'姓名: 丽丽, 性别: 女, 年龄: 23, 爱好: 爬山、骑行、攀岩、跳伞、蹦极',
|
||||
"姓名: 志刚, 性别: 男, 年龄: 25, 爱好: 读书、游戏",
|
||||
"姓名: 张三, 性别: 男, 年龄: 30, 爱好: 画画、写作、阅读、逛展、旅行",
|
||||
"姓名: 李四, 性别: 男, 年龄: 35, 爱好: 做饭、美食、旅游"
|
||||
]
|
||||
search_text = '25岁以下的'
|
||||
ids = vsdb.insert(metadatas, documents)
|
||||
results = vsdb.search(search_text, None, None, top_k=4)
|
||||
for result in results:
|
||||
print(result['document'], ' ', result['distance'])
|
||||
@@ -1,60 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
from fastapi import FastAPI, UploadFile, File, APIRouter, Depends
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from web.auth import require_auth
|
||||
from utils import obs
|
||||
from web.schemas import BaseResponse
|
||||
from web.custom import router as custom_router
|
||||
from web.people import router as people_router
|
||||
from web.user import router as user_router
|
||||
from web.recognition import router as recognition_router
|
||||
|
||||
api = FastAPI(title="Single People Management and Searching", version="0.1")
|
||||
api.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["https://localhost:5173", "https://ifu.mamamiyear.site"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
authorized_router = APIRouter(dependencies=[Depends(require_auth)])
|
||||
|
||||
@api.post("/api/ping")
|
||||
async def ping():
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
@authorized_router.post("/api/upload/image")
|
||||
async def post_upload_image(image: UploadFile = File(...)):
|
||||
# 实现上传图片的处理
|
||||
# 保存上传的图片文件
|
||||
# 生成唯一的文件名
|
||||
file_extension = os.path.splitext(image.filename)[1]
|
||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# 保存文件到对象存储
|
||||
file_path = f"uploads/{unique_filename}"
|
||||
obs_util = obs.get_instance()
|
||||
await run_in_threadpool(obs_util.put, file_path, await image.read())
|
||||
|
||||
# 获取对象存储外链
|
||||
obs_url = obs_util.get_link(file_path)
|
||||
return BaseResponse(error_code=0, error_info="success", data=obs_url)
|
||||
|
||||
|
||||
api.include_router(authorized_router)
|
||||
|
||||
# Register custom router
|
||||
api.include_router(custom_router, dependencies=[Depends(require_auth)])
|
||||
|
||||
# Register people router
|
||||
api.include_router(people_router, dependencies=[Depends(require_auth)])
|
||||
|
||||
# Register user router
|
||||
api.include_router(user_router)
|
||||
|
||||
# Register recognition router
|
||||
api.include_router(recognition_router, dependencies=[Depends(require_auth)])
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
from typing import Optional
|
||||
from fastapi import Cookie, HTTPException, Request
|
||||
from utils import rldb as rldb_util
|
||||
from models.user import User, UserTokenRLDBModel, UserRLDBModel
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def require_auth(request: Request, token: Optional[str] = Cookie(None)):
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="unauthorized")
|
||||
db = rldb_util.get_instance()
|
||||
tokens = db.query(UserTokenRLDBModel, token=token, limit=1)
|
||||
if not tokens:
|
||||
raise HTTPException(status_code=401, detail="unauthorized")
|
||||
t = tokens[0]
|
||||
if getattr(t, 'revoked', False):
|
||||
raise HTTPException(status_code=401, detail="unauthorized")
|
||||
if getattr(t, 'expired_at', None) and t.expired_at < datetime.now():
|
||||
raise HTTPException(status_code=401, detail="unauthorized")
|
||||
user_orm = db.get(UserRLDBModel, t.user_id)
|
||||
if not user_orm:
|
||||
raise HTTPException(status_code=401, detail="unauthorized")
|
||||
user = User.from_rldb_model(user_orm)
|
||||
request.state.user_id = user.id
|
||||
request.state.user_nickname = user.nickname
|
||||
request.state.user_email = user.email
|
||||
request.state.user_phone = user.phone
|
||||
request.state.token = token
|
||||
@@ -1,151 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, Request, Query, UploadFile, File
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
from models.custom import Custom
|
||||
from services.custom import get_instance as get_custom_service
|
||||
from utils.error import ErrorCode
|
||||
from utils import obs
|
||||
from web.schemas import BaseResponse
|
||||
|
||||
router = APIRouter(tags=["custom"])
|
||||
|
||||
class PostCustomRequest(BaseModel):
|
||||
custom: dict
|
||||
|
||||
@router.post("/api/custom")
|
||||
def create_custom(request: Request, post_custom_request: PostCustomRequest):
|
||||
logging.debug(f"post_custom_request: {post_custom_request}")
|
||||
custom = Custom.from_dict(post_custom_request.custom)
|
||||
|
||||
# Validate custom data
|
||||
err = custom.validate()
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
|
||||
custom.user_id = getattr(request.state, 'user_id', '')
|
||||
|
||||
service = get_custom_service()
|
||||
custom.id, error = service.save(custom)
|
||||
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success", data=custom.id)
|
||||
|
||||
@router.put("/api/custom/{custom_id}")
|
||||
def update_custom(request: Request, custom_id: str, post_custom_request: PostCustomRequest):
|
||||
logging.debug(f"post_custom_request: {post_custom_request}")
|
||||
custom = Custom.from_dict(post_custom_request.custom)
|
||||
custom.id = custom_id
|
||||
|
||||
# Validate custom data
|
||||
err = custom.validate()
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
|
||||
service = get_custom_service()
|
||||
# Check permission
|
||||
res, error = service.get(custom_id)
|
||||
if not error.success or not res:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
|
||||
custom.user_id = res.user_id # Ensure user_id is not changed or is set correctly
|
||||
|
||||
_, error = service.save(custom)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
@router.delete("/api/custom/{custom_id}")
|
||||
def delete_custom(request: Request, custom_id: str):
|
||||
service = get_custom_service()
|
||||
res, error = service.get(custom_id)
|
||||
if not error.success or not res:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
error = service.delete(custom_id)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success", data=custom_id)
|
||||
|
||||
@router.get("/api/customs")
|
||||
def get_customs(request: Request, limit: int = Query(10, ge=1, le=1000), offset: int = Query(0, ge=0)):
|
||||
service = get_custom_service()
|
||||
res, error = service.list({'user_id': getattr(request.state, 'user_id', '')}, limit=limit, offset=offset)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
# custom对象转换为字典
|
||||
customs = [custom.to_dict() for custom in res]
|
||||
return BaseResponse(error_code=0, error_info="success", data=customs)
|
||||
|
||||
@router.get("/api/custom/{custom_id}")
|
||||
def get_custom(request: Request, custom_id: str):
|
||||
service = get_custom_service()
|
||||
res, error = service.get(custom_id)
|
||||
if not error.success or not res:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
return BaseResponse(error_code=0, error_info="success", data=res.to_dict())
|
||||
|
||||
|
||||
@router.post("/api/custom/{custom_id}/image")
|
||||
async def post_custom_image(request: Request, custom_id: str, image: UploadFile = File(...)):
|
||||
# 检查 custom id 是否存在
|
||||
service = get_custom_service()
|
||||
custom, err = service.get(custom_id)
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
|
||||
if custom.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
|
||||
# 实现上传图片的处理
|
||||
# 保存上传的图片文件
|
||||
# 生成唯一的文件名
|
||||
file_extension = os.path.splitext(image.filename)[1]
|
||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# 保存文件到对象存储
|
||||
file_path = f"customs/{custom_id}/images/{unique_filename}"
|
||||
obs_util = obs.get_instance()
|
||||
await run_in_threadpool(obs_util.put, file_path, await image.read())
|
||||
|
||||
# 获取对象存储外链
|
||||
obs_url = obs_util.get_link(file_path)
|
||||
logging.info(f"obs_url: {obs_url}")
|
||||
|
||||
return BaseResponse(error_code=0, error_info="success", data=obs_url)
|
||||
|
||||
|
||||
@router.delete("/api/custom/{custom_id}/image")
|
||||
async def delete_custom_image(request: Request, custom_id: str, image_url: str):
|
||||
# 检查 custom id 是否存在
|
||||
service = get_custom_service()
|
||||
custom, err = service.get(custom_id)
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
|
||||
if custom.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
|
||||
# 检查 image_url 是否是该 custom 名下的图片链接
|
||||
obs_util = obs.get_instance()
|
||||
obs_path, err = obs_util.get_obs_path_by_link(image_url)
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
if not obs_path.startswith(f"customs/{custom_id}/images/"):
|
||||
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {custom_id} 名下的图片链接")
|
||||
|
||||
# 实现删除图片的处理
|
||||
# 删除对象存储中的文件
|
||||
err = obs_util.delete_by_link(image_url)
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
@@ -1,192 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Request, UploadFile, File, Query
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.people import get_instance as get_people_service
|
||||
from models.people import People
|
||||
from utils import obs
|
||||
from utils.error import ErrorCode
|
||||
from web.schemas import BaseResponse
|
||||
|
||||
router = APIRouter(tags=["people"])
|
||||
|
||||
|
||||
class PostPeopleRequest(BaseModel):
|
||||
people: dict
|
||||
|
||||
@router.post("/api/people")
|
||||
async def post_people(request: Request, post_people_request: PostPeopleRequest):
|
||||
logging.debug(f"post_people_request: {post_people_request}")
|
||||
people = People.from_dict(post_people_request.people)
|
||||
people.user_id = getattr(request.state, 'user_id', '')
|
||||
service = get_people_service()
|
||||
people.id, error = service.save(people)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success", data=people.id)
|
||||
|
||||
@router.put("/api/people/{people_id}")
|
||||
async def update_people(request: Request, people_id: str, post_people_request: PostPeopleRequest):
|
||||
logging.debug(f"post_people_request: {post_people_request}")
|
||||
people = People.from_dict(post_people_request.people)
|
||||
people.id = people_id
|
||||
service = get_people_service()
|
||||
res, error = service.get(people_id)
|
||||
if not error.success or not res:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
people.user_id = res.user_id
|
||||
_, error = service.save(people)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
@router.delete("/api/people/{people_id}")
|
||||
async def delete_people(request: Request, people_id: str):
|
||||
service = get_people_service()
|
||||
res, err = service.get(people_id)
|
||||
if not err.success or not res:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
error = service.delete(people_id)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
class GetPeopleRequest(BaseModel):
|
||||
query: Optional[str] = None
|
||||
conds: Optional[dict] = None
|
||||
top_k: int = 5
|
||||
|
||||
@router.get("/api/peoples")
|
||||
async def get_peoples(
|
||||
request: Request,
|
||||
name: Optional[str] = Query(None, description="姓名"),
|
||||
gender: Optional[str] = Query(None, description="性别"),
|
||||
age: Optional[int] = Query(None, description="年龄"),
|
||||
height: Optional[int] = Query(None, description="身高"),
|
||||
marital_status: Optional[str] = Query(None, description="婚姻状态"),
|
||||
limit: int = Query(10, description="分页大小"),
|
||||
offset: int = Query(0, description="分页偏移量"),
|
||||
):
|
||||
|
||||
# 解析查询参数为字典
|
||||
conds = {}
|
||||
conds["user_id"] = getattr(request.state, 'user_id', '')
|
||||
if name:
|
||||
conds["name"] = name
|
||||
if gender:
|
||||
conds["gender"] = gender
|
||||
if age:
|
||||
conds["age"] = age
|
||||
if height:
|
||||
conds["height"] = height
|
||||
if marital_status:
|
||||
conds["marital_status"] = marital_status
|
||||
|
||||
logging.info(f"conds: , limit: {limit}, offset: {offset}")
|
||||
|
||||
results = []
|
||||
service = get_people_service()
|
||||
results, error = service.list(conds, limit=limit, offset=offset)
|
||||
logging.info(f"query results: {results}")
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
peoples = [people.to_dict() for people in results]
|
||||
return BaseResponse(error_code=0, error_info="success", data=peoples)
|
||||
|
||||
|
||||
class RemarkRequest(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
@router.post("/api/people/{people_id}/remark")
|
||||
async def post_people_remark(request: Request, people_id: str, body: RemarkRequest):
|
||||
service = get_people_service()
|
||||
res, err = service.get(people_id)
|
||||
if not err.success or not res:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
error = service.save_remark(people_id, body.content)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
|
||||
@router.delete("/api/people/{people_id}/remark")
|
||||
async def delete_people_remark(request: Request, people_id: str):
|
||||
service = get_people_service()
|
||||
res, err = service.get(people_id)
|
||||
if not err.success or not res:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
if res.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
error = service.delete_remark(people_id)
|
||||
if not error.success:
|
||||
return BaseResponse(error_code=error.code, error_info=error.info)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
|
||||
@router.post("/api/people/{people_id}/image")
|
||||
async def post_people_image(request: Request, people_id: str, image: UploadFile = File(...)):
|
||||
|
||||
# 检查 people id 是否存在
|
||||
service = get_people_service()
|
||||
people, err = service.get(people_id)
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
|
||||
if people.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
|
||||
# 实现上传图片的处理
|
||||
# 保存上传的图片文件
|
||||
# 生成唯一的文件名
|
||||
file_extension = os.path.splitext(image.filename)[1]
|
||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# 保存文件到对象存储
|
||||
file_path = f"peoples/{people_id}/images/{unique_filename}"
|
||||
obs_util = obs.get_instance()
|
||||
await run_in_threadpool(obs_util.put, file_path, await image.read())
|
||||
|
||||
# 获取对象存储外链
|
||||
obs_url = obs_util.get_link(file_path)
|
||||
logging.info(f"obs_url: {obs_url}")
|
||||
|
||||
return BaseResponse(error_code=0, error_info="success", data=obs_url)
|
||||
|
||||
|
||||
@router.delete("/api/people/{people_id}/image")
|
||||
async def delete_people_image(request: Request, people_id: str, image_url: str):
|
||||
# 检查 people id 是否存在
|
||||
service = get_people_service()
|
||||
people, err = service.get(people_id)
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
|
||||
if people.user_id != getattr(request.state, 'user_id', ''):
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied")
|
||||
|
||||
# 检查 image_url 是否是该 people 名下的图片链接
|
||||
obs_util = obs.get_instance()
|
||||
obs_path, err = obs_util.get_obs_path_by_link(image_url)
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
if not obs_path.startswith(f"peoples/{people_id}/images/"):
|
||||
return BaseResponse(error_code=ErrorCode.OBS_INPUT_ERROR, error_info=f"文件 {image_url} 不是 {people_id} 名下的图片链接")
|
||||
|
||||
# 实现删除图片的处理
|
||||
# 删除对象存储中的文件
|
||||
err = obs_util.delete_by_link(image_url)
|
||||
if not err.success:
|
||||
return BaseResponse(error_code=err.code, error_info=err.info)
|
||||
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
@@ -1,91 +0,0 @@
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, UploadFile, File
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from models.people import People
|
||||
from models.custom import Custom
|
||||
from agents.extract_people_agent import ExtractPeopleAgent
|
||||
from agents.extract_custom_agent import ExtractCustomAgent
|
||||
from utils import obs, ocr
|
||||
from web.schemas import BaseResponse
|
||||
from utils.error import ErrorCode
|
||||
|
||||
router = APIRouter(tags=["recognition"])
|
||||
|
||||
def extract_people(text: str, cover_link: str = None) -> People:
|
||||
extra_agent = ExtractPeopleAgent()
|
||||
people = extra_agent.extract_people_info(text)
|
||||
if people:
|
||||
people.cover = cover_link
|
||||
logging.info(f"people: {people}")
|
||||
return people
|
||||
|
||||
def extract_custom(text: str, image_link: str = None) -> Custom:
|
||||
extra_agent = ExtractCustomAgent()
|
||||
custom = extra_agent.extract_custom_info(text)
|
||||
if custom:
|
||||
if image_link:
|
||||
custom.images = [image_link]
|
||||
logging.info(f"custom: {custom}")
|
||||
return custom
|
||||
|
||||
class PostInputRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
@router.post("/api/recognition/{model}/input")
|
||||
async def post_recognition_input(model: str, request: PostInputRequest):
|
||||
if model == "people":
|
||||
result = await run_in_threadpool(extract_people, request.text)
|
||||
elif model == "custom":
|
||||
result = await run_in_threadpool(extract_custom, request.text)
|
||||
else:
|
||||
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
|
||||
|
||||
if result is None:
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
|
||||
|
||||
resp = BaseResponse(error_code=0, error_info="success")
|
||||
resp.data = result.to_dict()
|
||||
return resp
|
||||
|
||||
@router.post("/api/recognition/{model}/image")
|
||||
async def post_recognition_image(model: str, image: UploadFile = File(...)):
|
||||
if model not in ["people", "custom"]:
|
||||
return BaseResponse(error_code=ErrorCode.INVALID_PARAMS.value, error_info=f"Unknown model: {model}")
|
||||
|
||||
# 实现上传图片的处理
|
||||
# 保存上传的图片文件
|
||||
# 生成唯一的文件名
|
||||
file_extension = os.path.splitext(image.filename)[1]
|
||||
unique_filename = f"{uuid.uuid4()}{file_extension}"
|
||||
|
||||
# 保存文件到对象存储
|
||||
file_path = f"uploads/{model}/{unique_filename}"
|
||||
obs_util = obs.get_instance()
|
||||
await run_in_threadpool(obs_util.put, file_path, await image.read())
|
||||
|
||||
# 获取对象存储外链
|
||||
obs_url = obs_util.get_link(file_path)
|
||||
logging.info(f"obs_url: {obs_url}")
|
||||
|
||||
# 调用OCR处理图片
|
||||
ocr_util = ocr.get_instance()
|
||||
ocr_result = await run_in_threadpool(ocr_util.recognize_image_text, obs_url)
|
||||
logging.info(f"ocr_result: {ocr_result}")
|
||||
|
||||
if model == "people":
|
||||
result = await run_in_threadpool(extract_people, ocr_result, obs_url)
|
||||
elif model == "custom":
|
||||
result = await run_in_threadpool(extract_custom, ocr_result, obs_url)
|
||||
|
||||
if result is None:
|
||||
return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="Extraction failed")
|
||||
|
||||
resp = BaseResponse(error_code=0, error_info="success")
|
||||
resp.data = result.to_dict()
|
||||
return resp
|
||||
@@ -1,7 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
error_code: int
|
||||
error_info: str
|
||||
data: Optional[Any] = None
|
||||
223
src/web/user.py
223
src/web/user.py
@@ -1,223 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Literal
|
||||
from fastapi import APIRouter, Depends, Request, HTTPException, Response, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
from services.user import get_instance as get_user_service
|
||||
from web.auth import require_auth
|
||||
from utils import obs
|
||||
from utils.config import get_instance as get_config
|
||||
from web.schemas import BaseResponse
|
||||
|
||||
router = APIRouter(tags=["user"])
|
||||
|
||||
class SendCodeRequest(BaseModel):
|
||||
target_type: str
|
||||
target: str
|
||||
scene: Literal['register', 'update']
|
||||
# scene: Literal['register', 'login']
|
||||
|
||||
|
||||
@router.post("/api/user/send_code")
|
||||
async def send_user_code(request: SendCodeRequest):
|
||||
service = get_user_service()
|
||||
err = service.send_code(request.target_type, request.target, request.scene)
|
||||
if not err.success:
|
||||
raise HTTPException(status_code=400, detail=err.info)
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
nickname: Optional[str] = None
|
||||
avatar_link: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
password: str
|
||||
code: str
|
||||
|
||||
@router.post("/api/user")
|
||||
async def user_register(request: RegisterRequest):
|
||||
service = get_user_service()
|
||||
from models.user import User
|
||||
u = User(
|
||||
nickname=request.nickname or "",
|
||||
avatar_link=request.avatar_link or "",
|
||||
email=request.email or "",
|
||||
phone=request.phone or "",
|
||||
password_hash=request.password,
|
||||
)
|
||||
uid, err = service.register(u, request.code)
|
||||
if not err.success:
|
||||
logging.error(f"register failed: {err}")
|
||||
raise HTTPException(status_code=400, detail=err.info)
|
||||
return BaseResponse(error_code=0, error_info="success", data=uid)
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
password: str
|
||||
|
||||
@router.post("/api/user/login")
|
||||
async def user_login(request: LoginRequest, response: Response):
|
||||
service = get_user_service()
|
||||
data, err = service.login(request.email, request.phone, request.password)
|
||||
if not err.success:
|
||||
raise HTTPException(status_code=400, detail=err.info)
|
||||
conf = get_config()
|
||||
ttl_days = conf.getint('auth', 'token_ttl_days', fallback=30)
|
||||
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
|
||||
cookie_secure = conf.getboolean('auth', 'cookie_secure', fallback=False)
|
||||
cookie_samesite = conf.get('auth', 'cookie_samesite', fallback=None)
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=data.get('token', ''),
|
||||
max_age=ttl_days * 24 * 3600,
|
||||
httponly=True,
|
||||
secure=cookie_secure,
|
||||
samesite=cookie_samesite,
|
||||
domain=cookie_domain,
|
||||
path="/",
|
||||
)
|
||||
return BaseResponse(error_code=0, error_info="success", data={"expired_at": data.get('expired_at')})
|
||||
|
||||
|
||||
@router.delete("/api/user/me/login", dependencies=[Depends(require_auth)])
|
||||
async def user_logout(response: Response, request: Request):
|
||||
service = get_user_service()
|
||||
err = service.logout(getattr(request.state, 'token', None))
|
||||
if not err.success:
|
||||
raise HTTPException(status_code=400, detail=err.info)
|
||||
conf = get_config()
|
||||
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
|
||||
response.delete_cookie(key="token", domain=cookie_domain, path="/")
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
|
||||
@router.delete("/api/user/me", dependencies=[Depends(require_auth)])
|
||||
async def user_delete(response: Response, request: Request):
|
||||
service = get_user_service()
|
||||
err = service.delete_user(getattr(request.state, 'user_id', None))
|
||||
if not err.success:
|
||||
raise HTTPException(status_code=400, detail=err.info)
|
||||
conf = get_config()
|
||||
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
|
||||
response.delete_cookie(key="token", domain=cookie_domain, path="/")
|
||||
return BaseResponse(error_code=0, error_info="success")
|
||||
|
||||
@router.get("/api/user/me", dependencies=[Depends(require_auth)])
|
||||
async def user_me(request: Request):
|
||||
service = get_user_service()
|
||||
user, err = service.get(getattr(request.state, 'user_id', None))
|
||||
if not err.success or not user:
|
||||
raise HTTPException(status_code=400, detail=err.info)
|
||||
data = {
|
||||
'nickname': user.nickname,
|
||||
'avatar_link': user.avatar_link,
|
||||
'phone': user.phone,
|
||||
'email': user.email,
|
||||
}
|
||||
return BaseResponse(error_code=0, error_info="success", data=data)
|
||||
|
||||
|
||||
class UpdateMeRequest(BaseModel):
|
||||
nickname: Optional[str] = None
|
||||
avatar_link: Optional[str] = None
|
||||
phone: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
|
||||
@router.put("/api/user/me", dependencies=[Depends(require_auth)])
|
||||
async def update_user_me(request: Request, body: UpdateMeRequest):
|
||||
service = get_user_service()
|
||||
user, err = service.update_profile(
|
||||
getattr(request.state, 'user_id', None),
|
||||
nickname=body.nickname,
|
||||
avatar_link=body.avatar_link,
|
||||
phone=body.phone,
|
||||
email=body.email,
|
||||
)
|
||||
if not err.success:
|
||||
raise HTTPException(status_code=400, detail=err.info)
|
||||
data = {
|
||||
'nickname': user.nickname,
|
||||
'avatar_link': user.avatar_link,
|
||||
'phone': user.phone,
|
||||
'email': user.email,
|
||||
}
|
||||
return BaseResponse(error_code=0, error_info="success", data=data)
|
||||
|
||||
|
||||
@router.put("/api/user/me/avatar", dependencies=[Depends(require_auth)])
|
||||
async def upload_avatar(request: Request, avatar: UploadFile = File(...)):
|
||||
user_id = getattr(request.state, 'user_id', None)
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="unauthorized")
|
||||
|
||||
file_extension = os.path.splitext(avatar.filename)[1]
|
||||
timestamp = int(time.time())
|
||||
avatar_path = f"users/{user_id}/avatar-{timestamp}{file_extension}"
|
||||
|
||||
try:
|
||||
obs_util = obs.get_instance()
|
||||
obs_util.Put(avatar_path, await avatar.read())
|
||||
avatar_url = obs_util.Link(avatar_path)
|
||||
|
||||
user_service = get_user_service()
|
||||
_, err = user_service.update_profile(user_id, avatar_link=avatar_url)
|
||||
if not err.success:
|
||||
raise HTTPException(status_code=500, detail=err.info)
|
||||
|
||||
return BaseResponse(error_code=0, error_info="success", data={"avatar_link": avatar_url})
|
||||
except Exception as e:
|
||||
logging.error(f"upload avatar failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="upload avatar failed")
|
||||
|
||||
|
||||
class UpdatePhoneRequest(BaseModel):
|
||||
phone: str
|
||||
code: str
|
||||
|
||||
|
||||
@router.put("/api/user/me/phone", dependencies=[Depends(require_auth)])
|
||||
async def update_user_phone(request: Request, body: UpdatePhoneRequest):
|
||||
service = get_user_service()
|
||||
user, err = service.update_phone_with_code(
|
||||
getattr(request.state, 'user_id', None),
|
||||
body.phone,
|
||||
body.code,
|
||||
)
|
||||
if not err.success:
|
||||
raise HTTPException(status_code=400, detail=err.info)
|
||||
data = {
|
||||
'nickname': user.nickname,
|
||||
'avatar_link': user.avatar_link,
|
||||
'phone': user.phone,
|
||||
'email': user.email,
|
||||
}
|
||||
return BaseResponse(error_code=0, error_info="success", data=data)
|
||||
|
||||
|
||||
class UpdateEmailRequest(BaseModel):
|
||||
email: str
|
||||
code: str
|
||||
|
||||
|
||||
@router.put("/api/user/me/email", dependencies=[Depends(require_auth)])
|
||||
async def update_user_email(request: Request, body: UpdateEmailRequest):
|
||||
service = get_user_service()
|
||||
user, err = service.update_email_with_code(
|
||||
getattr(request.state, 'user_id', None),
|
||||
body.email,
|
||||
body.code,
|
||||
)
|
||||
if not err.success:
|
||||
raise HTTPException(status_code=400, detail=err.info)
|
||||
data = {
|
||||
'nickname': user.nickname,
|
||||
'avatar_link': user.avatar_link,
|
||||
'phone': user.phone,
|
||||
'email': user.email,
|
||||
}
|
||||
return BaseResponse(error_code=0, error_info="success", data=data)
|
||||
16
test/test_logger.py
Normal file
16
test/test_logger.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../src'))
|
||||
|
||||
from utils.logger import init
|
||||
import logging
|
||||
|
||||
# 初始化日志
|
||||
init()
|
||||
|
||||
# 测试不同级别的日志
|
||||
logging.debug("这是一条调试信息")
|
||||
logging.info("这是一条普通信息")
|
||||
logging.warning("这是一条警告信息")
|
||||
logging.error("这是一条错误信息")
|
||||
logging.critical("这是一条严重错误信息")
|
||||
Reference in New Issue
Block a user