diff --git a/src/main.py b/src/main.py index c0b1460..365d18f 100644 --- a/src/main.py +++ b/src/main.py @@ -4,7 +4,8 @@ import os import argparse import uvicorn from services import people as people_service -from utils import config, logger, obs, ocr, rldb +from services import user as user_service +from utils import config, logger, obs, ocr, rldb, sms, mailer from web.api import api @@ -16,16 +17,19 @@ def main(): args = parser.parse_args() config.init(args.config) + conf = config.get_instance() + logger.init() - rldb.init() - ocr.init() obs.init() - + mailer.init(conf.get('mailer', 'type', fallback='real')) + sms.init(conf.get('sms', 'type', fallback='real')) + people_service.init() + user_service.init() + - conf = config.get_instance() host = conf.get('web_service', 'server_host', fallback='0.0.0.0') port = conf.getint('web_service', 'server_port', fallback=8099) diff --git a/src/models/user.py b/src/models/user.py new file mode 100644 index 0000000..bf840a6 --- /dev/null +++ b/src/models/user.py @@ -0,0 +1,184 @@ +from typing import Optional +import json +from datetime import datetime, timedelta +from sqlalchemy import Column, String, Text, DateTime, Integer, Boolean, func, UniqueConstraint +from utils.rldb import RLDBBaseModel +from utils.error import ErrorCode, error + + +class UserRLDBModel(RLDBBaseModel): + __tablename__ = 'users' + id = Column(String(36), primary_key=True) + nickname = Column(String(255)) + avatar_link = Column(String(255)) + email = Column(String(127), unique=True, index=True) + phone = Column(String(32), unique=True, index=True) + password_hash = Column(String(255)) + created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) + deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) + + +class VerificationCodeRLDBModel(RLDBBaseModel): + __tablename__ = 'verification_codes' + id = Column(String(36), primary_key=True) + target_type = Column(String(16)) + target = Column(String(255), index=True) + code = Column(String(16)) + scene = Column(String(32)) + expires_at = Column(DateTime(timezone=True)) + used_at = Column(DateTime(timezone=True), nullable=True) + + +class UserTokenRLDBModel(RLDBBaseModel): + __tablename__ = 'user_tokens' + id = Column(String(36), primary_key=True) + user_id = Column(String(36), index=True) + token = Column(Text) + expired_at = Column(DateTime(timezone=True)) + revoked = Column(Boolean, default=False) + + +class User: + id: str + nickname: str + avatar_link: str + email: str + phone: str + password_hash: str + created_at: datetime = None + + def __init__(self, **kwargs): + self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else '' + self.nickname = kwargs.get('nickname', '') if kwargs.get('nickname', '') is not None else '' + self.avatar_link = kwargs.get('avatar_link', '') if kwargs.get('avatar_link', '') is not None else '' + self.email = kwargs.get('email', '') if kwargs.get('email', '') is not None else '' + self.phone = kwargs.get('phone', '') if kwargs.get('phone', '') is not None else '' + self.password_hash = kwargs.get('password_hash', '') if kwargs.get('password_hash', '') is not None else '' + self.created_at = kwargs.get('created_at', None) + + def __str__(self) -> str: + return (f"User(id={self.id}, nickname={self.nickname}, avatar_link={self.avatar_link}, " + f"email={self.email}, phone={self.phone}, created_at={self.created_at})") + + @classmethod + def from_dict(cls, data: dict): + if 'updated_at' in data: + del data['updated_at'] + if 'deleted_at' in data: + del data['deleted_at'] + return cls(**data) + + @classmethod + def from_rldb_model(cls, data: UserRLDBModel): + return cls( + id=data.id, + nickname=data.nickname, + avatar_link=data.avatar_link, + email=data.email, + phone=data.phone, + password_hash=data.password_hash, + created_at=data.created_at, + ) + + def to_dict(self) -> dict: + return { + 'id': self.id, + 'nickname': self.nickname, + 'avatar_link': self.avatar_link, + 'email': self.email, + 'phone': self.phone, + 'created_at': int(self.created_at.timestamp()) if self.created_at else None, + } + + def to_rldb_model(self) -> UserRLDBModel: + return UserRLDBModel( + id=self.id, + nickname=self.nickname, + avatar_link=self.avatar_link, + email=self.email, + phone=self.phone, + password_hash=self.password_hash, + ) + + def validate(self) -> error: + err = error(ErrorCode.SUCCESS, "") + if not self.email and not self.phone: + return error(ErrorCode.MODEL_ERROR, "email or phone required") + return err + + +class VerificationCode: + id: str + target_type: str + target: str + code: str + scene: str + expires_at: datetime + used_at: Optional[datetime] = None + + def __init__(self, **kwargs): + self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else '' + self.target_type = kwargs.get('target_type', '') + self.target = kwargs.get('target', '') + self.code = kwargs.get('code', '') + self.scene = kwargs.get('scene', '') + self.expires_at = kwargs.get('expires_at') + self.used_at = kwargs.get('used_at', None) + + @classmethod + def from_rldb_model(cls, data: VerificationCodeRLDBModel): + return cls( + id=data.id, + target_type=data.target_type, + target=data.target, + code=data.code, + scene=data.scene, + expires_at=data.expires_at, + used_at=data.used_at, + ) + + def to_rldb_model(self) -> VerificationCodeRLDBModel: + return VerificationCodeRLDBModel( + id=self.id, + target_type=self.target_type, + target=self.target, + code=self.code, + scene=self.scene, + expires_at=self.expires_at, + used_at=self.used_at, + ) + + +class UserToken: + id: str + user_id: str + token: str + expired_at: datetime + revoked: bool + + def __init__(self, **kwargs): + self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else '' + self.user_id = kwargs.get('user_id', '') + self.token = kwargs.get('token', '') + self.expired_at = kwargs.get('expired_at') + self.revoked = kwargs.get('revoked', False) + + @classmethod + def from_rldb_model(cls, data: UserTokenRLDBModel): + return cls( + id=data.id, + user_id=data.user_id, + token=data.token, + expired_at=data.expired_at, + revoked=data.revoked, + ) + + def to_rldb_model(self) -> UserTokenRLDBModel: + return UserTokenRLDBModel( + id=self.id, + user_id=self.user_id, + token=self.token, + expired_at=self.expired_at, + revoked=self.revoked, + ) \ No newline at end of file diff --git a/src/services/user.py b/src/services/user.py new file mode 100644 index 0000000..8a43ed0 --- /dev/null +++ b/src/services/user.py @@ -0,0 +1,220 @@ +import uuid +import hmac +import base64 +import os +from datetime import datetime, timedelta +from typing import Optional +from utils.error import ErrorCode, error +from utils import rldb, mailer, sms, config +from models.user import ( + User, + UserRLDBModel, + VerificationCode, + VerificationCodeRLDBModel, + UserToken, + UserTokenRLDBModel, +) + + +class UserService: + def __init__(self): + self.rldb = rldb.get_instance() + self.mailer = mailer.get_instance() + self.sms = sms.get_instance() + self.conf = config.get_instance() + + def _hash_password(self, password: str, salt: Optional[str] = None) -> str: + salt = salt if salt else base64.urlsafe_b64encode(os.urandom(16)).decode('utf-8') + digest = hmac.new(salt.encode('utf-8'), password.encode('utf-8'), 'sha256').digest() + return f"{salt}:{base64.urlsafe_b64encode(digest).decode('utf-8')}" + + def _verify_password(self, password: str, password_hash: str) -> bool: + parts = password_hash.split(':') + if len(parts) != 2: + return False + salt = parts[0] + return self._hash_password(password, salt) == password_hash + + def send_code(self, target_type: str, target: str, scene: str) -> error: + scens = { + "register": "注册", + "update": "信息更新", + # "login": "登录", + } + if scene not in scens: + return error(ErrorCode.MODEL_ERROR, f'scene {scene} not supported') + scene_name = scens.get(scene, scene) + code = f"{uuid.uuid4().int % 1000000:06d}" + expires = datetime.now() + timedelta(minutes=10) + vc = VerificationCode( + id=uuid.uuid4().hex, + target_type=target_type, + target=target, + code=code, + scene=scene, + expires_at=expires, + ) + self.rldb.upsert(vc.to_rldb_model()) + content = f"IF.U服务{scene_name}验证码: {code}, 10分钟内有效" + sent = True + if target_type == 'email': + sent = self.mailer.send(target, f'IF.U服务{scene_name}验证码', content) if self.mailer else False + elif target_type == 'phone': + sent = self.sms.send(target, content) if self.sms else False + if not sent: + return error(ErrorCode.RLDB_ERROR, 'send code failed') + return error(ErrorCode.SUCCESS, '') + + def _get_user_by_identifier(self, email: Optional[str], phone: Optional[str]) -> Optional[User]: + if email: + users = self.rldb.query(UserRLDBModel, email=email, limit=1) + if users: + return User.from_rldb_model(users[0]) + if phone: + users = self.rldb.query(UserRLDBModel, phone=phone, limit=1) + if users: + return User.from_rldb_model(users[0]) + return None + + def register(self, user: User, code: str) -> (str, error): + if not user.email and not user.phone: + return '', error(ErrorCode.MODEL_ERROR, 'email or phone required') + existed = self._get_user_by_identifier(user.email, user.phone) + if existed: + return '', error(ErrorCode.MODEL_ERROR, 'user existed') + target_type = 'phone' if user.phone else 'email' + target = user.phone if user.phone else user.email + vc_list = self.rldb.query( + VerificationCodeRLDBModel, + target_type=target_type, + target=target, + scene='register', + limit=1, + ) + if not vc_list: + return '', error(ErrorCode.MODEL_ERROR, 'code not found') + vc = vc_list[0] + if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None: + return '', error(ErrorCode.MODEL_ERROR, 'invalid code') + vc.used_at = datetime.now() + self.rldb.upsert(vc) + user.id = uuid.uuid4().hex + hashed = self._hash_password(user.password_hash) + user.password_hash = hashed + self.rldb.upsert(user.to_rldb_model()) + return user.id, error(ErrorCode.SUCCESS, '') + + def login(self, email: Optional[str], phone: Optional[str], password: str) -> (dict, error): + u = self._get_user_by_identifier(email, phone) + if not u: + return {}, error(ErrorCode.MODEL_ERROR, 'user not found') + if not self._verify_password(password, u.password_hash): + return {}, error(ErrorCode.MODEL_ERROR, 'invalid password') + ttl_days = self.conf.getint('auth', 'token_ttl_days', fallback=30) + expired_at = datetime.now() + timedelta(days=ttl_days) + token_raw = f"{u.id}.{uuid.uuid4().hex}.{int(expired_at.timestamp())}" + secret = self.conf.get('auth', 'jwt_secret', fallback='dev-secret') + signature = hmac.new(secret.encode('utf-8'), token_raw.encode('utf-8'), 'sha256').digest() + token = base64.urlsafe_b64encode(token_raw.encode('utf-8')).decode('utf-8') + '.' + base64.urlsafe_b64encode(signature).decode('utf-8') + ut = UserToken(id=uuid.uuid4().hex, user_id=u.id, token=token, expired_at=expired_at, revoked=False) + self.rldb.upsert(ut.to_rldb_model()) + return {'token': token, 'expired_at': int(expired_at.timestamp())}, error(ErrorCode.SUCCESS, '') + + def logout(self, token: str) -> error: + tokens = self.rldb.query(UserTokenRLDBModel, token=token, limit=1) + if not tokens: + return error(ErrorCode.MODEL_ERROR, 'token not found') + t = tokens[0] + t.revoked = True + self.rldb.upsert(t) + return error(ErrorCode.SUCCESS, '') + + def delete_user(self, user_id: str) -> error: + u = self.rldb.get(UserRLDBModel, user_id) + if not u: + return error(ErrorCode.MODEL_ERROR, 'user not found') + self.rldb.delete(u) + return error(ErrorCode.SUCCESS, '') + + def get(self, user_id: str) -> (User, error): + u = self.rldb.get(UserRLDBModel, user_id) + if not u: + return None, error(ErrorCode.MODEL_ERROR, 'user not found') + return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '') + + def update_profile(self, user_id: str, nickname: str = None, avatar_link: str = None, phone: str = None, email: str = None) -> (User, error): + u = self.rldb.get(UserRLDBModel, user_id) + if not u: + return None, error(ErrorCode.MODEL_ERROR, 'user not found') + has_email = bool(u.email) + has_phone = bool(u.phone) + if nickname is not None: + u.nickname = nickname + if avatar_link is not None: + u.avatar_link = avatar_link + if email is not None: + new_email = email + if has_email: + if not has_phone: + return None, error(ErrorCode.MODEL_ERROR, 'email update requires phone exists') + conflicts = self.rldb.query(UserRLDBModel, email=new_email, limit=1) + if conflicts and conflicts[0].id != user_id: + return None, error(ErrorCode.MODEL_ERROR, 'email existed') + u.email = new_email + if phone is not None: + new_phone = phone + if has_phone: + if not has_email: + return None, error(ErrorCode.MODEL_ERROR, 'phone update requires email exists') + conflicts = self.rldb.query(UserRLDBModel, phone=new_phone, limit=1) + if conflicts and conflicts[0].id != user_id: + return None, error(ErrorCode.MODEL_ERROR, 'phone existed') + u.phone = new_phone + self.rldb.upsert(u) + return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '') + + def update_phone_with_code(self, user_id: str, new_phone: str, code: str) -> (User, error): + vc_list = self.rldb.query( + VerificationCodeRLDBModel, + target_type='phone', + target=new_phone, + scene='update', + limit=1, + ) + if not vc_list: + return None, error(ErrorCode.MODEL_ERROR, 'code not found') + vc = vc_list[0] + if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None: + return None, error(ErrorCode.MODEL_ERROR, 'invalid code') + vc.used_at = datetime.now() + self.rldb.upsert(vc) + return self.update_profile(user_id, phone=new_phone) + + def update_email_with_code(self, user_id: str, new_email: str, code: str) -> (User, error): + vc_list = self.rldb.query( + VerificationCodeRLDBModel, + target_type='email', + target=new_email, + scene='update', + limit=1, + ) + if not vc_list: + return None, error(ErrorCode.MODEL_ERROR, 'code not found') + vc = vc_list[0] + if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None: + return None, error(ErrorCode.MODEL_ERROR, 'invalid code') + vc.used_at = datetime.now() + self.rldb.upsert(vc) + return self.update_profile(user_id, email=new_email) + + +user_service = None + + +def init(): + global user_service + user_service = UserService() + + +def get_instance() -> UserService: + return user_service \ No newline at end of file diff --git a/src/utils/error.py b/src/utils/error.py index 4cccb1c..022bb64 100644 --- a/src/utils/error.py +++ b/src/utils/error.py @@ -15,7 +15,6 @@ class error(Protocol): def __init__(self, error_code: ErrorCode, error_info: str): self._error_code = int(error_code.value) self._error_info = error_info - logging.info(f"errorcode: {type(self._error_code)}") def __str__(self) -> str: return f"{self.__class__.__name__}({self._error_code}, {self._error_info})" diff --git a/src/utils/mailer.py b/src/utils/mailer.py new file mode 100644 index 0000000..c0d6b37 --- /dev/null +++ b/src/utils/mailer.py @@ -0,0 +1,60 @@ +import logging +import smtplib +from email.mime.text import MIMEText +from typing import Protocol +from .config import get_instance as get_config + +class Mailer(Protocol): + def send(self, to_email: str, subject: str, content: str) -> bool: + ... + + +class FakeMailer: + def __init__(self) -> None: + conf = get_config() + self.fake_message = conf.get('fake_mailer', 'message', fallback="FakeEmail") + def send(self, to_email: str, subject: str, content: str) -> bool: + logging.info(f"{self.fake_message}: to_email={to_email}, subject={subject}, content={content}") + return True + + +class RealMailer: + def __init__(self): + conf = get_config() + self.smtp_host = conf.get('real_mailer', 'smtp_host', fallback=None) + self.smtp_port = conf.getint('real_mailer', 'smtp_port', fallback=587) + self.smtp_user = conf.get('real_mailer', 'smtp_user', fallback=None) + self.smtp_pass = conf.get('real_mailer', 'smtp_pass', fallback=None) + self.from_email = conf.get('real_mailer', 'from_email', fallback=self.smtp_user) + + def send(self, to_email: str, subject: str, content: str) -> bool: + if not self.smtp_host or not self.smtp_user or not self.smtp_pass: + return False + msg = MIMEText(content, 'plain', 'utf-8') + msg['Subject'] = subject + msg['From'] = self.from_email + msg['To'] = to_email + try: + server = smtplib.SMTP(self.smtp_host, self.smtp_port) + server.starttls() + server.login(self.smtp_user, self.smtp_pass) + server.sendmail(self.from_email, [to_email], msg.as_string()) + server.quit() + return True + except Exception: + return False + + +_mailer: Mailer = None + + +def init(type: str = 'real'): + global _mailer + if type == 'real': + _mailer = RealMailer() + else: + _mailer = FakeMailer() + + +def get_instance() -> Mailer: + return _mailer \ No newline at end of file diff --git a/src/utils/sms.py b/src/utils/sms.py new file mode 100644 index 0000000..38b2ea4 --- /dev/null +++ b/src/utils/sms.py @@ -0,0 +1,51 @@ +import logging +from typing import Protocol +import requests +from .config import get_instance as get_config + + +class SMS(Protocol): + def send(self, phone: str, content: str) -> bool: + ... + + +class FakeSMS: + def __init__(self) -> None: + conf = get_config() + self.fake_message = conf.get('fake_sms', 'message', fallback="FakeSMS") + def send(self, phone: str, content: str) -> bool: + logging.info(f"{self.fake_message}: phone={phone}, content={content}") + return True + + +class RealSMS: + def __init__(self): + conf = get_config() + self.webhook_url = conf.get('real_sms', 'webhook_url', fallback=None) + self.webhook_token = conf.get('real_sms', 'webhook_token', fallback=None) + + def send(self, phone: str, content: str) -> bool: + if not self.webhook_url: + return False + try: + headers = {'Authorization': f'Bearer {self.webhook_token}'} if self.webhook_token else {} + data = {'phone': phone, 'content': content} + resp = requests.post(self.webhook_url, json=data, headers=headers, timeout=5) + return resp.status_code >= 200 and resp.status_code < 300 + except Exception: + return False + + +_sms: SMS = None + + +def init(type: str = 'real'): + global _sms + if type == 'real': + _sms = RealSMS() + else: + _sms = FakeSMS() + + +def get_instance() -> SMS: + return _sms \ No newline at end of file diff --git a/src/web/api.py b/src/web/api.py index 511571f..d329661 100644 --- a/src/web/api.py +++ b/src/web/api.py @@ -1,44 +1,50 @@ import os +import time import uuid import logging -from typing import Any, Optional -from fastapi import FastAPI, UploadFile, File, Query +from typing import Any, Optional, Literal +from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Response, APIRouter, Depends, Request from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from services.people import get_instance as get_people_service +from services.user import get_instance as get_user_service +from web.auth import require_auth from models.people import People from agents.extract_people_agent import ExtractPeopleAgent from utils import obs, ocr +from utils.config import get_instance as get_config api = FastAPI(title="Single People Management and Searching", version="0.1") api.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=["https://localhost:5173", "https://ifu.mamamiyear.site"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) +authorized_router = APIRouter(dependencies=[Depends(require_auth)]) + class BaseResponse(BaseModel): error_code: int error_info: str data: Optional[Any] = None -@api.post("/ping") +@api.post("/api/ping") async def ping(): return BaseResponse(error_code=0, error_info="success") class PostInputRequest(BaseModel): text: str -@api.post("/recognition/input") +@api.post("/api/recognition/input") async def post_input(request: PostInputRequest): people = extract_people(request.text) resp = BaseResponse(error_code=0, error_info="success") resp.data = people.to_dict() return resp -@api.post("/recognition/image") +@api.post("/api/recognition/image") async def post_input_image(image: UploadFile = File(...)): # 实现上传图片的处理 # 保存上传的图片文件 @@ -78,7 +84,7 @@ def extract_people(text: str, cover_link: str = None) -> People: class PostPeopleRequest(BaseModel): people: dict -@api.post("/people") +@api.post("/api/people") async def post_people(post_people_request: PostPeopleRequest): logging.debug(f"post_people_request: {post_people_request}") people = People.from_dict(post_people_request.people) @@ -88,7 +94,7 @@ async def post_people(post_people_request: PostPeopleRequest): return BaseResponse(error_code=error.code, error_info=error.info) return BaseResponse(error_code=0, error_info="success", data=people.id) -@api.put("/people/{people_id}") +@api.put("/api/people/{people_id}") async def update_people(people_id: str, post_people_request: PostPeopleRequest): logging.debug(f"post_people_request: {post_people_request}") people = People.from_dict(post_people_request.people) @@ -102,7 +108,7 @@ async def update_people(people_id: str, post_people_request: PostPeopleRequest): return BaseResponse(error_code=error.code, error_info=error.info) return BaseResponse(error_code=0, error_info="success") -@api.delete("/people/{people_id}") +@api.delete("/api/people/{people_id}") async def delete_people(people_id: str): service = get_people_service() error = service.delete(people_id) @@ -115,7 +121,7 @@ class GetPeopleRequest(BaseModel): conds: Optional[dict] = None top_k: int = 5 -@api.get("/peoples") +@api.get("/api/peoples") async def get_peoples( name: Optional[str] = Query(None, description="姓名"), gender: Optional[str] = Query(None, description="性别"), @@ -155,7 +161,7 @@ class RemarkRequest(BaseModel): content: str -@api.post("/people/{people_id}/remark") +@api.post("/api/people/{people_id}/remark") async def post_remark(people_id: str, request: RemarkRequest): service = get_people_service() error = service.save_remark(people_id, request.content) @@ -164,10 +170,223 @@ async def post_remark(people_id: str, request: RemarkRequest): return BaseResponse(error_code=0, error_info="success") -@api.delete("/people/{people_id}/remark") +@api.delete("/api/people/{people_id}/remark") async def delete_remark(people_id: str): service = get_people_service() error = service.delete_remark(people_id) if not error.success: return BaseResponse(error_code=error.code, error_info=error.info) return BaseResponse(error_code=0, error_info="success") + + +class SendCodeRequest(BaseModel): + target_type: str + target: str + scene: Literal['register', 'update'] + # scene: Literal['register', 'login'] + + +@api.post("/api/user/send_code") +async def send_user_code(request: SendCodeRequest): + service = get_user_service() + err = service.send_code(request.target_type, request.target, request.scene) + if not err.success: + raise HTTPException(status_code=400, detail=err.info) + return BaseResponse(error_code=0, error_info="success") + + +class RegisterRequest(BaseModel): + nickname: Optional[str] = None + avatar_link: Optional[str] = None + email: Optional[str] = None + phone: Optional[str] = None + password: str + code: str + +@api.post("/api/user") +async def user_register(request: RegisterRequest): + service = get_user_service() + from models.user import User + u = User( + nickname=request.nickname or "", + avatar_link=request.avatar_link or "", + email=request.email or "", + phone=request.phone or "", + password_hash=request.password, + ) + uid, err = service.register(u, request.code) + if not err.success: + logging.error(f"register failed: {err}") + raise HTTPException(status_code=400, detail=err.info) + return BaseResponse(error_code=0, error_info="success", data=uid) + + +class LoginRequest(BaseModel): + email: Optional[str] = None + phone: Optional[str] = None + password: str + +@api.post("/api/user/login") +async def user_login(request: LoginRequest, response: Response): + service = get_user_service() + data, err = service.login(request.email, request.phone, request.password) + if not err.success: + raise HTTPException(status_code=400, detail=err.info) + conf = get_config() + ttl_days = conf.getint('auth', 'token_ttl_days', fallback=30) + cookie_domain = conf.get('auth', 'cookie_domain', fallback=None) + cookie_secure = conf.getboolean('auth', 'cookie_secure', fallback=False) + cookie_samesite = conf.get('auth', 'cookie_samesite', fallback=None) + response.set_cookie( + key="token", + value=data.get('token', ''), + max_age=ttl_days * 24 * 3600, + httponly=True, + secure=cookie_secure, + samesite=cookie_samesite, + domain=cookie_domain, + path="/", + ) + return BaseResponse(error_code=0, error_info="success", data={"expired_at": data.get('expired_at')}) + + +@authorized_router.delete("/api/user/me/login") +async def user_logout(response: Response, request: Request): + service = get_user_service() + err = service.logout(getattr(request.state, 'token', None)) + if not err.success: + raise HTTPException(status_code=400, detail=err.info) + conf = get_config() + cookie_domain = conf.get('auth', 'cookie_domain', fallback=None) + response.delete_cookie(key="token", domain=cookie_domain, path="/") + return BaseResponse(error_code=0, error_info="success") + + +@authorized_router.delete("/api/user/me") +async def user_delete(response: Response, request: Request): + service = get_user_service() + err = service.delete_user(getattr(request.state, 'user_id', None)) + if not err.success: + raise HTTPException(status_code=400, detail=err.info) + conf = get_config() + cookie_domain = conf.get('auth', 'cookie_domain', fallback=None) + response.delete_cookie(key="token", domain=cookie_domain, path="/") + return BaseResponse(error_code=0, error_info="success") + +@authorized_router.get("/api/user/me") +async def user_me(request: Request): + service = get_user_service() + user, err = service.get(getattr(request.state, 'user_id', None)) + if not err.success or not user: + raise HTTPException(status_code=400, detail=err.info) + data = { + 'nickname': user.nickname, + 'avatar_link': user.avatar_link, + 'phone': user.phone, + 'email': user.email, + } + return BaseResponse(error_code=0, error_info="success", data=data) + + +class UpdateMeRequest(BaseModel): + nickname: Optional[str] = None + avatar_link: Optional[str] = None + phone: Optional[str] = None + email: Optional[str] = None + + +@authorized_router.put("/api/user/me") +async def update_user_me(request: Request, body: UpdateMeRequest): + service = get_user_service() + user, err = service.update_profile( + getattr(request.state, 'user_id', None), + nickname=body.nickname, + avatar_link=body.avatar_link, + phone=body.phone, + email=body.email, + ) + if not err.success: + raise HTTPException(status_code=400, detail=err.info) + data = { + 'nickname': user.nickname, + 'avatar_link': user.avatar_link, + 'phone': user.phone, + 'email': user.email, + } + return BaseResponse(error_code=0, error_info="success", data=data) + + +@authorized_router.put("/api/user/me/avatar") +async def upload_avatar(request: Request, avatar: UploadFile = File(...)): + user_id = getattr(request.state, 'user_id', None) + if not user_id: + raise HTTPException(status_code=401, detail="unauthorized") + + file_extension = os.path.splitext(avatar.filename)[1] + timestamp = int(time.time()) + avatar_path = f"users/{user_id}/avatar-{timestamp}{file_extension}" + + try: + obs_util = obs.get_instance() + obs_util.Put(avatar_path, await avatar.read()) + avatar_url = obs_util.Link(avatar_path) + + user_service = get_user_service() + _, err = user_service.update_profile(user_id, avatar_link=avatar_url) + if not err.success: + raise HTTPException(status_code=500, detail=err.info) + + return BaseResponse(error_code=0, error_info="success", data={"avatar_link": avatar_url}) + except Exception as e: + logging.error(f"upload avatar failed: {e}") + raise HTTPException(status_code=500, detail="upload avatar failed") + + +class UpdatePhoneRequest(BaseModel): + phone: str + code: str + + +@authorized_router.put("/api/user/me/phone") +async def update_user_phone(request: Request, body: UpdatePhoneRequest): + service = get_user_service() + user, err = service.update_phone_with_code( + getattr(request.state, 'user_id', None), + body.phone, + body.code, + ) + if not err.success: + raise HTTPException(status_code=400, detail=err.info) + data = { + 'nickname': user.nickname, + 'avatar_link': user.avatar_link, + 'phone': user.phone, + 'email': user.email, + } + return BaseResponse(error_code=0, error_info="success", data=data) + + +class UpdateEmailRequest(BaseModel): + email: str + code: str + + +@authorized_router.put("/api/user/me/email") +async def update_user_email(request: Request, body: UpdateEmailRequest): + service = get_user_service() + user, err = service.update_email_with_code( + getattr(request.state, 'user_id', None), + body.email, + body.code, + ) + if not err.success: + raise HTTPException(status_code=400, detail=err.info) + data = { + 'nickname': user.nickname, + 'avatar_link': user.avatar_link, + 'phone': user.phone, + 'email': user.email, + } + return BaseResponse(error_code=0, error_info="success", data=data) + +api.include_router(authorized_router) diff --git a/src/web/auth.py b/src/web/auth.py new file mode 100644 index 0000000..e05ba9c --- /dev/null +++ b/src/web/auth.py @@ -0,0 +1,28 @@ +from typing import Optional +from fastapi import Cookie, HTTPException, Request +from utils import rldb as rldb_util +from models.user import User, UserTokenRLDBModel, UserRLDBModel +from datetime import datetime + + +def require_auth(request: Request, token: Optional[str] = Cookie(None)): + if not token: + raise HTTPException(status_code=401, detail="unauthorized") + db = rldb_util.get_instance() + tokens = db.query(UserTokenRLDBModel, token=token, limit=1) + if not tokens: + raise HTTPException(status_code=401, detail="unauthorized") + t = tokens[0] + if getattr(t, 'revoked', False): + raise HTTPException(status_code=401, detail="unauthorized") + if getattr(t, 'expired_at', None) and t.expired_at < datetime.now(): + raise HTTPException(status_code=401, detail="unauthorized") + user_orm = db.get(UserRLDBModel, t.user_id) + if not user_orm: + raise HTTPException(status_code=401, detail="unauthorized") + user = User.from_rldb_model(user_orm) + request.state.user_id = user.id + request.state.user_nickname = user.nickname + request.state.user_email = user.email + request.state.user_phone = user.phone + request.state.token = token \ No newline at end of file