fix: agent recognize data type of age and height for people wrong

This commit is contained in:
2025-11-12 17:12:34 +08:00
parent e74279ca5e
commit 3dea2b10f2
3 changed files with 39 additions and 4 deletions

View File

@@ -1,4 +1,5 @@
import datetime
import json import json
import logging import logging
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
@@ -12,6 +13,7 @@ class ExtractPeopleAgent(BaseAgent):
self.prompt = ChatPromptTemplate.from_messages([ self.prompt = ChatPromptTemplate.from_messages([
( (
"system", "system",
f"现在是{datetime.datetime.now().strftime('%Y-%m-%d')}"
"你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n" "你是一个专业的婚姻、交友助手,善于从一段文字描述中,精确获取用户的以下信息:\n"
"姓名 name\n" "姓名 name\n"
"性别 gender\n" "性别 gender\n"
@@ -20,6 +22,8 @@ class ExtractPeopleAgent(BaseAgent):
"婚姻状况 marital_status\n" "婚姻状况 marital_status\n"
"择偶要求 match_requirement\n" "择偶要求 match_requirement\n"
"以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n" "以上信息需要严格按照 JSON 格式输出 字段名与条目中英文保持一致。\n"
"其中,'年龄 age''身高(cm) height' 必须是一个整数,不能是一个字符串;\n"
"并且,'性别 gender' 根据识别结果,必须从 男,女,未知 三选一填写。\n"
"除了上述基本信息,还有一个字段\n" "除了上述基本信息,还有一个字段\n"
"个人介绍 introduction\n" "个人介绍 introduction\n"
"其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n" "其余的信息需要按照字典的方式进行提炼和总结,都放在个人介绍字段中\n"
@@ -34,8 +38,15 @@ class ExtractPeopleAgent(BaseAgent):
response = self.llm.invoke(prompt) response = self.llm.invoke(prompt)
logging.info(f"llm response: {response.content}") logging.info(f"llm response: {response.content}")
try: try:
return People.from_dict(json.loads(response.content)) people = People.from_dict(json.loads(response.content))
err = people.validate()
if not err.success:
raise ValueError(f"Failed to validate people info: {err.info}")
return people
except json.JSONDecodeError: except json.JSONDecodeError:
logging.error(f"Failed to parse JSON from LLM response: {response.content}") logging.error(f"Failed to parse JSON from LLM response: {response.content}")
return None return None
except ValueError as e:
logging.error(f"Failed to validate people info: {e}")
return None
pass pass

View File

@@ -2,9 +2,11 @@
# created by mmmy on 2025-09-30 # created by mmmy on 2025-09-30
import json import json
import logging
from typing import Dict from typing import Dict
from sqlalchemy import Column, Integer, String, Text, DateTime, func from sqlalchemy import Column, Integer, String, Text, DateTime, func
from utils.rldb import RLDBBaseModel from utils.rldb import RLDBBaseModel
from utils.error import ErrorCode, error
class PeopleRLDBModel(RLDBBaseModel): class PeopleRLDBModel(RLDBBaseModel):
__tablename__ = 'peoples' __tablename__ = 'peoples'
@@ -129,4 +131,20 @@ class People:
introduction=json.dumps(self.introduction, ensure_ascii=False), introduction=json.dumps(self.introduction, ensure_ascii=False),
comments=json.dumps(self.comments, ensure_ascii=False), comments=json.dumps(self.comments, ensure_ascii=False),
cover=self.cover, cover=self.cover,
) )
def validate(self) -> error:
err = error(ErrorCode.SUCCESS, "")
if not self.name:
logging.error("Name is required")
err = error(ErrorCode.MODEL_ERROR, "Name is required")
if not self.gender in ['', '', '未知']:
logging.error("Gender must be '', '', or '未知'")
err = error(ErrorCode.MODEL_ERROR, "Gender must be '', '', or '未知'")
if not isinstance(self.age, int) or self.age <= 0:
logging.error("Age must be an integer and greater than 0")
err = error(ErrorCode.MODEL_ERROR, "Age must be an integer and greater than 0")
if not isinstance(self.height, int) or self.height <= 0:
logging.error("Height must be an integer and greater than 0")
err = error(ErrorCode.MODEL_ERROR, "Height must be an integer and greater than 0")
return err

View File

@@ -1,14 +1,20 @@
from enum import Enum
import logging
from typing import Protocol from typing import Protocol
class ErrorCode(Enum):
SUCCESS = 0
MODEL_ERROR = 1000
class error(Protocol): class error(Protocol):
_error_code: int = 0 _error_code: int = 0
_error_info: str = "" _error_info: str = ""
def __init__(self, error_code: int, error_info: str): def __init__(self, error_code: ErrorCode, error_info: str):
self._error_code = error_code self._error_code = int(error_code.value)
self._error_info = error_info self._error_info = error_info
logging.info(f"errorcode: {type(self._error_code)}")
def __str__(self) -> str: def __str__(self) -> str:
return f"{self.__class__.__name__}({self._error_code}, {self._error_info})" return f"{self.__class__.__name__}({self._error_code}, {self._error_info})"