diff --git a/.gitignore b/.gitignore index b7faf40..e448300 100644 --- a/.gitignore +++ b/.gitignore @@ -205,3 +205,12 @@ cython_debug/ marimo/_static/ marimo/_lsp/ __marimo__/ + +# Other +uv.lock +configuration/ +logs/ +.DS_Store + +# Test +localstore/ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..e937ee0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[project] +name = "service" +version = "0.1.0" +description = "This project is the web servcie sub-system for if.u projuect" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "alibabacloud-ocr-api20210707>=3.1.3", + "chromadb>=1.1.1", + "fastapi>=0.118.2", + "langchain>=0.3.27", + "langchain-openai>=0.3.35", + "numpy>=2.3.3", + "python-multipart>=0.0.20", + "qiniu>=7.17.0", + "requests>=2.32.5", +] diff --git a/src/ai/__init__.py b/src/ai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ai/agent.py b/src/ai/agent.py new file mode 100644 index 0000000..f10d61f --- /dev/null +++ b/src/ai/agent.py @@ -0,0 +1,55 @@ +import json +import logging +from langchain_openai import ChatOpenAI +from langchain.prompts import ChatPromptTemplate + +from models.people import People + +class BaseAgent: + def __init__(self): + self.llm = ChatOpenAI( + openai_api_key="56d82040-85c7-4701-8f87-734985e27909", + openai_api_base="https://ark.cn-beijing.volces.com/api/v3", + model_name="ep-20250722161445-n9lfq" + ) + pass + +class ExtractPeopleAgent(BaseAgent): + def __init__(self): + super().__init__() + self.prompt = ChatPromptTemplate.from_messages([ + ( + "system", + "你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n" + "姓名 name\n" + "性别 gender\n" + "年龄 age\n" + "身高(cm) height\n" + # "体重(kg) weight\n" + "婚姻状况 marital_status\n" + "择偶要求 match_requirement\n" + "以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n" + "除了上述基本信息,还有一个字段\n" + "个人介绍 introduction\n" + "其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n" + "个人介绍的字典的 key 需要使用提炼好的中文。\n" + ), + ("human", "{input}") + ]) + + def extract_people_info(self, text: str) -> People: + """从文本中提取个人信息""" + prompt = self.prompt.format_prompt(input=text) + response = self.llm.invoke(prompt) + logging.info(f"llm response: {response.content}") + try: + return People.from_dict(json.loads(response.content)) + except json.JSONDecodeError: + logging.error(f"Failed to parse JSON from LLM response: {response.content}") + return None + pass + +class SummaryPeopleAgent(BaseAgent): + def __init__(self): + super().__init__() + pass \ No newline at end of file diff --git a/src/app/__init__.py b/src/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/app/api.py b/src/app/api.py new file mode 100644 index 0000000..fb583bc --- /dev/null +++ b/src/app/api.py @@ -0,0 +1,130 @@ + +import json +import logging +import os +import uuid + +from typing import Any, Optional +from fastapi import FastAPI, File, UploadFile, Query +from pydantic import BaseModel + +from ai.agent import ExtractPeopleAgent +from models.people import People +from utils import obs, ocr, vsdb +from storage.people_store import get_instance as get_people_store +from fastapi.middleware.cors import CORSMiddleware + +api = FastAPI(title="Single People Management and Searching", version="1.0.0") +api.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +class BaseResponse(BaseModel): + error_code: int + error_info: str + data: Optional[Any] = None + +class PostInputRequest(BaseModel): + text: str + +@api.post("/input") +async def post_input(request: PostInputRequest): + extra_agent = ExtractPeopleAgent() + people = extra_agent.extract_people_info(request.text) + logging.info(f"people: {people}") + resp = BaseResponse(error_code=0, error_info="success") + resp.data = people.to_dict() + return resp + +@api.post("/input_image") +async def post_input_image(image: UploadFile = File(...)): + # 实现上传图片的处理 + # 保存上传的图片文件 + # 生成唯一的文件名 + file_extension = os.path.splitext(image.filename)[1] + unique_filename = f"{uuid.uuid4()}{file_extension}" + + # 确保uploads目录存在 + os.makedirs("uploads", exist_ok=True) + + # 保存文件到对象存储 + file_path = f"uploads/{unique_filename}" + obs_util = obs.get_instance() + obs_util.Put(file_path, await image.read()) + + # 获取对象存储外链 + obs_url = obs_util.Link(file_path) + logging.info(f"obs_url: {obs_url}") + + # 调用OCR处理图片 + ocr_util = ocr.get_instance() + ocr_result = ocr_util.recognize_image_text(obs_url) + logging.info(f"ocr_result: {ocr_result}") + + post_input_request = PostInputRequest(text=ocr_result) + return await post_input(post_input_request) + +class PostPeopleRequest(BaseModel): + people: dict + +@api.post("/peoples") +async def post_people(post_people_request: PostPeopleRequest): + logging.debug(f"post_people_request: {post_people_request}") + people = People.from_dict(post_people_request.people) + store = get_people_store() + people.id = store.save(people) + return BaseResponse(error_code=0, error_info="success", data=people.id) + +class GetPeopleRequest(BaseModel): + query: Optional[str] = None + conds: Optional[dict] = None + top_k: int = 5 + +@api.get("/peoples") +async def get_peoples( + name: Optional[str] = Query(None, description="姓名"), + gender: Optional[str] = Query(None, description="性别"), + age: Optional[int] = Query(None, description="年龄"), + height: Optional[int] = Query(None, description="身高"), + marital_status: Optional[str] = Query(None, description="婚姻状态"), + limit: int = Query(10, description="分页大小"), + offset: int = Query(0, description="分页偏移量"), + search: Optional[str] = Query(None, description="搜索内容"), + top_k: int = Query(5, description="搜索结果数量"), + ): + + # 解析查询参数为字典 + conds = {} + if name: + conds["name"] = name + if gender: + conds["gender"] = gender + if age: + conds["age"] = age + if height: + conds["height"] = height + if marital_status: + conds["marital_status"] = marital_status + + logging.info(f"conds: , limit: {limit}, offset: {offset}, search: {search}, top_k: {top_k}") + + results = [] + store = get_people_store() + if search: + results = store.search(search, conds, ids=None, top_k=top_k) + logging.info(f"search results: {results}") + else: + results = store.query(conds, limit=limit, offset=offset) + logging.info(f"query results: {results}") + peoples = [people.to_dict() for people in results] + return BaseResponse(error_code=0, error_info="success", data=peoples) + +@api.delete("/peoples/{people_id}") +async def delete_people(people_id: str): + store = get_people_store() + store.delete(people_id) + return BaseResponse(error_code=0, error_info="success") \ No newline at end of file diff --git a/src/app/app.py b/src/app/app.py new file mode 100644 index 0000000..5323819 --- /dev/null +++ b/src/app/app.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# created by mmmy on 2025-09-27 + +import logging +from ai.agent import ExtractPeopleAgent +from utils import ocr, vsdb, obs +from models.people import People +from fastapi import FastAPI + +api = FastAPI(title="Single People Management and Searching", version="1.0.0") + +class App: + def __init__(self): + self.extract_people_agent = ExtractPeopleAgent() + self.ocr = ocr.get_instance() + self.vedb = vsdb.get_instance(db_type='chromadb') + self.obs = obs.get_instance() + + def run(self): + pass + + class InputForPeople: + image: bytes + text: str + + def input_to_people(self, input: InputForPeople) -> People: + if not input.image and not input.text: + return None + if input.image: + content = self.ocr.recognize_image_text(input.image) + else: + content = input.text + print(content) + people = self._trans_text_to_people(content) + return people + + def _trans_text_to_people(self, text: str) -> People: + if not text: + return None + person = self.extract_people_agent.extract_people_info(text) + print(person) + return person + + def create_people(self, people: People) -> bool: + if not people: + return False + try: + people.save() + except Exception as e: + logging.error(f"保存人物到数据库失败: {e}") + return False + return True \ No newline at end of file diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..e361b3c --- /dev/null +++ b/src/main.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# created by mmmy on 2025-09-27 +import logging +import os +import argparse +from venv import logger +import uvicorn +from app.api import api +from utils import obs, ocr, vsdb +from utils.config import get_instance as get_config, init as init_config +from utils.logger import init as init_logger +from storage import people_store + +# 主函数 +def main(): + main_path = os.path.dirname(os.path.abspath(__file__)) + parser = argparse.ArgumentParser(description='IF.u 服务') + parser.add_argument('--config', type=str, default=os.path.join(main_path, '../configuration/test_conf.ini'), help='配置文件路径') + args = parser.parse_args() + init_logger(log_level=logging.DEBUG) + logger.info(f"args.config: {args.config}") + init_config(args.config) + config = get_config() + print(config.sections()) + obs.init() + ocr.init() + vsdb.init() + people_store.init() + port = config.getint('web_service', 'server_port', fallback=8099) + uvicorn.run(api, host="127.0.0.1", port=port) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/people.py b/src/models/people.py new file mode 100644 index 0000000..f30eb48 --- /dev/null +++ b/src/models/people.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# created by mmmy on 2025-09-30 + +import logging +from typing import Dict + + +class People: + # 数据库 ID + id: str + # 姓名 + name: str + # 性别 + gender: str + # 年龄 + age: int + # 身高(cm) + height: int + # 体重(kg) + # weight: int + # 婚姻状况 + # [ + # "未婚(single)", + # "已婚(married)", + # "离异(divorced)", + # "丧偶(widowed)" + # ] + marital_status: str + # 择偶要求 + match_requirement: str + # 个人介绍 + introduction: Dict[str, str] + # 总结评价 + comments: Dict[str, str] + + + def __init__(self, **kwargs): + # 初始化所有属性,从kwargs中获取值,如果不存在则设置默认值 + self.id = kwargs.get('id', '') if kwargs.get('id', '') is not None else '' + self.name = kwargs.get('name', '') if kwargs.get('name', '') is not None else '' + self.gender = kwargs.get('gender', '') if kwargs.get('gender', '') is not None else '' + self.age = kwargs.get('age', 0) if kwargs.get('age', 0) is not None else 0 + self.height = kwargs.get('height', 0) if kwargs.get('height', 0) is not None else 0 + # self.weight = kwargs.get('weight', 0) if kwargs.get('weight', 0) is not None else 0 + self.marital_status = kwargs.get('marital_status', '') if kwargs.get('marital_status', '') is not None else '' + self.match_requirement = kwargs.get('match_requirement', '') if kwargs.get('match_requirement', '') is not None else '' + self.introduction = kwargs.get('introduction', {}) if kwargs.get('introduction', {}) is not None else {} + self.comments = kwargs.get('comments', {}) if kwargs.get('comments', {}) is not None else {} + + def __str__(self) -> str: + return self.tonl() + + @classmethod + def from_dict(cls, data: dict): + if 'created_at' in data: + # 移除 created_at 字段,避免类型错误 + del data['created_at'] + if 'updated_at' in data: + # 移除 updated_at 字段,避免类型错误 + del data['updated_at'] + if 'deleted_at' in data: + # 移除 deleted_at 字段,避免类型错误 + del data['deleted_at'] + return cls(**data) + + def to_dict(self) -> dict: + # 将对象转换为字典格式 + return { + 'id': self.id, + 'name': self.name, + 'gender': self.gender, + 'age': self.age, + 'height': self.height, + # 'weight': self.weight, + 'marital_status': self.marital_status, + 'match_requirement': self.match_requirement, + 'introduction': self.introduction, + 'comments': self.comments, + } + + def meta(self) -> Dict[str, str]: + # 返回对象的元数据信息 + meta = { + 'id': self.id, + 'name': self.name, + 'gender': self.gender, + 'age': self.age, + 'height': self.height, + # 'weight': self.weight, + 'marital_status': self.marital_status, + # 'match_requirement': self.match_requirement, + } + logging.info(f"people meta: {meta}") + return meta + + def tonl(self) -> str: + # 将对象转换为文档格式的字符串 + doc = [] + doc.append(f"姓名: {self.name}") + doc.append(f"性别: {self.gender}") + if self.age: + doc.append(f"年龄: {self.age}") + if self.height: + doc.append(f"身高: {self.height}cm") + # if self.weight: + # doc.append(f"体重: {self.weight}kg") + if self.marital_status: + doc.append(f"婚姻状况: {self.marital_status}") + if self.match_requirement: + doc.append(f"择偶要求: {self.match_requirement}") + if self.introduction: + doc.append("个人介绍:") + for key, value in self.introduction.items(): + doc.append(f" - {key}: {value}") + if self.comments: + doc.append("总结评价:") + for key, value in self.comments.items(): + doc.append(f" - {key}: {value}") + return '\n'.join(doc) + + def comment(self, comment: Dict[str, str]): + # 添加总结评价 + self.comments.update(comment) diff --git a/src/storage/__init__.py b/src/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/storage/people_store.py b/src/storage/people_store.py new file mode 100644 index 0000000..67768ea --- /dev/null +++ b/src/storage/people_store.py @@ -0,0 +1,210 @@ +import json +import logging +import uuid +from datetime import datetime + +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, func +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +from utils.config import get_instance as get_config +from utils.vsdb import VectorDB, get_instance as get_vsdb +from utils.obs import OBS, get_instance as get_obs + +from models.people import People + +people_store = None +Base = declarative_base() +class PeopleORM(Base): + __tablename__ = 'peoples' + id = Column(String(36), primary_key=True) + name = Column(String(255), index=True) + gender = Column(String(10)) + age = Column(Integer) + height = Column(Integer) + marital_status = Column(String(20)) + match_requirement = Column(Text) + introduction = Column(Text) + comments = Column(Text) + created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) + updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False) + deleted_at = Column(DateTime(timezone=True), nullable=True, index=True) + def parse_from_people(self, people: People): + import json + self.id = people.id + self.name = people.name + self.gender = people.gender + self.age = people.age + self.height = people.height + self.marital_status = people.marital_status + self.match_requirement = people.match_requirement + # 将字典类型字段序列化为JSON字符串存储 + self.introduction = json.dumps(people.introduction, ensure_ascii=False) + self.comments = json.dumps(people.comments, ensure_ascii=False) + + def to_people(self) -> People: + import json + people = People() + people.id = self.id + people.name = self.name + people.gender = self.gender + people.age = self.age + people.height = self.height + people.marital_status = self.marital_status + people.match_requirement = self.match_requirement + # 将JSON字符串反序列化为字典类型字段 + try: + people.introduction = json.loads(self.introduction) if self.introduction else {} + except (json.JSONDecodeError, TypeError): + people.introduction = {} + + try: + people.comments = json.loads(self.comments) if self.comments else {} + except (json.JSONDecodeError, TypeError): + people.comments = {} + + return people + +class PeopleStore: + def __init__(self): + config = get_config() + self.sqldb_engine = create_engine(config.get("sqlalchemy", "database_dsn")) + Base.metadata.create_all(self.sqldb_engine) + self.session_maker = sessionmaker(bind=self.sqldb_engine) + self.vsdb: VectorDB = get_vsdb() + self.obs: OBS = get_obs() + + def save(self, people: People) -> str: + """ + 保存人物到数据库和向量数据库 + + :param people: 人物对象 + :return: 人物ID + """ + # 0. 生成 people id + people.id = people.id if people.id else uuid.uuid4().hex + + # 1. 转换模型,并保存到 SQL 数据库 + people_orm = PeopleORM() + people_orm.parse_from_people(people) + with self.session_maker() as session: + session.add(people_orm) + session.commit() + + # 2. 保存到向量数据库 + people_metadata = people.meta() + people_document = people.tonl() + logging.info(f"people: {people}") + logging.info(f"people_metadata: {people_metadata}") + logging.info(f"people_document: {people_document}") + results = self.vsdb.insert(metadatas=[people_metadata], documents=[people_document], ids=[people.id]) + logging.info(f"results: {results}") + if len(results) == 0: + raise Exception("insert failed") + + # 3. 保存到 OBS 存储 + people_dict = people.to_dict() + people_json = json.dumps(people_dict, ensure_ascii=False) + obs_url = self.obs.Put(f"peoples/{people.id}/detail.json", people_json.encode('utf-8')) + logging.info(f"obs_url: {obs_url}") + + return people.id + + def update(self, people: People) -> None: + raise Exception("update not implemented") + return None + + def find(self, people_id: str) -> People: + """ + 根据人物ID查询人物 + + :param people_id: 人物ID + :return: 人物对象 + """ + with self.session_maker() as session: + people_orm = session.query(PeopleORM).filter( + PeopleORM.id == people_id, + PeopleORM.deleted_at.is_(None) + ).first() + if not people_orm: + raise Exception(f"people not found, people_id: {people_id}") + return people_orm.to_people() + + def query(self, conds: dict = {}, limit: int = 10, offset: int = 0) -> list[People]: + """ + 根据查询条件查询人物 + + :param conds: 查询条件字典 + :param limit: 分页大小 + :param offset: 分页偏移量 + :return: 人物对象列表 + """ + if conds is None: + conds = {} + with self.session_maker() as session: + people_orms = session.query(PeopleORM).filter_by(**conds).filter( + PeopleORM.deleted_at.is_(None) + ).limit(limit).offset(offset).all() + return [people_orm.to_people() for people_orm in people_orms] + + def search(self, search: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[People]: + """ + 根据搜索内容和查询条件查询人物 + + :param search: 搜索内容 + :param metadatas: 查询条件字典 + :param ids: 可选的人物ID列表,用于过滤结果 + :param top_k: 返回结果数量 + :return: 人物对象列表 + """ + peoples = [] + results = self.vsdb.search(document=search, metadatas=metadatas, ids=ids, top_k=top_k) + logging.info(f"results: {results}") + people_id_list = [] + for result in results: + people_id = result.get('id', '') + if not people_id: + continue + people_id_list.append(people_id) + logging.info(f"people_id_list: {people_id_list}") + with self.session_maker() as session: + people_orms = session.query(PeopleORM).filter(PeopleORM.id.in_(people_id_list)).filter( + PeopleORM.deleted_at.is_(None) + ).all() + for people_orm in people_orms: + people = people_orm.to_people() + peoples.append(people) + return peoples + + def delete(self, people_id: str) -> None: + """ + 删除人物从数据库和向量数据库 + + :param people_id: 人物ID + """ + # 1. 从 SQL 数据库软删除人物 + with self.session_maker() as session: + session.query(PeopleORM).filter( + PeopleORM.id == people_id, + PeopleORM.deleted_at.is_(None) + ).update({PeopleORM.deleted_at: func.now()}, synchronize_session=False) + session.commit() + logging.info(f"人物 {people_id} 标记删除 SQL 成功") + + # 2. 删除向量数据库中的记录 + self.vsdb.delete(ids=[people_id]) + logging.info(f"人物 {people_id} 删除向量数据库成功") + + # 3. 删除 OBS 存储中的文件 + keys = self.obs.List(f"peoples/{people_id}/") + for key in keys: + self.obs.Del(key) + logging.info(f"文件 {key} 删除 OBS 成功") + +def init(): + global people_store + people_store = PeopleStore() + + +def get_instance() -> PeopleStore: + return people_store diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..f9f8136 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1,3 @@ +# 导出utils模块中的子模块 +from . import config, obs, ocr, vsdb, logger +__all__ = ['config', 'obs', 'ocr', 'vsdb', 'logger'] diff --git a/src/utils/config.py b/src/utils/config.py new file mode 100644 index 0000000..b67123c --- /dev/null +++ b/src/utils/config.py @@ -0,0 +1,25 @@ +import configparser + +config = None + + +def init(config_file: str): + global config + config = configparser.ConfigParser() + config.read(config_file) + +def get_instance() -> configparser.ConfigParser: + return config + + +if __name__ == "__main__": + # 本文件的绝对路径 + import os + config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini") + init(config_file) + conf = get_instance() + print(conf.sections()) + for section in conf.sections(): + print(conf.options(section)) + for option in conf.options(section): + print(f"{section}.{option}={conf.get(section, option)}") diff --git a/src/utils/logger.py b/src/utils/logger.py new file mode 100644 index 0000000..e88c882 --- /dev/null +++ b/src/utils/logger.py @@ -0,0 +1,84 @@ +import logging +import os +from datetime import datetime + +# 定义颜色代码 +class Colors: + RED = '\033[31m' + GREEN = '\033[32m' + YELLOW = '\033[33m' + BLUE = '\033[34m' + MAGENTA = '\033[35m' + CYAN = '\033[36m' + WHITE = '\033[37m' + RESET = '\033[0m' # 重置颜色 + +# 自定义控制台处理器,为不同日志级别添加颜色 +class ColoredConsoleHandler(logging.StreamHandler): + def emit(self, record): + # 为不同日志级别设置颜色 + colors = { + logging.DEBUG: Colors.CYAN, + logging.INFO: Colors.GREEN, + logging.WARNING: Colors.YELLOW, + logging.ERROR: Colors.RED, + logging.CRITICAL: Colors.MAGENTA + } + + # 获取对应级别的颜色,默认为白色 + color = colors.get(record.levelno, Colors.WHITE) + + # 获取原始消息 + message = self.format(record) + + # 添加颜色并输出 + self.stream.write(f"{color}{message}{Colors.RESET}\n") + self.flush() + +def init(log_dir="logs", log_file="log", log_level=logging.INFO, console_log_level=logging.DEBUG): + # 创建logs目录(如果不存在) + if not os.path.exists(log_dir): + os.makedirs(log_dir) + + # 设置日志格式 + log_format = "[%(asctime)s.%(msecs)03d][%(filename)s:%(lineno)d][%(levelname)s] %(message)s" + date_format = "%Y-%m-%d %H:%M:%S" + + # 创建格式化器 + formatter = logging.Formatter(log_format, datefmt=date_format) + + # 获取根日志记录器 + root_logger = logging.getLogger() + root_logger.setLevel(logging.NOTSET) + + # 清除现有的处理器 + root_logger.handlers.clear() + + # 创建控制台处理器并设置颜色 + console_handler = ColoredConsoleHandler() + console_handler.setFormatter(formatter) + console_handler.setLevel(console_log_level) + root_logger.addHandler(console_handler) + + # 创建文件处理器 + log_filename = os.path.join(log_dir, f"{log_file}_{datetime.now().strftime('%Y%m%d')}.log") + file_handler = logging.FileHandler(log_filename, encoding='utf-8') + file_handler.setFormatter(formatter) + file_handler.setLevel(log_level) + root_logger.addHandler(file_handler) + + # 确保日志消息被正确处理 + logging.addLevelName(logging.DEBUG, "D") + logging.addLevelName(logging.INFO, "I") + logging.addLevelName(logging.WARNING, "W") + logging.addLevelName(logging.ERROR, "E") + logging.addLevelName(logging.CRITICAL, "C") + + +if __name__ == "__main__": + init(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG) + logging.debug("debug log") + logging.info("info log") + logging.warning("warning log") + logging.error("error log") + logging.critical("critical log") \ No newline at end of file diff --git a/src/utils/obs.py b/src/utils/obs.py new file mode 100644 index 0000000..75c6ee0 --- /dev/null +++ b/src/utils/obs.py @@ -0,0 +1,220 @@ + +import json +import logging +from typing import Protocol +import qiniu +import requests +from .config import get_instance as get_config + + +class OBS(Protocol): + def Put(self, obs_path: str, content: bytes) -> str: + """ + 上传文件到OBS + + Args: + obs_path (str): OBS目标路径 + content (bytes): 文件内容 + + Returns: + str: OBS文件路径 + """ + ... + + def Get(self, obs_path: str) -> bytes: + """ + 从OBS下载文件 + + Args: + obs_path (str): OBS文件路径 + + Returns: + bytes: 文件内容 + """ + ... + + def List(self, obs_path: str) -> list: + """ + 列出OBS目录下的所有文件 + + Args: + obs_path (str): OBS目录路径 + + Returns: + list: 所有文件路径列表 + """ + ... + + def Del(self, obs_path: str) -> bool: + """ + 删除OBS文件 + + Args: + obs_path (str): OBS文件路径 + + Returns: + bool: 是否删除成功 + """ + ... + + def Link(self, obs_path: str) -> str: + """ + 获取OBS文件链接 + + Args: + obs_path (str): OBS文件路径 + + Returns: + str: OBS文件链接 + """ + ... + + +class Koodo: + def __init__(self): + config = get_config() + self.bucket_name = config.get('koodo_obs', 'bucket_name') + self.prefix_path = config.get('koodo_obs', 'prefix_path') + self.access_key = config.get('koodo_obs', 'access_key') + self.secret_key = config.get('koodo_obs', 'secret_key') + self.outer_domain = config.get('koodo_obs', 'outer_domain') + self.auth = qiniu.Auth(self.access_key, self.secret_key) + self.bucket = qiniu.BucketManager(self.auth) + pass + + def Put(self, obs_path: str, content: bytes) -> str: + """ + 上传文件到OBS + + Args: + obs_path (str): OBS目标路径 + content (bytes): 文件内容 + + Returns: + str: OBS文件路径 + """ + full_path = f"{self.prefix_path}{obs_path}" + token = self.auth.upload_token(self.bucket_name, full_path) + ret, info = qiniu.put_data(token, full_path, content) + logging.debug(f"文件 {obs_path} 上传到 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}") + if ret is None or info.status_code != 200: + logging.error(f"文件 {obs_path} 上传失败, 错误信息: {info.text_body}") + return "" + logging.info(f"文件 {obs_path} 上传成功, OBS路径: {full_path}") + return f"{self.outer_domain}/{full_path}" + + def Get(self, obs_path: str) -> bytes: + """ + 从OBS下载文件 + + Args: + obs_path (str): OBS文件路径 + + Returns: + bytes: 文件内容 + """ + link = f"{self.outer_domain}/{self.prefix_path}{obs_path}" + resp = requests.get(link) + data = json.loads(resp.text) + if 'error' in data and data['error']: + logging.error(f"从 OBS {obs_path} 下载文件失败, 错误信息: {data['error']}") + return None + return resp.content + + def List(self, prefix: str = "") -> list[str]: + """ + 列出OBS目录下的所有文件 + + Args: + prefix (str, optional): OBS目录路径前缀. Defaults to "". + + Returns: + list: 文件路径列表 + """ + prefix = f"{self.prefix_path}{prefix}" + ret, eof, info = self.bucket.list(self.bucket_name, prefix) + keys = [] + for item in ret['items']: + item['key'] = item['key'].replace(prefix, "") + keys.append(item['key']) + # logging.debug(f"文件 {item['key']} 路径: {item['key']}") + # logging.debug(f"ret: {ret}") + # logging.debug(f"eof: {eof}") + # logging.debug(f"info: {info}") + return keys + + def Del(self, obs_path: str) -> bool: + """ + 删除OBS文件 + + Args: + obs_path (str): OBS文件路径 + + Returns: + bool: 是否删除成功 + """ + ret, info = self.bucket.delete(self.bucket_name, f"{self.prefix_path}{obs_path}") + logging.debug(f"文件 {obs_path} 删除 OBS, 结果: {ret}, 状态码: {info.status_code}, 错误信息: {info.text_body}") + if ret is None or info.status_code != 200: + logging.error(f"文件 {obs_path} 删除 OBS 失败, 错误信息: {info.text_body}") + return False + logging.info(f"文件 {obs_path} 删除 OBS 成功") + return True + + def Link(self, obs_path: str) -> str: + """ + 获取OBS文件链接 + + Args: + obs_path (str): OBS文件路径 + + Returns: + str: OBS文件链接 + """ + return f"{self.outer_domain}/{self.prefix_path}{obs_path}" + + +_obs_instance: OBS = None + +def init(): + global _obs_instance + _obs_instance = Koodo() + +def get_instance() -> OBS: + global _obs_instance + return _obs_instance + + +if __name__ == "__main__": + import os + + from .logger import init as init_logger + init_logger(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG) + + from .config import init as init_config, get_instance as get_config + config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini") + init_config(config_file) + + init() + obs = get_instance() + + # 从OBS下载测试图片 + # obs_path = "test111.PNG" + # local_path = os.path.join(os.path.dirname(__file__), "../../test/9e03ad5eb8b1a51e752fb79cd8f98169.PNG") + # content = None + # with open(local_path, "rb") as f: + # content = f.read() + # if content is None: + # print(f"文件 {local_path} 读取失败") + # exit(1) + # obs.Put(obs_path, content) + + # link = obs.Link(obs_path) + # print(f"文件 {obs_path} 链接: {link}") + + # 列出OBS目录下的所有文件 + keys = obs.List("") + print(f"OBS 目录下的所有文件: {keys}") + for key in keys: + link = obs.Del(key) + print(f"文件 {key} 删除 OBS 成功: {link}") diff --git a/src/utils/ocr.py b/src/utils/ocr.py new file mode 100644 index 0000000..b524c0b --- /dev/null +++ b/src/utils/ocr.py @@ -0,0 +1,105 @@ + +import json +import logging +from typing import Protocol +from alibabacloud_ocr_api20210707.client import Client as OcrClient +from alibabacloud_tea_openapi import models as open_api_models +from alibabacloud_ocr_api20210707 import models as ocr_models +from alibabacloud_tea_util import models as util_models +from alibabacloud_tea_util.client import Client as UtilClient +from .config import get_instance as get_config + + +class OCR(Protocol): + def recognize_image_text(self, image_link: str) -> str: + """ + 从图片提取文本 + + Args: + image_link (str): 图片链接 + + Returns: + str: 提取到的文本 + """ + ... + +class AliOCR: + def __init__(self): + config = get_config() + self.access_key = config.get("ali_ocr", "access_key") + self.secret_key = config.get("ali_ocr", "secret_key") + self.endpoint = config.get("ali_ocr", "endpoint") + self.client = self._create_client() + + def _create_client(self): + config = open_api_models.Config( + access_key_id=self.access_key, + access_key_secret=self.secret_key, + ) + config.endpoint = self.endpoint + return OcrClient(config) + + def recognize_image_text(self, image_link: str) -> str: + """ + 使用阿里云OCR从图片链接提取文本 + + Args: + image_link (str): 图片链接 + + Returns: + str: 提取到的文本 + """ + # 创建OCR请求 + recognize_general_request = ocr_models.RecognizeGeneralRequest(url=image_link) + runtime = util_models.RuntimeOptions() + try: + resp = self.client.recognize_general_with_options(recognize_general_request, runtime) + logging.debug(resp.body.data) + except Exception as error: + # 此处仅做打印展示,请谨慎对待异常处理,在工程项目中切勿直接忽略异常。 + # 错误 message + logging.error(error.message) + # 诊断地址 + logging.error(error.data.get("Recommend")) + UtilClient.assert_as_string(error.message) + + response = self.client.recognize_general_with_options(recognize_general_request, runtime) + if response.status_code == 200 and response.body: + result_data = response.body.data + result_body = json.loads(result_data) + if result_body and 'content' in result_body: + return result_body['content'] + return "" + +# 全局OCR实例 +_ocr_instance = None + + +def init(): + """初始化OCR实例""" + global _ocr_instance + _ocr_instance = AliOCR() + + +def get_instance() -> OCR: + """获取OCR实例""" + global _ocr_instance + if _ocr_instance is None: + raise RuntimeError("OCR模块未初始化,请先调用init()函数") + return _ocr_instance + + +if __name__ == "__main__": + import os + + from logger import init as init_logger + init_logger(console_log_level=logging.DEBUG) + + from config import init as init_config + config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini") + init_config(config_file) + + init() + ocr = get_instance() + text = ocr.recognize_image_text(image_link="https://pic.mamamiyear.site/test.if.u/test111.PNG") + print(text) \ No newline at end of file diff --git a/src/utils/vsdb.py b/src/utils/vsdb.py new file mode 100644 index 0000000..e2107f8 --- /dev/null +++ b/src/utils/vsdb.py @@ -0,0 +1,245 @@ +import uuid +import chromadb +import logging +from typing import Protocol +from chromadb.config import Settings +from chromadb.utils import embedding_functions +from .config import get_instance as get_config + +class VectorDB(Protocol): + + def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]: + """ + 插入向量到数据库 + + Args: + vector (list[float]): 向量 + metadata (dict): 元数据 + + Returns: + bool: 是否插入成功 + """ + ... + + def delete(self, ids: list[str]) -> bool: + """ + Delete documents from a collection. + + Args: + ids: List of IDs to delete + + Returns: + bool: Whether deletion was successful + """ + ... + + def query(self, metadatas: dict, ids: list[str], top_k: int = 5) -> list[dict]: + """ + 查询向量数据库 + + Args: + query_vector (list[float]): 查询向量 + top_k (int, optional): 返回Top K结果. Defaults to 5. + + Returns: + list[dict]: 查询结果列表 + """ + ... + + def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]: + """ + 搜索向量数据库 + + Args: + document: Document to search + metadatas: Metadata to filter by + ids: List of IDs to filter by + top_k (int, optional): 返回Top K结果. Defaults to 5. + + Returns: + list[dict]: 查询结果列表 + """ + ... + +class ChromaDB: + def __init__(self, **kwargs): + """ + Initialize the ChromaDB instance. + + Args: + persist_directory: Optional directory to persist the database. + If None, the database will be in-memory only. + """ + config = get_config() + self.embedding_functions = embedding_functions.OpenAIEmbeddingFunction( + api_base=config.get("voc-engine_embedding", "api_url"), + api_key=config.get("voc-engine_embedding", "api_key"), + model_name=config.get("voc-engine_embedding", "endpoint"), + ) + persist_directory = config.get("chroma_vsdb", "database_dir", fallback=None) + if persist_directory: + self.client = chromadb.PersistentClient( + path=persist_directory, + settings=Settings(anonymized_telemetry=False) + ) + else: + self.client = chromadb.Client( + settings=Settings(anonymized_telemetry=False), + ) + self.collection_name = config.get("chroma_vsdb", "collection_name", fallback="peoples") + metadata: dict = kwargs.get('collection_metadata', {'hnsw:space': 'cosine'}) + metadata['hnsw:space'] = metadata.get('hnsw:space', 'cosine') + self.collection = self.client.get_or_create_collection( + name=self.collection_name, + embedding_function=self.embedding_functions, + metadata=metadata, + ) + + def insert(self, metadatas: list[dict], documents: list[str], ids: list[str] = None) -> list[str]: + """ + Insert documents into a collection. + + Args: + metadatas: List of metadata corresponding to each document + documents: List of documents to insert + ids: Optional list of unique IDs for each document. If None, IDs will be generated. + + Returns: + list[str]: List of inserted IDs + """ + + if not ids: + # Generate unique IDs if not provided + ids = [str(uuid.uuid4()) for _ in range(len(documents))] + + self.collection.add( + documents=documents, + metadatas=metadatas, + ids=ids + ) + + return ids + + def delete(self, ids: list[str]) -> bool: + """ + Delete documents from a collection. + + Args: + ids: List of IDs to delete + + Returns: + bool: Whether deletion was successful + """ + try: + self.collection.delete(ids) + return True + except Exception as e: + print(f"Error deleting documents: {e}") + return False + + def query(self, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]: + """ + 查询向量数据库 + + Args: + metadatas: Metadata to filter by + ids: List of IDs to query + top_k (int, optional): 返回Top K结果. Defaults to 5. + + Returns: + list[dict]: 查询结果列表 + """ + results = self.collection.query( + query_embeddings=None, + query_texts=None, + n_results=top_k, + where=metadatas, + ids=ids, + include=["documents", "metadatas", "distances"], + ) + formatted_results = [] + for i in range(len(results['ids'][0])): + result = { + 'id': results['ids'][0][i], + 'distance': results['distances'][0][i], + 'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {}, + 'document': results['documents'][0][i] if results['documents'][0] else '' + } + formatted_results.append(result) + return formatted_results + + def search(self, document: str, metadatas: dict, ids: list[str] = None, top_k: int = 5) -> list[dict]: + """ + 搜索向量数据库 + + Args: + document: Document to search + metadatas: Metadata to filter by + ids: List of IDs to filter by + top_k (int, optional): 返回Top K结果. Defaults to 5. + + Returns: + list[dict]: 查询结果列表 + """ + results = self.collection.query( + query_embeddings=None, + query_texts=[document], + n_results=top_k, + where=metadatas if metadatas else None, + ids=ids, + include=["documents", "metadatas", "distances"], + ) + print("log: results keys: ", results.keys()) + print("log: results ids: ", results['ids']) + formatted_results = [] + for i in range(len(results['ids'][0])): + result = { + 'id': results['ids'][0][i], + 'distance': results['distances'][0][i], + 'metadata': results['metadatas'][0][i] if results['metadatas'][0] else {}, + 'document': results['documents'][0][i] if results['documents'][0] else '' + } + formatted_results.append(result) + return formatted_results + pass + +_vsdb_instance: VectorDB = None + +def init(): + global _vsdb_instance + _vsdb_instance = ChromaDB() + +def get_instance() -> VectorDB: + global _vsdb_instance + return _vsdb_instance + + +if __name__ == "__main__": + import os + + from logger import init as init_logger + init_logger(log_dir="logs", log_file="test", log_level=logging.INFO, console_log_level=logging.DEBUG) + + from config import init as init_config + config_file = os.path.join(os.path.dirname(__file__), "../../configuration/test_conf.ini") + init_config(config_file) + + init() + vsdb = get_instance() + metadatas = [ + {'name': '丽丽'}, + {'name': '志刚'}, + {'name': '张三'}, + {'name': '李四'}, + ] + documents = [ + '姓名: 丽丽, 性别: 女, 年龄: 23, 爱好: 爬山、骑行、攀岩、跳伞、蹦极', + "姓名: 志刚, 性别: 男, 年龄: 25, 爱好: 读书、游戏", + "姓名: 张三, 性别: 男, 年龄: 30, 爱好: 画画、写作、阅读、逛展、旅行", + "姓名: 李四, 性别: 男, 年龄: 35, 爱好: 做饭、美食、旅游" + ] + search_text = '25岁以下的' + ids = vsdb.insert(metadatas, documents) + results = vsdb.search(search_text, None, None, top_k=4) + for result in results: + print(result['document'], ' ', result['distance']) diff --git a/test/test_logger.py b/test/test_logger.py new file mode 100644 index 0000000..6e99c98 --- /dev/null +++ b/test/test_logger.py @@ -0,0 +1,16 @@ +import sys +import os +sys.path.append(os.path.join(os.path.dirname(__file__), '../src')) + +from utils.logger import init +import logging + +# 初始化日志 +init() + +# 测试不同级别的日志 +logging.debug("这是一条调试信息") +logging.info("这是一条普通信息") +logging.warning("这是一条警告信息") +logging.error("这是一条错误信息") +logging.critical("这是一条严重错误信息") \ No newline at end of file