Files
texas_hold_x/tests/test_ai_client.py
2026-05-12 00:56:49 +08:00

151 lines
4.1 KiB
Python

import io
import tempfile
import unittest
from pathlib import Path
from typing import Any
from texas_holdem.ai_client import (
AIAgentConsole,
AIAgentService,
LLMClient,
LLMConfig,
PromptLibrary,
_iter_sse_payloads,
)
class FakeLLM(LLMClient):
def __init__(self, reply: str) -> None:
super().__init__(
LLMConfig(
base_url="http://example.test/v1",
api_key="test",
model="fake",
)
)
self.reply = reply
self.calls: list[list[dict[str, Any]]] = []
def chat(self, messages, on_delta=None): # type: ignore[no-untyped-def]
self.calls.append(messages)
if on_delta:
on_delta("reasoning", "counting outs... ")
on_delta("content", self.reply)
return self.reply
def prompt_library(path: Path) -> PromptLibrary:
(path / "system.md").write_text("system", encoding="utf-8")
(path / "game_start.md").write_text(
"GAME {game_id} {hand_number} {status} {players_block} {history_block}",
encoding="utf-8",
)
(path / "observation.md").write_text(
"ACT {hand_number} {street} {player_id} {legal_actions_block}",
encoding="utf-8",
)
return PromptLibrary(path)
def game_state() -> dict[str, Any]:
return {
"game_id": "g1",
"status": "running",
"hand_number": 1,
"small_blind": 5,
"big_blind": 10,
"button_seat": 0,
"starting_stack": 100,
"players": [
{
"player_id": "ai",
"name": "AI",
"seat": 0,
"stack": 100,
"in_hand": True,
}
],
"hands": [],
}
def observation() -> dict[str, Any]:
return {
"game_id": "g1",
"hand_number": 1,
"street": "preflop",
"player_id": "ai",
"seat": 0,
"button_seat": 0,
"small_blind": 5,
"big_blind": 10,
"board": [],
"hole_cards": ["As", "Ah"],
"players": game_state()["players"],
"pot": 15,
"to_call": 10,
"min_raise_to": 20,
"legal_actions": [
{"action": "fold", "amount": 0},
{"action": "call", "amount": 10},
{
"action": "raise",
"min_amount": 20,
"max_amount": 100,
"amount_mode": "street_total",
},
],
"action_history": [],
}
class LineResponse:
def __init__(self, lines: list[bytes]) -> None:
self.lines = lines
def __iter__(self): # type: ignore[no-untyped-def]
return iter(self.lines)
class AIClientTests(unittest.TestCase):
def test_iter_sse_payloads_handles_done_and_crlf(self) -> None:
response = LineResponse(
[
b"data: {\"a\": 1}\r\n",
b"\r\n",
b"data: [DONE]\n",
b"\n",
]
)
self.assertEqual(list(_iter_sse_payloads(response)), ['{"a": 1}', "[DONE]"])
def test_service_logs_game_act_stream_and_action(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
output = io.StringIO()
console = AIAgentConsole(
output_stream=output,
keep_history=True,
use_color=False,
)
service = AIAgentService(
FakeLLM('{"action": "call", "amount": 10}'),
prompt_library(Path(temp_dir)),
console=console,
)
service.handle_game(game_state(), player_id="ai")
action = service.handle_act(observation())
text = output.getvalue()
self.assertEqual(action, {"action": "call", "amount": 10})
self.assertIn("GAME UPDATE", text)
self.assertIn("Game g1 | Hand #1 | Street: preflop", text)
self.assertIn("AI MODEL STREAM", text)
self.assertIn("counting outs...", text)
self.assertIn('AI ACTION (model) -> {"action": "call", "amount": 10}', text)
if __name__ == "__main__":
unittest.main()