diff --git a/src/models/people.py b/src/models/people.py index 68bbadc..e6c1d27 100644 --- a/src/models/people.py +++ b/src/models/people.py @@ -12,6 +12,7 @@ from utils.error import ErrorCode, error class PeopleRLDBModel(RLDBBaseModel): __tablename__ = 'peoples' id = Column(String(36), primary_key=True) + user_id = Column(String(36), index=True) name = Column(String(255), index=True) contact = Column(String(255), index=True) gender = Column(String(10)) @@ -61,6 +62,8 @@ class Comment: class People: # 数据库 ID id: str + # 所属用户 ID + user_id: str # 姓名 name: str # 联系人 @@ -87,6 +90,7 @@ class People: def __init__(self, **kwargs): # 初始化所有属性,从kwargs中获取值,如果不存在则设置默认值 self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else '' + self.user_id = kwargs.get('user_id', '') if kwargs.get('user_id', '') is not None else '' self.name = kwargs.get('name', '') if kwargs.get('name', '') is not None else '' self.contact = kwargs.get('contact', '') if kwargs.get('contact', '') is not None else '' self.gender = kwargs.get('gender', '') if kwargs.get('gender', '') is not None else '' @@ -121,6 +125,7 @@ class People: # 将关系数据库模型转换为对象 return cls( id=data.id, + user_id=data.user_id, name=data.name, contact=data.contact, gender=data.gender, @@ -138,6 +143,7 @@ class People: # 将对象转换为字典格式 return { 'id': self.id, + 'user_id': self.user_id, 'name': self.name, 'contact': self.contact, 'gender': self.gender, @@ -155,6 +161,7 @@ class People: # 将对象转换为关系数据库模型 return PeopleRLDBModel( id=self.id, + user_id=self.user_id, name=self.name, contact=self.contact, gender=self.gender, diff --git a/src/utils/rldb.py b/src/utils/rldb.py index 6371016..9b4f5b8 100644 --- a/src/utils/rldb.py +++ b/src/utils/rldb.py @@ -1,4 +1,5 @@ +from re import S from typing import Protocol import uuid from sqlalchemy import Column, DateTime, String, create_engine, func diff --git a/src/web/api.py b/src/web/api.py index d329661..3242637 100644 --- a/src/web/api.py +++ b/src/web/api.py @@ -13,6 +13,7 @@ from models.people import People from agents.extract_people_agent import ExtractPeopleAgent from utils import obs, ocr from utils.config import get_instance as get_config +from utils.error import ErrorCode api = FastAPI(title="Single People Management and Searching", version="0.1") api.add_middleware( @@ -84,18 +85,19 @@ def extract_people(text: str, cover_link: str = None) -> People: class PostPeopleRequest(BaseModel): people: dict -@api.post("/api/people") -async def post_people(post_people_request: PostPeopleRequest): +@authorized_router.post("/api/people") +async def post_people(request: Request, post_people_request: PostPeopleRequest): logging.debug(f"post_people_request: {post_people_request}") people = People.from_dict(post_people_request.people) + people.user_id = getattr(request.state, 'user_id', '') service = get_people_service() people.id, error = service.save(people) if not error.success: return BaseResponse(error_code=error.code, error_info=error.info) return BaseResponse(error_code=0, error_info="success", data=people.id) -@api.put("/api/people/{people_id}") -async def update_people(people_id: str, post_people_request: PostPeopleRequest): +@authorized_router.put("/api/people/{people_id}") +async def update_people(request: Request, people_id: str, post_people_request: PostPeopleRequest): logging.debug(f"post_people_request: {post_people_request}") people = People.from_dict(post_people_request.people) people.id = people_id @@ -103,14 +105,22 @@ async def update_people(people_id: str, post_people_request: PostPeopleRequest): res, error = service.get(people_id) if not error.success or not res: return BaseResponse(error_code=error.code, error_info=error.info) + if res.user_id != getattr(request.state, 'user_id', ''): + return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied") + people.user_id = res.user_id _, error = service.save(people) if not error.success: return BaseResponse(error_code=error.code, error_info=error.info) return BaseResponse(error_code=0, error_info="success") -@api.delete("/api/people/{people_id}") -async def delete_people(people_id: str): +@authorized_router.delete("/api/people/{people_id}") +async def delete_people(request: Request, people_id: str): service = get_people_service() + res, err = service.get(people_id) + if not err.success or not res: + return BaseResponse(error_code=err.code, error_info=err.info) + if res.user_id != getattr(request.state, 'user_id', ''): + return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied") error = service.delete(people_id) if not error.success: return BaseResponse(error_code=error.code, error_info=error.info) @@ -121,8 +131,9 @@ class GetPeopleRequest(BaseModel): conds: Optional[dict] = None top_k: int = 5 -@api.get("/api/peoples") +@authorized_router.get("/api/peoples") async def get_peoples( + request: Request, name: Optional[str] = Query(None, description="姓名"), gender: Optional[str] = Query(None, description="性别"), age: Optional[int] = Query(None, description="年龄"), @@ -134,6 +145,7 @@ async def get_peoples( # 解析查询参数为字典 conds = {} + conds["user_id"] = getattr(request.state, 'user_id', '') if name: conds["name"] = name if gender: @@ -161,18 +173,28 @@ class RemarkRequest(BaseModel): content: str -@api.post("/api/people/{people_id}/remark") -async def post_remark(people_id: str, request: RemarkRequest): +@authorized_router.post("/api/people/{people_id}/remark") +async def post_remark(request: Request, people_id: str, body: RemarkRequest): service = get_people_service() - error = service.save_remark(people_id, request.content) + res, err = service.get(people_id) + if not err.success or not res: + return BaseResponse(error_code=err.code, error_info=err.info) + if res.user_id != getattr(request.state, 'user_id', ''): + return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied") + error = service.save_remark(people_id, body.content) if not error.success: return BaseResponse(error_code=error.code, error_info=error.info) return BaseResponse(error_code=0, error_info="success") -@api.delete("/api/people/{people_id}/remark") -async def delete_remark(people_id: str): +@authorized_router.delete("/api/people/{people_id}/remark") +async def delete_remark(request: Request, people_id: str): service = get_people_service() + res, err = service.get(people_id) + if not err.success or not res: + return BaseResponse(error_code=err.code, error_info=err.info) + if res.user_id != getattr(request.state, 'user_id', ''): + return BaseResponse(error_code=ErrorCode.MODEL_ERROR.value, error_info="permission denied") error = service.delete_remark(people_id) if not error.success: return BaseResponse(error_code=error.code, error_info=error.info)