feat: add ai agent http agent
This commit is contained in:
@@ -0,0 +1,150 @@
|
||||
import io
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from texas_holdem.ai_client import (
|
||||
AIAgentConsole,
|
||||
AIAgentService,
|
||||
LLMClient,
|
||||
LLMConfig,
|
||||
PromptLibrary,
|
||||
_iter_sse_payloads,
|
||||
)
|
||||
|
||||
|
||||
class FakeLLM(LLMClient):
|
||||
def __init__(self, reply: str) -> None:
|
||||
super().__init__(
|
||||
LLMConfig(
|
||||
base_url="http://example.test/v1",
|
||||
api_key="test",
|
||||
model="fake",
|
||||
)
|
||||
)
|
||||
self.reply = reply
|
||||
self.calls: list[list[dict[str, Any]]] = []
|
||||
|
||||
def chat(self, messages, on_delta=None): # type: ignore[no-untyped-def]
|
||||
self.calls.append(messages)
|
||||
if on_delta:
|
||||
on_delta("reasoning", "counting outs... ")
|
||||
on_delta("content", self.reply)
|
||||
return self.reply
|
||||
|
||||
|
||||
def prompt_library(path: Path) -> PromptLibrary:
|
||||
(path / "system.md").write_text("system", encoding="utf-8")
|
||||
(path / "game_start.md").write_text(
|
||||
"GAME {game_id} {hand_number} {status} {players_block} {history_block}",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(path / "observation.md").write_text(
|
||||
"ACT {hand_number} {street} {player_id} {legal_actions_block}",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return PromptLibrary(path)
|
||||
|
||||
|
||||
def game_state() -> dict[str, Any]:
|
||||
return {
|
||||
"game_id": "g1",
|
||||
"status": "running",
|
||||
"hand_number": 1,
|
||||
"small_blind": 5,
|
||||
"big_blind": 10,
|
||||
"button_seat": 0,
|
||||
"starting_stack": 100,
|
||||
"players": [
|
||||
{
|
||||
"player_id": "ai",
|
||||
"name": "AI",
|
||||
"seat": 0,
|
||||
"stack": 100,
|
||||
"in_hand": True,
|
||||
}
|
||||
],
|
||||
"hands": [],
|
||||
}
|
||||
|
||||
|
||||
def observation() -> dict[str, Any]:
|
||||
return {
|
||||
"game_id": "g1",
|
||||
"hand_number": 1,
|
||||
"street": "preflop",
|
||||
"player_id": "ai",
|
||||
"seat": 0,
|
||||
"button_seat": 0,
|
||||
"small_blind": 5,
|
||||
"big_blind": 10,
|
||||
"board": [],
|
||||
"hole_cards": ["As", "Ah"],
|
||||
"players": game_state()["players"],
|
||||
"pot": 15,
|
||||
"to_call": 10,
|
||||
"min_raise_to": 20,
|
||||
"legal_actions": [
|
||||
{"action": "fold", "amount": 0},
|
||||
{"action": "call", "amount": 10},
|
||||
{
|
||||
"action": "raise",
|
||||
"min_amount": 20,
|
||||
"max_amount": 100,
|
||||
"amount_mode": "street_total",
|
||||
},
|
||||
],
|
||||
"action_history": [],
|
||||
}
|
||||
|
||||
|
||||
class LineResponse:
|
||||
def __init__(self, lines: list[bytes]) -> None:
|
||||
self.lines = lines
|
||||
|
||||
def __iter__(self): # type: ignore[no-untyped-def]
|
||||
return iter(self.lines)
|
||||
|
||||
|
||||
class AIClientTests(unittest.TestCase):
|
||||
def test_iter_sse_payloads_handles_done_and_crlf(self) -> None:
|
||||
response = LineResponse(
|
||||
[
|
||||
b"data: {\"a\": 1}\r\n",
|
||||
b"\r\n",
|
||||
b"data: [DONE]\n",
|
||||
b"\n",
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(list(_iter_sse_payloads(response)), ['{"a": 1}', "[DONE]"])
|
||||
|
||||
def test_service_logs_game_act_stream_and_action(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
output = io.StringIO()
|
||||
console = AIAgentConsole(
|
||||
output_stream=output,
|
||||
keep_history=True,
|
||||
use_color=False,
|
||||
)
|
||||
service = AIAgentService(
|
||||
FakeLLM('{"action": "call", "amount": 10}'),
|
||||
prompt_library(Path(temp_dir)),
|
||||
console=console,
|
||||
)
|
||||
|
||||
service.handle_game(game_state(), player_id="ai")
|
||||
action = service.handle_act(observation())
|
||||
|
||||
text = output.getvalue()
|
||||
self.assertEqual(action, {"action": "call", "amount": 10})
|
||||
self.assertIn("GAME UPDATE", text)
|
||||
self.assertIn("Game g1 | Hand #1 | Street: preflop", text)
|
||||
self.assertIn("AI MODEL STREAM", text)
|
||||
self.assertIn("counting outs...", text)
|
||||
self.assertIn('AI ACTION (model) -> {"action": "call", "amount": 10}', text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user