feat: support multi tenant

This commit is contained in:
2025-11-18 16:08:02 +08:00
parent 18f0083827
commit af8fa03e59
8 changed files with 783 additions and 18 deletions

View File

@@ -4,7 +4,8 @@ import os
import argparse
import uvicorn
from services import people as people_service
from utils import config, logger, obs, ocr, rldb
from services import user as user_service
from utils import config, logger, obs, ocr, rldb, sms, mailer
from web.api import api
@@ -16,16 +17,19 @@ def main():
args = parser.parse_args()
config.init(args.config)
conf = config.get_instance()
logger.init()
rldb.init()
ocr.init()
obs.init()
mailer.init(conf.get('mailer', 'type', fallback='real'))
sms.init(conf.get('sms', 'type', fallback='real'))
people_service.init()
user_service.init()
conf = config.get_instance()
host = conf.get('web_service', 'server_host', fallback='0.0.0.0')
port = conf.getint('web_service', 'server_port', fallback=8099)

184
src/models/user.py Normal file
View File

@@ -0,0 +1,184 @@
from typing import Optional
import json
from datetime import datetime, timedelta
from sqlalchemy import Column, String, Text, DateTime, Integer, Boolean, func, UniqueConstraint
from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error
class UserRLDBModel(RLDBBaseModel):
__tablename__ = 'users'
id = Column(String(36), primary_key=True)
nickname = Column(String(255))
avatar_link = Column(String(255))
email = Column(String(127), unique=True, index=True)
phone = Column(String(32), unique=True, index=True)
password_hash = Column(String(255))
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False)
deleted_at = Column(DateTime(timezone=True), nullable=True, index=True)
class VerificationCodeRLDBModel(RLDBBaseModel):
__tablename__ = 'verification_codes'
id = Column(String(36), primary_key=True)
target_type = Column(String(16))
target = Column(String(255), index=True)
code = Column(String(16))
scene = Column(String(32))
expires_at = Column(DateTime(timezone=True))
used_at = Column(DateTime(timezone=True), nullable=True)
class UserTokenRLDBModel(RLDBBaseModel):
__tablename__ = 'user_tokens'
id = Column(String(36), primary_key=True)
user_id = Column(String(36), index=True)
token = Column(Text)
expired_at = Column(DateTime(timezone=True))
revoked = Column(Boolean, default=False)
class User:
id: str
nickname: str
avatar_link: str
email: str
phone: str
password_hash: str
created_at: datetime = None
def __init__(self, **kwargs):
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.nickname = kwargs.get('nickname', '') if kwargs.get('nickname', '') is not None else ''
self.avatar_link = kwargs.get('avatar_link', '') if kwargs.get('avatar_link', '') is not None else ''
self.email = kwargs.get('email', '') if kwargs.get('email', '') is not None else ''
self.phone = kwargs.get('phone', '') if kwargs.get('phone', '') is not None else ''
self.password_hash = kwargs.get('password_hash', '') if kwargs.get('password_hash', '') is not None else ''
self.created_at = kwargs.get('created_at', None)
def __str__(self) -> str:
return (f"User(id={self.id}, nickname={self.nickname}, avatar_link={self.avatar_link}, "
f"email={self.email}, phone={self.phone}, created_at={self.created_at})")
@classmethod
def from_dict(cls, data: dict):
if 'updated_at' in data:
del data['updated_at']
if 'deleted_at' in data:
del data['deleted_at']
return cls(**data)
@classmethod
def from_rldb_model(cls, data: UserRLDBModel):
return cls(
id=data.id,
nickname=data.nickname,
avatar_link=data.avatar_link,
email=data.email,
phone=data.phone,
password_hash=data.password_hash,
created_at=data.created_at,
)
def to_dict(self) -> dict:
return {
'id': self.id,
'nickname': self.nickname,
'avatar_link': self.avatar_link,
'email': self.email,
'phone': self.phone,
'created_at': int(self.created_at.timestamp()) if self.created_at else None,
}
def to_rldb_model(self) -> UserRLDBModel:
return UserRLDBModel(
id=self.id,
nickname=self.nickname,
avatar_link=self.avatar_link,
email=self.email,
phone=self.phone,
password_hash=self.password_hash,
)
def validate(self) -> error:
err = error(ErrorCode.SUCCESS, "")
if not self.email and not self.phone:
return error(ErrorCode.MODEL_ERROR, "email or phone required")
return err
class VerificationCode:
id: str
target_type: str
target: str
code: str
scene: str
expires_at: datetime
used_at: Optional[datetime] = None
def __init__(self, **kwargs):
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.target_type = kwargs.get('target_type', '')
self.target = kwargs.get('target', '')
self.code = kwargs.get('code', '')
self.scene = kwargs.get('scene', '')
self.expires_at = kwargs.get('expires_at')
self.used_at = kwargs.get('used_at', None)
@classmethod
def from_rldb_model(cls, data: VerificationCodeRLDBModel):
return cls(
id=data.id,
target_type=data.target_type,
target=data.target,
code=data.code,
scene=data.scene,
expires_at=data.expires_at,
used_at=data.used_at,
)
def to_rldb_model(self) -> VerificationCodeRLDBModel:
return VerificationCodeRLDBModel(
id=self.id,
target_type=self.target_type,
target=self.target,
code=self.code,
scene=self.scene,
expires_at=self.expires_at,
used_at=self.used_at,
)
class UserToken:
id: str
user_id: str
token: str
expired_at: datetime
revoked: bool
def __init__(self, **kwargs):
self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else ''
self.user_id = kwargs.get('user_id', '')
self.token = kwargs.get('token', '')
self.expired_at = kwargs.get('expired_at')
self.revoked = kwargs.get('revoked', False)
@classmethod
def from_rldb_model(cls, data: UserTokenRLDBModel):
return cls(
id=data.id,
user_id=data.user_id,
token=data.token,
expired_at=data.expired_at,
revoked=data.revoked,
)
def to_rldb_model(self) -> UserTokenRLDBModel:
return UserTokenRLDBModel(
id=self.id,
user_id=self.user_id,
token=self.token,
expired_at=self.expired_at,
revoked=self.revoked,
)

220
src/services/user.py Normal file
View File

@@ -0,0 +1,220 @@
import uuid
import hmac
import base64
import os
from datetime import datetime, timedelta
from typing import Optional
from utils.error import ErrorCode, error
from utils import rldb, mailer, sms, config
from models.user import (
User,
UserRLDBModel,
VerificationCode,
VerificationCodeRLDBModel,
UserToken,
UserTokenRLDBModel,
)
class UserService:
def __init__(self):
self.rldb = rldb.get_instance()
self.mailer = mailer.get_instance()
self.sms = sms.get_instance()
self.conf = config.get_instance()
def _hash_password(self, password: str, salt: Optional[str] = None) -> str:
salt = salt if salt else base64.urlsafe_b64encode(os.urandom(16)).decode('utf-8')
digest = hmac.new(salt.encode('utf-8'), password.encode('utf-8'), 'sha256').digest()
return f"{salt}:{base64.urlsafe_b64encode(digest).decode('utf-8')}"
def _verify_password(self, password: str, password_hash: str) -> bool:
parts = password_hash.split(':')
if len(parts) != 2:
return False
salt = parts[0]
return self._hash_password(password, salt) == password_hash
def send_code(self, target_type: str, target: str, scene: str) -> error:
scens = {
"register": "注册",
"update": "信息更新",
# "login": "登录",
}
if scene not in scens:
return error(ErrorCode.MODEL_ERROR, f'scene {scene} not supported')
scene_name = scens.get(scene, scene)
code = f"{uuid.uuid4().int % 1000000:06d}"
expires = datetime.now() + timedelta(minutes=10)
vc = VerificationCode(
id=uuid.uuid4().hex,
target_type=target_type,
target=target,
code=code,
scene=scene,
expires_at=expires,
)
self.rldb.upsert(vc.to_rldb_model())
content = f"IF.U服务{scene_name}验证码: {code}, 10分钟内有效"
sent = True
if target_type == 'email':
sent = self.mailer.send(target, f'IF.U服务{scene_name}验证码', content) if self.mailer else False
elif target_type == 'phone':
sent = self.sms.send(target, content) if self.sms else False
if not sent:
return error(ErrorCode.RLDB_ERROR, 'send code failed')
return error(ErrorCode.SUCCESS, '')
def _get_user_by_identifier(self, email: Optional[str], phone: Optional[str]) -> Optional[User]:
if email:
users = self.rldb.query(UserRLDBModel, email=email, limit=1)
if users:
return User.from_rldb_model(users[0])
if phone:
users = self.rldb.query(UserRLDBModel, phone=phone, limit=1)
if users:
return User.from_rldb_model(users[0])
return None
def register(self, user: User, code: str) -> (str, error):
if not user.email and not user.phone:
return '', error(ErrorCode.MODEL_ERROR, 'email or phone required')
existed = self._get_user_by_identifier(user.email, user.phone)
if existed:
return '', error(ErrorCode.MODEL_ERROR, 'user existed')
target_type = 'phone' if user.phone else 'email'
target = user.phone if user.phone else user.email
vc_list = self.rldb.query(
VerificationCodeRLDBModel,
target_type=target_type,
target=target,
scene='register',
limit=1,
)
if not vc_list:
return '', error(ErrorCode.MODEL_ERROR, 'code not found')
vc = vc_list[0]
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
return '', error(ErrorCode.MODEL_ERROR, 'invalid code')
vc.used_at = datetime.now()
self.rldb.upsert(vc)
user.id = uuid.uuid4().hex
hashed = self._hash_password(user.password_hash)
user.password_hash = hashed
self.rldb.upsert(user.to_rldb_model())
return user.id, error(ErrorCode.SUCCESS, '')
def login(self, email: Optional[str], phone: Optional[str], password: str) -> (dict, error):
u = self._get_user_by_identifier(email, phone)
if not u:
return {}, error(ErrorCode.MODEL_ERROR, 'user not found')
if not self._verify_password(password, u.password_hash):
return {}, error(ErrorCode.MODEL_ERROR, 'invalid password')
ttl_days = self.conf.getint('auth', 'token_ttl_days', fallback=30)
expired_at = datetime.now() + timedelta(days=ttl_days)
token_raw = f"{u.id}.{uuid.uuid4().hex}.{int(expired_at.timestamp())}"
secret = self.conf.get('auth', 'jwt_secret', fallback='dev-secret')
signature = hmac.new(secret.encode('utf-8'), token_raw.encode('utf-8'), 'sha256').digest()
token = base64.urlsafe_b64encode(token_raw.encode('utf-8')).decode('utf-8') + '.' + base64.urlsafe_b64encode(signature).decode('utf-8')
ut = UserToken(id=uuid.uuid4().hex, user_id=u.id, token=token, expired_at=expired_at, revoked=False)
self.rldb.upsert(ut.to_rldb_model())
return {'token': token, 'expired_at': int(expired_at.timestamp())}, error(ErrorCode.SUCCESS, '')
def logout(self, token: str) -> error:
tokens = self.rldb.query(UserTokenRLDBModel, token=token, limit=1)
if not tokens:
return error(ErrorCode.MODEL_ERROR, 'token not found')
t = tokens[0]
t.revoked = True
self.rldb.upsert(t)
return error(ErrorCode.SUCCESS, '')
def delete_user(self, user_id: str) -> error:
u = self.rldb.get(UserRLDBModel, user_id)
if not u:
return error(ErrorCode.MODEL_ERROR, 'user not found')
self.rldb.delete(u)
return error(ErrorCode.SUCCESS, '')
def get(self, user_id: str) -> (User, error):
u = self.rldb.get(UserRLDBModel, user_id)
if not u:
return None, error(ErrorCode.MODEL_ERROR, 'user not found')
return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '')
def update_profile(self, user_id: str, nickname: str = None, avatar_link: str = None, phone: str = None, email: str = None) -> (User, error):
u = self.rldb.get(UserRLDBModel, user_id)
if not u:
return None, error(ErrorCode.MODEL_ERROR, 'user not found')
has_email = bool(u.email)
has_phone = bool(u.phone)
if nickname is not None:
u.nickname = nickname
if avatar_link is not None:
u.avatar_link = avatar_link
if email is not None:
new_email = email
if has_email:
if not has_phone:
return None, error(ErrorCode.MODEL_ERROR, 'email update requires phone exists')
conflicts = self.rldb.query(UserRLDBModel, email=new_email, limit=1)
if conflicts and conflicts[0].id != user_id:
return None, error(ErrorCode.MODEL_ERROR, 'email existed')
u.email = new_email
if phone is not None:
new_phone = phone
if has_phone:
if not has_email:
return None, error(ErrorCode.MODEL_ERROR, 'phone update requires email exists')
conflicts = self.rldb.query(UserRLDBModel, phone=new_phone, limit=1)
if conflicts and conflicts[0].id != user_id:
return None, error(ErrorCode.MODEL_ERROR, 'phone existed')
u.phone = new_phone
self.rldb.upsert(u)
return User.from_rldb_model(u), error(ErrorCode.SUCCESS, '')
def update_phone_with_code(self, user_id: str, new_phone: str, code: str) -> (User, error):
vc_list = self.rldb.query(
VerificationCodeRLDBModel,
target_type='phone',
target=new_phone,
scene='update',
limit=1,
)
if not vc_list:
return None, error(ErrorCode.MODEL_ERROR, 'code not found')
vc = vc_list[0]
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
return None, error(ErrorCode.MODEL_ERROR, 'invalid code')
vc.used_at = datetime.now()
self.rldb.upsert(vc)
return self.update_profile(user_id, phone=new_phone)
def update_email_with_code(self, user_id: str, new_email: str, code: str) -> (User, error):
vc_list = self.rldb.query(
VerificationCodeRLDBModel,
target_type='email',
target=new_email,
scene='update',
limit=1,
)
if not vc_list:
return None, error(ErrorCode.MODEL_ERROR, 'code not found')
vc = vc_list[0]
if vc.code != code or vc.expires_at < datetime.now() or vc.used_at is not None:
return None, error(ErrorCode.MODEL_ERROR, 'invalid code')
vc.used_at = datetime.now()
self.rldb.upsert(vc)
return self.update_profile(user_id, email=new_email)
user_service = None
def init():
global user_service
user_service = UserService()
def get_instance() -> UserService:
return user_service

View File

@@ -15,7 +15,6 @@ class error(Protocol):
def __init__(self, error_code: ErrorCode, error_info: str):
self._error_code = int(error_code.value)
self._error_info = error_info
logging.info(f"errorcode: {type(self._error_code)}")
def __str__(self) -> str:
return f"{self.__class__.__name__}({self._error_code}, {self._error_info})"

60
src/utils/mailer.py Normal file
View File

@@ -0,0 +1,60 @@
import logging
import smtplib
from email.mime.text import MIMEText
from typing import Protocol
from .config import get_instance as get_config
class Mailer(Protocol):
def send(self, to_email: str, subject: str, content: str) -> bool:
...
class FakeMailer:
def __init__(self) -> None:
conf = get_config()
self.fake_message = conf.get('fake_mailer', 'message', fallback="FakeEmail")
def send(self, to_email: str, subject: str, content: str) -> bool:
logging.info(f"{self.fake_message}: to_email={to_email}, subject={subject}, content={content}")
return True
class RealMailer:
def __init__(self):
conf = get_config()
self.smtp_host = conf.get('real_mailer', 'smtp_host', fallback=None)
self.smtp_port = conf.getint('real_mailer', 'smtp_port', fallback=587)
self.smtp_user = conf.get('real_mailer', 'smtp_user', fallback=None)
self.smtp_pass = conf.get('real_mailer', 'smtp_pass', fallback=None)
self.from_email = conf.get('real_mailer', 'from_email', fallback=self.smtp_user)
def send(self, to_email: str, subject: str, content: str) -> bool:
if not self.smtp_host or not self.smtp_user or not self.smtp_pass:
return False
msg = MIMEText(content, 'plain', 'utf-8')
msg['Subject'] = subject
msg['From'] = self.from_email
msg['To'] = to_email
try:
server = smtplib.SMTP(self.smtp_host, self.smtp_port)
server.starttls()
server.login(self.smtp_user, self.smtp_pass)
server.sendmail(self.from_email, [to_email], msg.as_string())
server.quit()
return True
except Exception:
return False
_mailer: Mailer = None
def init(type: str = 'real'):
global _mailer
if type == 'real':
_mailer = RealMailer()
else:
_mailer = FakeMailer()
def get_instance() -> Mailer:
return _mailer

51
src/utils/sms.py Normal file
View File

@@ -0,0 +1,51 @@
import logging
from typing import Protocol
import requests
from .config import get_instance as get_config
class SMS(Protocol):
def send(self, phone: str, content: str) -> bool:
...
class FakeSMS:
def __init__(self) -> None:
conf = get_config()
self.fake_message = conf.get('fake_sms', 'message', fallback="FakeSMS")
def send(self, phone: str, content: str) -> bool:
logging.info(f"{self.fake_message}: phone={phone}, content={content}")
return True
class RealSMS:
def __init__(self):
conf = get_config()
self.webhook_url = conf.get('real_sms', 'webhook_url', fallback=None)
self.webhook_token = conf.get('real_sms', 'webhook_token', fallback=None)
def send(self, phone: str, content: str) -> bool:
if not self.webhook_url:
return False
try:
headers = {'Authorization': f'Bearer {self.webhook_token}'} if self.webhook_token else {}
data = {'phone': phone, 'content': content}
resp = requests.post(self.webhook_url, json=data, headers=headers, timeout=5)
return resp.status_code >= 200 and resp.status_code < 300
except Exception:
return False
_sms: SMS = None
def init(type: str = 'real'):
global _sms
if type == 'real':
_sms = RealSMS()
else:
_sms = FakeSMS()
def get_instance() -> SMS:
return _sms

View File

@@ -1,44 +1,50 @@
import os
import time
import uuid
import logging
from typing import Any, Optional
from fastapi import FastAPI, UploadFile, File, Query
from typing import Any, Optional, Literal
from fastapi import FastAPI, HTTPException, UploadFile, File, Query, Response, APIRouter, Depends, Request
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from services.people import get_instance as get_people_service
from services.user import get_instance as get_user_service
from web.auth import require_auth
from models.people import People
from agents.extract_people_agent import ExtractPeopleAgent
from utils import obs, ocr
from utils.config import get_instance as get_config
api = FastAPI(title="Single People Management and Searching", version="0.1")
api.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=["https://localhost:5173", "https://ifu.mamamiyear.site"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
authorized_router = APIRouter(dependencies=[Depends(require_auth)])
class BaseResponse(BaseModel):
error_code: int
error_info: str
data: Optional[Any] = None
@api.post("/ping")
@api.post("/api/ping")
async def ping():
return BaseResponse(error_code=0, error_info="success")
class PostInputRequest(BaseModel):
text: str
@api.post("/recognition/input")
@api.post("/api/recognition/input")
async def post_input(request: PostInputRequest):
people = extract_people(request.text)
resp = BaseResponse(error_code=0, error_info="success")
resp.data = people.to_dict()
return resp
@api.post("/recognition/image")
@api.post("/api/recognition/image")
async def post_input_image(image: UploadFile = File(...)):
# 实现上传图片的处理
# 保存上传的图片文件
@@ -78,7 +84,7 @@ def extract_people(text: str, cover_link: str = None) -> People:
class PostPeopleRequest(BaseModel):
people: dict
@api.post("/people")
@api.post("/api/people")
async def post_people(post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
@@ -88,7 +94,7 @@ async def post_people(post_people_request: PostPeopleRequest):
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success", data=people.id)
@api.put("/people/{people_id}")
@api.put("/api/people/{people_id}")
async def update_people(people_id: str, post_people_request: PostPeopleRequest):
logging.debug(f"post_people_request: {post_people_request}")
people = People.from_dict(post_people_request.people)
@@ -102,7 +108,7 @@ async def update_people(people_id: str, post_people_request: PostPeopleRequest):
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
@api.delete("/people/{people_id}")
@api.delete("/api/people/{people_id}")
async def delete_people(people_id: str):
service = get_people_service()
error = service.delete(people_id)
@@ -115,7 +121,7 @@ class GetPeopleRequest(BaseModel):
conds: Optional[dict] = None
top_k: int = 5
@api.get("/peoples")
@api.get("/api/peoples")
async def get_peoples(
name: Optional[str] = Query(None, description="姓名"),
gender: Optional[str] = Query(None, description="性别"),
@@ -155,7 +161,7 @@ class RemarkRequest(BaseModel):
content: str
@api.post("/people/{people_id}/remark")
@api.post("/api/people/{people_id}/remark")
async def post_remark(people_id: str, request: RemarkRequest):
service = get_people_service()
error = service.save_remark(people_id, request.content)
@@ -164,10 +170,223 @@ async def post_remark(people_id: str, request: RemarkRequest):
return BaseResponse(error_code=0, error_info="success")
@api.delete("/people/{people_id}/remark")
@api.delete("/api/people/{people_id}/remark")
async def delete_remark(people_id: str):
service = get_people_service()
error = service.delete_remark(people_id)
if not error.success:
return BaseResponse(error_code=error.code, error_info=error.info)
return BaseResponse(error_code=0, error_info="success")
class SendCodeRequest(BaseModel):
target_type: str
target: str
scene: Literal['register', 'update']
# scene: Literal['register', 'login']
@api.post("/api/user/send_code")
async def send_user_code(request: SendCodeRequest):
service = get_user_service()
err = service.send_code(request.target_type, request.target, request.scene)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success")
class RegisterRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
email: Optional[str] = None
phone: Optional[str] = None
password: str
code: str
@api.post("/api/user")
async def user_register(request: RegisterRequest):
service = get_user_service()
from models.user import User
u = User(
nickname=request.nickname or "",
avatar_link=request.avatar_link or "",
email=request.email or "",
phone=request.phone or "",
password_hash=request.password,
)
uid, err = service.register(u, request.code)
if not err.success:
logging.error(f"register failed: {err}")
raise HTTPException(status_code=400, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data=uid)
class LoginRequest(BaseModel):
email: Optional[str] = None
phone: Optional[str] = None
password: str
@api.post("/api/user/login")
async def user_login(request: LoginRequest, response: Response):
service = get_user_service()
data, err = service.login(request.email, request.phone, request.password)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
ttl_days = conf.getint('auth', 'token_ttl_days', fallback=30)
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
cookie_secure = conf.getboolean('auth', 'cookie_secure', fallback=False)
cookie_samesite = conf.get('auth', 'cookie_samesite', fallback=None)
response.set_cookie(
key="token",
value=data.get('token', ''),
max_age=ttl_days * 24 * 3600,
httponly=True,
secure=cookie_secure,
samesite=cookie_samesite,
domain=cookie_domain,
path="/",
)
return BaseResponse(error_code=0, error_info="success", data={"expired_at": data.get('expired_at')})
@authorized_router.delete("/api/user/me/login")
async def user_logout(response: Response, request: Request):
service = get_user_service()
err = service.logout(getattr(request.state, 'token', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@authorized_router.delete("/api/user/me")
async def user_delete(response: Response, request: Request):
service = get_user_service()
err = service.delete_user(getattr(request.state, 'user_id', None))
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
conf = get_config()
cookie_domain = conf.get('auth', 'cookie_domain', fallback=None)
response.delete_cookie(key="token", domain=cookie_domain, path="/")
return BaseResponse(error_code=0, error_info="success")
@authorized_router.get("/api/user/me")
async def user_me(request: Request):
service = get_user_service()
user, err = service.get(getattr(request.state, 'user_id', None))
if not err.success or not user:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateMeRequest(BaseModel):
nickname: Optional[str] = None
avatar_link: Optional[str] = None
phone: Optional[str] = None
email: Optional[str] = None
@authorized_router.put("/api/user/me")
async def update_user_me(request: Request, body: UpdateMeRequest):
service = get_user_service()
user, err = service.update_profile(
getattr(request.state, 'user_id', None),
nickname=body.nickname,
avatar_link=body.avatar_link,
phone=body.phone,
email=body.email,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
@authorized_router.put("/api/user/me/avatar")
async def upload_avatar(request: Request, avatar: UploadFile = File(...)):
user_id = getattr(request.state, 'user_id', None)
if not user_id:
raise HTTPException(status_code=401, detail="unauthorized")
file_extension = os.path.splitext(avatar.filename)[1]
timestamp = int(time.time())
avatar_path = f"users/{user_id}/avatar-{timestamp}{file_extension}"
try:
obs_util = obs.get_instance()
obs_util.Put(avatar_path, await avatar.read())
avatar_url = obs_util.Link(avatar_path)
user_service = get_user_service()
_, err = user_service.update_profile(user_id, avatar_link=avatar_url)
if not err.success:
raise HTTPException(status_code=500, detail=err.info)
return BaseResponse(error_code=0, error_info="success", data={"avatar_link": avatar_url})
except Exception as e:
logging.error(f"upload avatar failed: {e}")
raise HTTPException(status_code=500, detail="upload avatar failed")
class UpdatePhoneRequest(BaseModel):
phone: str
code: str
@authorized_router.put("/api/user/me/phone")
async def update_user_phone(request: Request, body: UpdatePhoneRequest):
service = get_user_service()
user, err = service.update_phone_with_code(
getattr(request.state, 'user_id', None),
body.phone,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
class UpdateEmailRequest(BaseModel):
email: str
code: str
@authorized_router.put("/api/user/me/email")
async def update_user_email(request: Request, body: UpdateEmailRequest):
service = get_user_service()
user, err = service.update_email_with_code(
getattr(request.state, 'user_id', None),
body.email,
body.code,
)
if not err.success:
raise HTTPException(status_code=400, detail=err.info)
data = {
'nickname': user.nickname,
'avatar_link': user.avatar_link,
'phone': user.phone,
'email': user.email,
}
return BaseResponse(error_code=0, error_info="success", data=data)
api.include_router(authorized_router)

28
src/web/auth.py Normal file
View File

@@ -0,0 +1,28 @@
from typing import Optional
from fastapi import Cookie, HTTPException, Request
from utils import rldb as rldb_util
from models.user import User, UserTokenRLDBModel, UserRLDBModel
from datetime import datetime
def require_auth(request: Request, token: Optional[str] = Cookie(None)):
if not token:
raise HTTPException(status_code=401, detail="unauthorized")
db = rldb_util.get_instance()
tokens = db.query(UserTokenRLDBModel, token=token, limit=1)
if not tokens:
raise HTTPException(status_code=401, detail="unauthorized")
t = tokens[0]
if getattr(t, 'revoked', False):
raise HTTPException(status_code=401, detail="unauthorized")
if getattr(t, 'expired_at', None) and t.expired_at < datetime.now():
raise HTTPException(status_code=401, detail="unauthorized")
user_orm = db.get(UserRLDBModel, t.user_id)
if not user_orm:
raise HTTPException(status_code=401, detail="unauthorized")
user = User.from_rldb_model(user_orm)
request.state.user_id = user.id
request.state.user_nickname = user.nickname
request.state.user_email = user.email
request.state.user_phone = user.phone
request.state.token = token