151 lines
4.1 KiB
Python
151 lines
4.1 KiB
Python
import io
|
|
import tempfile
|
|
import unittest
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from texas_holdem.ai_client import (
|
|
AIAgentConsole,
|
|
AIAgentService,
|
|
LLMClient,
|
|
LLMConfig,
|
|
PromptLibrary,
|
|
_iter_sse_payloads,
|
|
)
|
|
|
|
|
|
class FakeLLM(LLMClient):
|
|
def __init__(self, reply: str) -> None:
|
|
super().__init__(
|
|
LLMConfig(
|
|
base_url="http://example.test/v1",
|
|
api_key="test",
|
|
model="fake",
|
|
)
|
|
)
|
|
self.reply = reply
|
|
self.calls: list[list[dict[str, Any]]] = []
|
|
|
|
def chat(self, messages, on_delta=None): # type: ignore[no-untyped-def]
|
|
self.calls.append(messages)
|
|
if on_delta:
|
|
on_delta("reasoning", "counting outs... ")
|
|
on_delta("content", self.reply)
|
|
return self.reply
|
|
|
|
|
|
def prompt_library(path: Path) -> PromptLibrary:
|
|
(path / "system.md").write_text("system", encoding="utf-8")
|
|
(path / "game_start.md").write_text(
|
|
"GAME {game_id} {hand_number} {status} {players_block} {history_block}",
|
|
encoding="utf-8",
|
|
)
|
|
(path / "observation.md").write_text(
|
|
"ACT {hand_number} {street} {player_id} {legal_actions_block}",
|
|
encoding="utf-8",
|
|
)
|
|
return PromptLibrary(path)
|
|
|
|
|
|
def game_state() -> dict[str, Any]:
|
|
return {
|
|
"game_id": "g1",
|
|
"status": "running",
|
|
"hand_number": 1,
|
|
"small_blind": 5,
|
|
"big_blind": 10,
|
|
"button_seat": 0,
|
|
"starting_stack": 100,
|
|
"players": [
|
|
{
|
|
"player_id": "ai",
|
|
"name": "AI",
|
|
"seat": 0,
|
|
"stack": 100,
|
|
"in_hand": True,
|
|
}
|
|
],
|
|
"hands": [],
|
|
}
|
|
|
|
|
|
def observation() -> dict[str, Any]:
|
|
return {
|
|
"game_id": "g1",
|
|
"hand_number": 1,
|
|
"street": "preflop",
|
|
"player_id": "ai",
|
|
"seat": 0,
|
|
"button_seat": 0,
|
|
"small_blind": 5,
|
|
"big_blind": 10,
|
|
"board": [],
|
|
"hole_cards": ["As", "Ah"],
|
|
"players": game_state()["players"],
|
|
"pot": 15,
|
|
"to_call": 10,
|
|
"min_raise_to": 20,
|
|
"legal_actions": [
|
|
{"action": "fold", "amount": 0},
|
|
{"action": "call", "amount": 10},
|
|
{
|
|
"action": "raise",
|
|
"min_amount": 20,
|
|
"max_amount": 100,
|
|
"amount_mode": "street_total",
|
|
},
|
|
],
|
|
"action_history": [],
|
|
}
|
|
|
|
|
|
class LineResponse:
|
|
def __init__(self, lines: list[bytes]) -> None:
|
|
self.lines = lines
|
|
|
|
def __iter__(self): # type: ignore[no-untyped-def]
|
|
return iter(self.lines)
|
|
|
|
|
|
class AIClientTests(unittest.TestCase):
|
|
def test_iter_sse_payloads_handles_done_and_crlf(self) -> None:
|
|
response = LineResponse(
|
|
[
|
|
b"data: {\"a\": 1}\r\n",
|
|
b"\r\n",
|
|
b"data: [DONE]\n",
|
|
b"\n",
|
|
]
|
|
)
|
|
|
|
self.assertEqual(list(_iter_sse_payloads(response)), ['{"a": 1}', "[DONE]"])
|
|
|
|
def test_service_logs_game_act_stream_and_action(self) -> None:
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
output = io.StringIO()
|
|
console = AIAgentConsole(
|
|
output_stream=output,
|
|
keep_history=True,
|
|
use_color=False,
|
|
)
|
|
service = AIAgentService(
|
|
FakeLLM('{"action": "call", "amount": 10}'),
|
|
prompt_library(Path(temp_dir)),
|
|
console=console,
|
|
)
|
|
|
|
service.handle_game(game_state(), player_id="ai")
|
|
action = service.handle_act(observation())
|
|
|
|
text = output.getvalue()
|
|
self.assertEqual(action, {"action": "call", "amount": 10})
|
|
self.assertIn("GAME UPDATE", text)
|
|
self.assertIn("Game g1 | Hand #1 | Street: preflop", text)
|
|
self.assertIn("AI MODEL STREAM", text)
|
|
self.assertIn("counting outs...", text)
|
|
self.assertIn('AI ACTION (model) -> {"action": "call", "amount": 10}', text)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|