From bc372c5ba1ddda857d101e6cb805ad887f3259df Mon Sep 17 00:00:00 2001 From: "qianrui.mmmy" Date: Mon, 11 May 2026 21:09:55 +0800 Subject: [PATCH] feat: add ai agent http agent --- README.md | 28 +- pyproject.toml | 1 + tests/test_ai_client.py | 150 +++++ texas_holdem/ai_client.py | 972 ++++++++++++++++++++++++++++ texas_holdem/engine.py | 9 +- texas_holdem/prompts/game_start.md | 22 + texas_holdem/prompts/observation.md | 36 ++ texas_holdem/prompts/system.md | 63 ++ 8 files changed, 1276 insertions(+), 5 deletions(-) create mode 100644 tests/test_ai_client.py create mode 100644 texas_holdem/ai_client.py create mode 100644 texas_holdem/prompts/game_start.md create mode 100644 texas_holdem/prompts/observation.md create mode 100644 texas_holdem/prompts/system.md diff --git a/README.md b/README.md index da0ce20..210fac5 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ - 观察信息包含玩家筹码、公共牌、当前玩家手牌、底池、历史动作、可用动作和跟注/加注边界。 - 支持盲注、四条街下注、弃牌、过牌、跟注、下注、加注、全下、边池和摊牌结算。 - 支持本地 Agent 和 HTTP Agent。 +- 支持 Human Agent 和 OpenAI-compatible AI Agent 的终端过程输出。 ## 运行服务 @@ -60,7 +61,7 @@ curl http://127.0.0.1:8000/games/demo "name": "LLM Agent", "agent": { "type": "http", - "endpoint": "http://127.0.0.1:9001/act", + "endpoint": "http://127.0.0.1:9101", "timeout_seconds": 10 } } @@ -83,8 +84,31 @@ curl http://127.0.0.1:8000/games/demo `bet` 和 `raise` 的 `amount` 表示当前下注轮中该玩家希望达到的总下注额,也就是观察中 `amount_mode: "street_total"` 的含义。 +## AI Agent + +启动一个可接入 OpenAI-compatible Chat Completions API 的 AI Agent: + +```bash +python -m texas_holdem.ai_client \ + --host 127.0.0.1 \ + --port 9101 \ + --base-url https://api.openai.com/v1 \ + --api-key "$OPENAI_API_KEY" \ + --model gpt-4o-mini \ + --keep-history +``` + +AI Agent 会在终端输出: + +- 收到的 `/game` 游戏快照; +- 收到的 `/act` 行动请求; +- 大模型流式返回内容,默认灰色显示; +- 最终解析出的 action,或失败时的 fallback action。 + +默认每次 `/act` 会清屏,和 Human Agent 一致;加 `--keep-history` 后保留历史滚动输出。可用 `--no-stream` 关闭流式请求,用 `--no-color` 关闭灰色 ANSI 输出。 + ## 测试 ```bash -python -m unittest discover -s tests -v +python -m unittest discover -v ``` diff --git a/pyproject.toml b/pyproject.toml index b74acae..73d3f87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [] [project.scripts] texas-holdem-server = "texas_holdem.server:main" texas-holdem-human = "texas_holdem.human_client:main" +texas-holdem-ai = "texas_holdem.ai_client:main" [tool.pytest.ini_options] testpaths = ["tests"] diff --git a/tests/test_ai_client.py b/tests/test_ai_client.py new file mode 100644 index 0000000..683dc95 --- /dev/null +++ b/tests/test_ai_client.py @@ -0,0 +1,150 @@ +import io +import tempfile +import unittest +from pathlib import Path +from typing import Any + +from texas_holdem.ai_client import ( + AIAgentConsole, + AIAgentService, + LLMClient, + LLMConfig, + PromptLibrary, + _iter_sse_payloads, +) + + +class FakeLLM(LLMClient): + def __init__(self, reply: str) -> None: + super().__init__( + LLMConfig( + base_url="http://example.test/v1", + api_key="test", + model="fake", + ) + ) + self.reply = reply + self.calls: list[list[dict[str, Any]]] = [] + + def chat(self, messages, on_delta=None): # type: ignore[no-untyped-def] + self.calls.append(messages) + if on_delta: + on_delta("reasoning", "counting outs... ") + on_delta("content", self.reply) + return self.reply + + +def prompt_library(path: Path) -> PromptLibrary: + (path / "system.md").write_text("system", encoding="utf-8") + (path / "game_start.md").write_text( + "GAME {game_id} {hand_number} {status} {players_block} {history_block}", + encoding="utf-8", + ) + (path / "observation.md").write_text( + "ACT {hand_number} {street} {player_id} {legal_actions_block}", + encoding="utf-8", + ) + return PromptLibrary(path) + + +def game_state() -> dict[str, Any]: + return { + "game_id": "g1", + "status": "running", + "hand_number": 1, + "small_blind": 5, + "big_blind": 10, + "button_seat": 0, + "starting_stack": 100, + "players": [ + { + "player_id": "ai", + "name": "AI", + "seat": 0, + "stack": 100, + "in_hand": True, + } + ], + "hands": [], + } + + +def observation() -> dict[str, Any]: + return { + "game_id": "g1", + "hand_number": 1, + "street": "preflop", + "player_id": "ai", + "seat": 0, + "button_seat": 0, + "small_blind": 5, + "big_blind": 10, + "board": [], + "hole_cards": ["As", "Ah"], + "players": game_state()["players"], + "pot": 15, + "to_call": 10, + "min_raise_to": 20, + "legal_actions": [ + {"action": "fold", "amount": 0}, + {"action": "call", "amount": 10}, + { + "action": "raise", + "min_amount": 20, + "max_amount": 100, + "amount_mode": "street_total", + }, + ], + "action_history": [], + } + + +class LineResponse: + def __init__(self, lines: list[bytes]) -> None: + self.lines = lines + + def __iter__(self): # type: ignore[no-untyped-def] + return iter(self.lines) + + +class AIClientTests(unittest.TestCase): + def test_iter_sse_payloads_handles_done_and_crlf(self) -> None: + response = LineResponse( + [ + b"data: {\"a\": 1}\r\n", + b"\r\n", + b"data: [DONE]\n", + b"\n", + ] + ) + + self.assertEqual(list(_iter_sse_payloads(response)), ['{"a": 1}', "[DONE]"]) + + def test_service_logs_game_act_stream_and_action(self) -> None: + with tempfile.TemporaryDirectory() as temp_dir: + output = io.StringIO() + console = AIAgentConsole( + output_stream=output, + keep_history=True, + use_color=False, + ) + service = AIAgentService( + FakeLLM('{"action": "call", "amount": 10}'), + prompt_library(Path(temp_dir)), + console=console, + ) + + service.handle_game(game_state(), player_id="ai") + action = service.handle_act(observation()) + + text = output.getvalue() + self.assertEqual(action, {"action": "call", "amount": 10}) + self.assertIn("GAME UPDATE", text) + self.assertIn("Game g1 | Hand #1 | Street: preflop", text) + self.assertIn("AI MODEL STREAM", text) + self.assertIn("counting outs...", text) + self.assertIn('AI ACTION (model) -> {"action": "call", "amount": 10}', text) + + +if __name__ == "__main__": + unittest.main() diff --git a/texas_holdem/ai_client.py b/texas_holdem/ai_client.py new file mode 100644 index 0000000..08b4ac8 --- /dev/null +++ b/texas_holdem/ai_client.py @@ -0,0 +1,972 @@ +"""Standalone HTTP AI poker agent backed by an OpenAI-compatible LLM. + +Run as a process exposing two endpoints that the Texas Hold'em service +calls: + +* ``POST /game`` - delivered at the start of every new hand. We use it as + the boundary that opens a fresh chat session for that hand and seeds it + with a human-readable rendering of the table snapshot (history of past + hands included). +* ``POST /act`` - the per-decision request. We render the observation + with a templated user prompt, ask the configured LLM, parse the JSON + reply and return it to the server. + +Run:: + + python -m texas_holdem.ai_client \\ + --host 127.0.0.1 --port 9101 \\ + --base-url https://api.openai.com/v1 \\ + --api-key $OPENAI_API_KEY \\ + --model gpt-4o-mini + +Hook it up by passing the *base* URL when creating the game:: + + {"id": "ai", "name": "AI", + "agent": {"type": "http", "endpoint": "http://127.0.0.1:9101"}} + +Design notes: +- Prompts live in ``texas_holdem/prompts/*.md`` so non-engineers can edit + them without touching code; we read them once at boot via + :class:`PromptLibrary`. +- :class:`LLMSession` owns the chat history for a single (game_id, hand) + scope and is reset by every ``/game`` push, matching the user's mental + model that one hand == one session. +- :class:`SessionRegistry` keeps a tiny LRU per ``(game_id, player_id)`` + so the same process can serve multiple parallel games / seats. +- The LLM client speaks OpenAI's ``/v1/chat/completions`` schema, which + is what virtually every OpenAI-compatible provider implements. +""" + +from __future__ import annotations + +import argparse +import json +import os +import re +import sys +import threading +from collections import OrderedDict +from contextlib import contextmanager +from dataclasses import dataclass, field +from http import HTTPStatus +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from pathlib import Path +from typing import IO, Any, Callable, Iterator +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +from texas_holdem.human_io import clear_screen, render_game_state, render_observation + +# Default location of the prompt templates. Living next to this module lets +# operators edit them in place without re-installing the package. +PROMPTS_DIR = Path(__file__).resolve().parent / "prompts" +ANSI_GRAY = "\x1b[90m" +ANSI_RESET = "\x1b[0m" + + +# --------------------------------------------------------------------------- +# Terminal diagnostics +# --------------------------------------------------------------------------- + + +class AIAgentConsole: + """Serialised terminal output for the standalone AI agent. + + The behaviour mirrors :class:`HumanClientConsole`: by default each + decision clears the terminal first, while ``keep_history=True`` leaves + previous game/act logs in scrollback. The LLM stream is printed in gray + so model output stays visually separate from game state and final + actions. + """ + + def __init__( + self, + output_stream: IO[str] | None = None, + keep_history: bool = False, + use_color: bool = True, + ) -> None: + self._output = output_stream if output_stream is not None else sys.stdout + self._keep_history = keep_history + self._use_color = use_color + self._lock = threading.Lock() + + @contextmanager + def act_log(self, observation: dict[str, Any]) -> Iterator[None]: + """Render one received ``/act`` payload and hold the console lock.""" + with self._lock: + if not self._keep_history: + clear_screen(self._write) + self._write(render_observation(observation)) + yield + + def announce_game(self, game_state: dict[str, Any]) -> None: + """Render one received ``/game`` payload.""" + with self._lock: + self._write(render_game_state(game_state)) + + def begin_llm_stream(self) -> None: + self._write(self._gray("AI MODEL STREAM\n")) + + def write_llm_delta(self, kind: str, text: str) -> None: + if not text: + return + self._write(self._gray(text)) + + def end_llm_stream(self) -> None: + self._write(self._gray("\n")) + + def announce_action( + self, + action: dict[str, Any], + source: str = "model", + ) -> None: + body = json.dumps(action, ensure_ascii=False) + self._write(f"\nAI ACTION ({source}) -> {body}\n") + self._write("~" * 60 + "\n\n") + + def announce_warning(self, message: str) -> None: + self._write(f"\nAI WARNING -> {message}\n") + + def _gray(self, text: str) -> str: + if not self._use_color: + return text + return f"{ANSI_GRAY}{text}{ANSI_RESET}" + + def _write(self, text: str) -> None: + self._output.write(text) + self._output.flush() + + +# --------------------------------------------------------------------------- +# Prompt rendering +# --------------------------------------------------------------------------- + + +class PromptLibrary: + """Loads and renders prompt templates from disk. + + Templates use Python ``str.format`` placeholders. Centralising the + render step here keeps the LLM-facing wording out of the code base and + avoids ad-hoc string concatenation in the agent logic. + """ + + def __init__(self, directory: Path = PROMPTS_DIR) -> None: + self.directory = directory + # Cache prompts by name; they are a couple of KB and read-only at + # runtime, so we trade a few bytes of memory for zero disk IO per + # request. + self._cache: dict[str, str] = {} + + def load(self, name: str) -> str: + if name not in self._cache: + path = self.directory / f"{name}.md" + if not path.exists(): + raise FileNotFoundError(f"missing prompt template: {path}") + self._cache[name] = path.read_text(encoding="utf-8") + return self._cache[name] + + def render(self, name: str, **fields: Any) -> str: + """Return the named template with ``fields`` substituted via format().""" + template = self.load(name) + return template.format(**fields) + + +def render_game_start_prompt(library: PromptLibrary, game_state: dict[str, Any]) -> str: + """Build the opening user-message describing a new hand. + + Pulls out the heavy formatting (player rows, hand history) into helpers + so the template itself stays declarative. + """ + return library.render( + "game_start", + game_id=game_state.get("game_id"), + hand_number=game_state.get("hand_number"), + status=game_state.get("status"), + small_blind=game_state.get("small_blind"), + big_blind=game_state.get("big_blind"), + button_seat=game_state.get("button_seat"), + starting_stack=game_state.get("starting_stack"), + players_block=_format_players_block(game_state.get("players") or []), + hand_count=len(game_state.get("hands") or []), + history_block=_format_history_block(game_state.get("hands") or []), + ) + + +def render_observation_prompt( + library: PromptLibrary, observation: dict[str, Any] +) -> str: + """Build the per-decision user-message from a server observation dict.""" + legal_actions = list(observation.get("legal_actions") or []) + you = _find_self_player(observation) + return library.render( + "observation", + hand_number=observation.get("hand_number"), + street=observation.get("street"), + player_id=observation.get("player_id"), + player_name=you.get("name") if you else observation.get("player_id"), + seat=observation.get("seat"), + button_seat=observation.get("button_seat"), + pot=observation.get("pot"), + to_call=observation.get("to_call"), + min_raise_to=observation.get("min_raise_to"), + amount_mode=observation.get("amount_mode") or "street_total", + hole_cards=_format_card_list(observation.get("hole_cards")), + board=_format_card_list(observation.get("board")), + players_block=_format_players_block(observation.get("players") or []), + action_history_block=_format_action_history( + observation.get("action_history") or [] + ), + legal_actions_block=_format_legal_actions(legal_actions), + ) + + +def _find_self_player(observation: dict[str, Any]) -> dict[str, Any] | None: + """Locate the acting player's row inside the observation snapshot.""" + pid = observation.get("player_id") + for player in observation.get("players") or []: + if player.get("player_id") == pid: + return player + return None + + +def _format_card_list(cards: Any) -> str: + """Render a list of card labels for prompts; never returns an empty string.""" + if not cards: + return "(none)" + return " ".join(str(card) for card in cards) + + +def _format_players_block(players: list[dict[str, Any]]) -> str: + """Render the per-seat status table used in both prompt templates.""" + if not players: + return "(no players)" + rows: list[str] = [] + for player in players: + flags = [] + if player.get("folded"): + flags.append("folded") + if player.get("all_in"): + flags.append("all_in") + if not player.get("in_hand"): + flags.append("out") + flag_text = ",".join(flags) if flags else "active" + rows.append( + f"- seat {player.get('seat')}: id={player.get('player_id')}, " + f"name={player.get('name')}, stack={player.get('stack')}, " + f"street_bet={player.get('street_bet', 0)}, " + f"total_bet={player.get('total_bet', 0)}, status={flag_text}" + ) + return "\n".join(rows) + + +def _format_history_block(hands: list[dict[str, Any]]) -> str: + """Render a compact digest of finished hands for the GAME_START message.""" + if not hands: + return "(no hands played yet)" + digests: list[str] = [] + for hand in hands: + awards = hand.get("awards") or [] + winner_lines = [ + f" pot {a.get('amount')}: -> {','.join(a.get('winners') or [])} " + f"({(a.get('hand_value') or {}).get('name', '-')})" + for a in awards + ] + showdown = hand.get("showdown_hands") or {} + showdown_lines = [ + f" showdown {pid}: {' '.join(cards)}" + for pid, cards in showdown.items() + ] + digests.append( + "\n".join( + [ + f"- Hand #{hand.get('hand_number')} " + f"(button seat {hand.get('button_seat')}), " + f"board: {_format_card_list(hand.get('board'))}", + *winner_lines, + *showdown_lines, + ] + ) + ) + return "\n".join(digests) + + +def _format_action_history(history: list[dict[str, Any]]) -> str: + """Render the per-action log; trims very old entries to keep prompts cheap.""" + if not history: + return "(no actions yet)" + # The engine never produces unbounded history within a single hand, but + # we cap defensively so a malformed payload cannot blow up token usage. + rows = [] + for record in history[-32:]: + rows.append( + f"- [{record.get('street')}] {record.get('player_id')} -> " + f"{record.get('action')} amount={record.get('amount', 0)}" + ) + return "\n".join(rows) + + +def _format_legal_actions(legal: list[dict[str, Any]]) -> str: + """Render a numbered list of legal actions including amount ranges.""" + if not legal: + return "(no legal actions)" + rows: list[str] = [] + for index, action in enumerate(legal, start=1): + amount = action.get("amount") + if action.get("action") in {"bet", "raise"}: + rows.append( + f"{index}. {action['action']}: street_total in " + f"[{action.get('min_amount')}, {action.get('max_amount')}]" + ) + else: + rows.append(f"{index}. {action['action']} (amount={amount})") + return "\n".join(rows) + + +# --------------------------------------------------------------------------- +# OpenAI-compatible chat client +# --------------------------------------------------------------------------- + + +@dataclass +class LLMConfig: + """Static configuration for the OpenAI-compatible provider. + + A small dataclass keeps the constructor wiring readable and lets us add + fields (e.g. organisation id) without touching every call site. + """ + + base_url: str + api_key: str + model: str + timeout_seconds: float = 60.0 + temperature: float = 0.4 + stream: bool = True + + def chat_completions_url(self) -> str: + """Return the canonical chat completions URL for the configured base.""" + base = self.base_url.rstrip("/") + # Tolerate users passing the full ``/chat/completions`` path; only + # the OpenAI-style base URL is documented but mistakes are common. + if base.endswith("/chat/completions"): + return base + return f"{base}/chat/completions" + + +class LLMClient: + """Thin wrapper around the OpenAI-compatible Chat Completions API. + + Implemented with ``urllib`` to honour the project's "no third-party + dependency" constraint; swapping in ``httpx`` later would only touch + this class. + """ + + def __init__(self, config: LLMConfig) -> None: + self.config = config + + def chat( + self, + messages: list[dict[str, Any]], + on_delta: Callable[[str, str], None] | None = None, + ) -> str: + """Send a chat completion request and return the assistant text.""" + if self.config.stream: + return self._chat_stream(messages, on_delta) + return self._chat_once(messages, on_delta) + + def _request(self, body: dict[str, Any]) -> Request: + body = json.dumps( + body + ).encode("utf-8") + return Request( + self.config.chat_completions_url(), + data=body, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {self.config.api_key}", + }, + method="POST", + ) + + def _chat_once( + self, + messages: list[dict[str, Any]], + on_delta: Callable[[str, str], None] | None = None, + ) -> str: + request = self._request( + { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + "stream": False, + } + ) + try: + with urlopen(request, timeout=self.config.timeout_seconds) as resp: + payload = json.loads(resp.read().decode("utf-8")) + except HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise RuntimeError( + f"LLM HTTP {exc.code} from {self.config.chat_completions_url()}: {detail}" + ) from exc + except (OSError, URLError) as exc: + raise RuntimeError( + f"LLM request failed: {self.config.chat_completions_url()}" + ) from exc + try: + message = payload["choices"][0]["message"] + reasoning = _reasoning_text(message) + content = _message_text(message.get("content")) + if on_delta and reasoning: + on_delta("reasoning", reasoning) + if on_delta and content: + on_delta("content", content) + return content + except (KeyError, IndexError, TypeError) as exc: + raise RuntimeError(f"LLM returned unexpected payload: {payload}") from exc + + def _chat_stream( + self, + messages: list[dict[str, Any]], + on_delta: Callable[[str, str], None] | None = None, + ) -> str: + request = self._request( + { + "model": self.config.model, + "messages": messages, + "temperature": self.config.temperature, + "stream": True, + } + ) + parts: list[str] = [] + try: + with urlopen(request, timeout=self.config.timeout_seconds) as resp: + for event in _iter_sse_payloads(resp): + if event == "[DONE]": + break + try: + payload = json.loads(event) + except json.JSONDecodeError as exc: + raise RuntimeError( + f"LLM returned invalid stream event: {event!r}" + ) from exc + delta = _stream_delta(payload) + reasoning = _reasoning_text(delta) + content = _message_text(delta.get("content")) + if reasoning and on_delta: + on_delta("reasoning", reasoning) + if content: + parts.append(content) + if on_delta: + on_delta("content", content) + except HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise RuntimeError( + f"LLM HTTP {exc.code} from {self.config.chat_completions_url()}: {detail}" + ) from exc + except (OSError, URLError) as exc: + raise RuntimeError( + f"LLM request failed: {self.config.chat_completions_url()}" + ) from exc + return "".join(parts) + + +def _message_text(value: Any) -> str: + if value is None: + return "" + if isinstance(value, str): + return value + return str(value) + + +def _reasoning_text(message_or_delta: dict[str, Any]) -> str: + for key in ("reasoning_content", "reasoning", "reasoning_text"): + text = _message_text(message_or_delta.get(key)) + if text: + return text + return "" + + +def _stream_delta(payload: dict[str, Any]) -> dict[str, Any]: + try: + delta = payload["choices"][0].get("delta") or {} + except (KeyError, IndexError, TypeError) as exc: + raise RuntimeError(f"LLM returned unexpected stream payload: {payload}") from exc + if not isinstance(delta, dict): + raise RuntimeError(f"LLM returned unexpected stream delta: {payload}") + return delta + + +def _iter_sse_payloads(response: Any) -> Iterator[str]: + """Yield ``data:`` payloads from an OpenAI-compatible SSE response.""" + data_lines: list[str] = [] + for raw in response: + line = raw.decode("utf-8", errors="replace").rstrip("\r\n") + if line == "": + if data_lines: + yield "\n".join(data_lines) + data_lines = [] + continue + if line.startswith("data:"): + data_lines.append(line[5:].lstrip()) + if data_lines: + yield "\n".join(data_lines) + + +# --------------------------------------------------------------------------- +# Session lifecycle +# --------------------------------------------------------------------------- + + +@dataclass +class LLMSession: + """Chat session bound to a single hand. + + The system prompt is fixed for the whole match while the user/assistant + exchange is reset every time a new hand begins (i.e. when ``/game`` is + received). Storing recent assistant turns lets the model maintain + intra-hand continuity without re-paying for the long table snapshot. + """ + + system_prompt: str + game_id: str + player_id: str + messages: list[dict[str, Any]] = field(default_factory=list) + hand_number: int = 0 + + def reset_with_game(self, hand_number: int, game_user_message: str) -> None: + """Start a fresh exchange for a new hand.""" + self.hand_number = hand_number + self.messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": game_user_message}, + ] + + def append_user(self, content: str) -> None: + self.messages.append({"role": "user", "content": content}) + + def append_assistant(self, content: str) -> None: + self.messages.append({"role": "assistant", "content": content}) + + def chat_messages(self) -> list[dict[str, Any]]: + # Always include the system prompt; if reset_with_game has not been + # called yet (e.g. /act arrives before /game), we still want a + # legal request to go through. + if not self.messages: + return [{"role": "system", "content": self.system_prompt}] + return list(self.messages) + + +class SessionRegistry: + """Tiny LRU-style registry keyed by ``(game_id, player_id)``. + + Multiple parallel games or multiple seats served by the same process + each need an isolated chat history; the registry provides exactly that + while bounding memory. + """ + + def __init__(self, system_prompt: str, max_sessions: int = 64) -> None: + self._system_prompt = system_prompt + self._sessions: OrderedDict[tuple[str, str], LLMSession] = OrderedDict() + self._max = max_sessions + self._lock = threading.Lock() + + def get_or_create(self, game_id: str, player_id: str) -> LLMSession: + with self._lock: + key = (game_id, player_id) + session = self._sessions.get(key) + if session is None: + session = LLMSession( + system_prompt=self._system_prompt, + game_id=game_id, + player_id=player_id, + ) + self._sessions[key] = session + # Drop the oldest if we exceed the cap; LLM context is + # expensive but we never need stale game histories. + while len(self._sessions) > self._max: + self._sessions.popitem(last=False) + else: + self._sessions.move_to_end(key) + return session + + +# --------------------------------------------------------------------------- +# Action parsing +# --------------------------------------------------------------------------- + + +_JSON_OBJECT_RE = re.compile(r"\{[\s\S]*\}") + + +def parse_action_reply(reply: str) -> dict[str, Any]: + """Extract the action JSON from a possibly chatty LLM response. + + LLMs occasionally wrap JSON in markdown fences or add a sentence of + chatter despite explicit instructions. We pluck the first ``{...}`` + block and parse it; downstream code (engine ``_coerce_action``) will + sanitise illegal values, so we do not need to validate ranges here. + """ + if not isinstance(reply, str) or not reply.strip(): + raise ValueError("empty LLM reply") + match = _JSON_OBJECT_RE.search(reply) + if match is None: + raise ValueError(f"no JSON object found in LLM reply: {reply!r}") + try: + payload = json.loads(match.group(0)) + except json.JSONDecodeError as exc: + raise ValueError(f"invalid JSON in LLM reply: {reply!r}") from exc + if not isinstance(payload, dict): + raise ValueError(f"LLM reply was not a JSON object: {reply!r}") + return { + "action": str(payload.get("action") or "fold"), + "amount": int(payload.get("amount") or 0), + } + + +def fallback_action(observation: dict[str, Any]) -> dict[str, Any]: + """Pick a safe legal action when the LLM call fails entirely. + + Order of preference: ``check`` > ``call`` (cheapest) > ``fold``. + """ + legal = observation.get("legal_actions") or [] + by_name = {item.get("action"): item for item in legal} + if "check" in by_name: + return {"action": "check", "amount": 0} + if "call" in by_name: + call = by_name["call"] + return {"action": "call", "amount": int(call.get("amount") or 0)} + if "fold" in by_name: + return {"action": "fold", "amount": 0} + # Last resort: echo the first legal action as-is. + if legal: + first = legal[0] + return { + "action": str(first.get("action")), + "amount": int(first.get("amount") or 0), + } + return {"action": "fold", "amount": 0} + + +# --------------------------------------------------------------------------- +# HTTP service +# --------------------------------------------------------------------------- + + +class AIAgentService: + """Glues the LLM client, prompt library and session registry together. + + Exposed as a single object so that the HTTP handler stays thin and + purely concerned with request parsing and response framing. + """ + + def __init__( + self, + llm: LLMClient, + prompts: PromptLibrary, + registry: SessionRegistry | None = None, + console: AIAgentConsole | None = None, + ) -> None: + self.llm = llm + self.prompts = prompts + self.console = console + # The system prompt is read once and shared across all sessions to + # avoid stale copies if the operator hot-edits the markdown file + # mid-game (intentional: restart the agent to pick up changes). + system_prompt = self.prompts.load("system") + self.registry = registry or SessionRegistry(system_prompt=system_prompt) + + def handle_game(self, game_state: dict[str, Any], player_id: str) -> None: + """Open or refresh the per-hand session.""" + if self.console: + self.console.announce_game(game_state) + game_id = str(game_state.get("game_id") or "") + hand_number = int(game_state.get("hand_number") or 0) + session = self.registry.get_or_create(game_id, player_id) + session.reset_with_game( + hand_number=hand_number, + game_user_message=render_game_start_prompt(self.prompts, game_state), + ) + + def handle_act(self, observation: dict[str, Any]) -> dict[str, Any]: + """Render the prompt, call the LLM, parse the reply.""" + if self.console: + with self.console.act_log(observation): + return self._handle_act_locked(observation) + return self._handle_act_locked(observation) + + def _handle_act_locked(self, observation: dict[str, Any]) -> dict[str, Any]: + game_id = str(observation.get("game_id") or "") + player_id = str(observation.get("player_id") or "") + session = self.registry.get_or_create(game_id, player_id) + # Ensure we never silently drop the system prompt even if /game + # arrives after /act (e.g. process started mid-hand on a restart). + if not session.messages: + session.messages = [ + {"role": "system", "content": session.system_prompt} + ] + user_msg = render_observation_prompt(self.prompts, observation) + session.append_user(user_msg) + + try: + if self.console: + self.console.begin_llm_stream() + assistant_text = self.llm.chat( + session.chat_messages(), + on_delta=self.console.write_llm_delta if self.console else None, + ) + if self.console: + self.console.end_llm_stream() + except RuntimeError as exc: + # On any LLM error, use the safe fallback. We also drop the + # last user message to avoid contaminating the next turn with + # a request that produced no assistant reply. + session.messages.pop() + action = fallback_action(observation) + if self.console: + self.console.announce_warning(str(exc)) + self.console.announce_action(action, source="fallback") + return action + + session.append_assistant(assistant_text) + try: + action = parse_action_reply(assistant_text) + if self.console: + self.console.announce_action(action, source="model") + return action + except ValueError as exc: + # Reply was unparseable; keep it in the history so the LLM can + # see what it did wrong on the next turn (no extra prompt + # needed - the next observation will fill that role) and + # answer with a safe action this turn. + action = fallback_action(observation) + if self.console: + self.console.announce_warning(str(exc)) + self.console.announce_action(action, source="fallback") + return action + + +def _bind_player_id(handler: BaseHTTPRequestHandler) -> str: + """Resolve which seat a request belongs to. + + The standalone process serves *one* AI seat by default. Multi-seat + deployments can pass ``X-Player-Id`` so the registry can keep the + sessions isolated. + """ + explicit = handler.headers.get("X-Player-Id") + if explicit: + return explicit + return getattr(handler.server, "default_player_id", "ai") # type: ignore[attr-defined] + + +class AIRequestHandler(BaseHTTPRequestHandler): + """HTTP entry point for the AI agent. + + Routes: + - ``GET /health`` - liveness probe. + - ``POST /game`` - new hand boundary; opens a fresh session. + - ``POST /act`` - returns the AI-decided action. + """ + + server_version = "TexasHoldemAIAgent/0.1" + service: AIAgentService # injected by ``create_server`` + + def do_GET(self) -> None: + if self.path == "/health": + self._json({"ok": True}) + return + self._json({"error": "not found"}, HTTPStatus.NOT_FOUND) + + def do_POST(self) -> None: + routes = { + "/game": self._handle_game, + "/act": self._handle_act, + } + handler = routes.get(self.path) + if handler is None: + self._json({"error": "not found"}, HTTPStatus.NOT_FOUND) + return + + try: + payload = self._read_json() + except ValueError as exc: + self._json({"error": str(exc)}, HTTPStatus.BAD_REQUEST) + return + + try: + handler(payload) + except Exception as exc: # pragma: no cover - last-resort guard + self._json({"error": str(exc)}, HTTPStatus.INTERNAL_SERVER_ERROR) + + def _handle_game(self, payload: dict[str, Any]) -> None: + self.service.handle_game(payload, player_id=_bind_player_id(self)) + self._empty(HTTPStatus.NO_CONTENT) + + def _handle_act(self, payload: dict[str, Any]) -> None: + action = self.service.handle_act(payload) + self._json(action) + + def log_message(self, format: str, *args: Any) -> None: # noqa: A002 + return + + def _read_json(self) -> dict[str, Any]: + length = int(self.headers.get("Content-Length", "0")) + if length <= 0: + raise ValueError("request body is required") + try: + payload = json.loads(self.rfile.read(length).decode("utf-8")) + except json.JSONDecodeError as exc: + raise ValueError("request body must be valid JSON") from exc + if not isinstance(payload, dict): + raise ValueError("request body must be a JSON object") + return payload + + def _json( + self, payload: dict[str, Any], status: HTTPStatus = HTTPStatus.OK + ) -> None: + body = json.dumps(payload, ensure_ascii=True).encode("utf-8") + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def _empty(self, status: HTTPStatus) -> None: + self.send_response(status) + self.send_header("Content-Length", "0") + self.end_headers() + + +def create_server( + host: str, + port: int, + service: AIAgentService, + default_player_id: str = "ai", +) -> ThreadingHTTPServer: + """Build and configure the HTTP server. + + The ``service`` and ``default_player_id`` are attached to the server / + handler classes so that all worker threads share a single instance, + which in turn means a single registry of sessions. + """ + server = ThreadingHTTPServer((host, port), AIRequestHandler) + AIRequestHandler.service = service + server.default_player_id = default_player_id # type: ignore[attr-defined] + return server + + +# --------------------------------------------------------------------------- +# CLI entry point +# --------------------------------------------------------------------------- + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Run an OpenAI-compatible AI poker agent that exposes " + "POST /act and POST /game." + ) + ) + parser.add_argument("--host", default="127.0.0.1") + parser.add_argument("--port", default=9101, type=int) + parser.add_argument( + "--base-url", + default=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"), + help="OpenAI-compatible base URL (default: $OPENAI_BASE_URL or " + "https://api.openai.com/v1).", + ) + parser.add_argument( + "--api-key", + default=os.environ.get("OPENAI_API_KEY", ""), + help="API key (default: $OPENAI_API_KEY).", + ) + parser.add_argument( + "--model", + default=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"), + help="Model identifier (default: $OPENAI_MODEL or gpt-4o-mini).", + ) + parser.add_argument( + "--timeout", + default=60.0, + type=float, + help="LLM request timeout in seconds.", + ) + parser.add_argument( + "--temperature", + default=0.4, + type=float, + help="LLM sampling temperature.", + ) + parser.add_argument( + "--player-id", + default="ai", + help=( + "Default player_id used to key sessions. Override per-request " + "via the X-Player-Id header for multi-seat setups." + ), + ) + parser.add_argument( + "--prompts-dir", + default=str(PROMPTS_DIR), + help="Directory containing the prompt markdown templates.", + ) + parser.add_argument( + "--keep-history", + action="store_true", + help=( + "Keep previous terminal output when a new /act request arrives " + "instead of clearing the screen." + ), + ) + parser.add_argument( + "--no-stream", + action="store_true", + help="Disable streaming Chat Completions requests.", + ) + parser.add_argument( + "--no-color", + action="store_true", + help="Disable ANSI gray coloring for streamed LLM output.", + ) + args = parser.parse_args() + + if not args.api_key: + parser.error("--api-key (or OPENAI_API_KEY) is required") + + config = LLMConfig( + base_url=args.base_url, + api_key=args.api_key, + model=args.model, + timeout_seconds=args.timeout, + temperature=args.temperature, + stream=not args.no_stream, + ) + prompts = PromptLibrary(directory=Path(args.prompts_dir)) + console = AIAgentConsole( + keep_history=args.keep_history, + use_color=not args.no_color, + ) + service = AIAgentService(LLMClient(config), prompts, console=console) + server = create_server(args.host, args.port, service, default_player_id=args.player_id) + + print( + f"AI HTTP agent listening on http://{args.host}:{args.port}\n" + f" POST /act - decision request\n" + f" POST /game - new-hand snapshot (opens a fresh session)\n" + f" model : {config.model}\n" + f" base_url : {config.base_url}\n" + f" player_id : {args.player_id}\n" + f" stream : {'on' if config.stream else 'off'}\n" + f" clear-screen: {'off (keep history)' if args.keep_history else 'on'}", + file=sys.stderr, + flush=True, + ) + try: + server.serve_forever() + except KeyboardInterrupt: + pass + finally: + server.server_close() + + +if __name__ == "__main__": + main() diff --git a/texas_holdem/engine.py b/texas_holdem/engine.py index 16eb7ef..0a409f9 100644 --- a/texas_holdem/engine.py +++ b/texas_holdem/engine.py @@ -79,6 +79,12 @@ class TableGame: self._advance_button() assert self.button_index is not None + # Notify every agent that a new hand is starting. Pushing here (as + # opposed to after ``_award_pots``) lets HTTP agents seed a fresh + # session with the latest table state and per-hand history before + # any decision is asked of them. + self._broadcast_game_update() + self._deal_hole_cards(deck) small_blind_index, big_blind_index = self._blind_indexes() self._post_blind(small_blind_index, "small_blind", self.small_blind) @@ -115,9 +121,6 @@ class TableGame: finished_at=time(), ) self.hand_summaries.append(summary) - # Notify every agent so HTTP-backed clients can render the just - # finished hand. Failures here must never abort the table. - self._broadcast_game_update() return summary def run_hands(self, max_hands: int, until_one_left: bool = False) -> list[HandSummary]: diff --git a/texas_holdem/prompts/game_start.md b/texas_holdem/prompts/game_start.md new file mode 100644 index 0000000..7b1903c --- /dev/null +++ b/texas_holdem/prompts/game_start.md @@ -0,0 +1,22 @@ +# GAME_START + +A new hand of Texas Hold'em is about to begin. Use this snapshot as the +fresh context for every decision in the upcoming hand. Hole cards and +betting state from prior hands are NOT carried over. + +## Table + +- Game id: {game_id} +- Hand number: {hand_number} +- Status: {status} +- Blinds: small={small_blind}, big={big_blind} +- Button seat: {button_seat} +- Starting stack: {starting_stack} + +## Players (current stacks) + +{players_block} + +## Hands played so far ({hand_count}) + +{history_block} diff --git a/texas_holdem/prompts/observation.md b/texas_holdem/prompts/observation.md new file mode 100644 index 0000000..341cd3f --- /dev/null +++ b/texas_holdem/prompts/observation.md @@ -0,0 +1,36 @@ +# OBSERVATION (your turn to act) + +It is your turn. Read the state below and respond with a single JSON +object: `{{"action": "", "amount": }}`. Pick only from the +listed legal actions. + +## Hand state + +- Hand number: {hand_number} +- Street: {street} +- You are: player_id={player_id}, name={player_name}, seat={seat} +- Button seat: {button_seat} +- Pot size: {pot} +- To call: {to_call} +- Min raise to: {min_raise_to} +- Amount semantics for bet/raise: {amount_mode} (the integer is the + target total street bet, NOT the delta on top of your current bet) + +## Cards + +- Your hole cards: {hole_cards} +- Community board: {board} + +## Players at the table + +{players_block} + +## Action history (this hand) + +{action_history_block} + +## Legal actions + +{legal_actions_block} + +Respond NOW with one JSON line and nothing else. diff --git a/texas_holdem/prompts/system.md b/texas_holdem/prompts/system.md new file mode 100644 index 0000000..f283b2c --- /dev/null +++ b/texas_holdem/prompts/system.md @@ -0,0 +1,63 @@ +# Role + +You are an expert No-Limit Texas Hold'em poker player participating in a +multi-agent table game. You play one fixed seat for the entire match. + +You will receive: + +1. A "GAME_START" message at the beginning of every new hand, containing the + full table snapshot (players, stacks, finished hands so far). +2. One "OBSERVATION" message per decision point, describing the current + street, your hole cards, the public board, the action history of this + hand, the legal actions available to you, and the amount semantics. + +# Rules of Texas Hold'em (concise reference) + +- Each player is dealt two private hole cards. +- Five community cards are dealt across three streets: flop (3), turn (1), + river (1). +- Betting rounds occur preflop, flop, turn, river. The best 5-card hand + built from any combination of hole + board cards wins the pot. +- Hand ranking (high to low): straight flush, four of a kind, full house, + flush, straight, three of a kind, two pair, one pair, high card. +- "Position" matters: acting later in a street is an advantage. + +# Action protocol + +For every decision request you MUST output a single JSON object and nothing +else. The schema is: + +```json +{"action": "", "amount": } +``` + +- `amount` MUST be an integer (chips, no decimals). +- For `bet` and `raise`, `amount` is interpreted as **the target total bet + on the current street** (`amount_mode = "street_total"` in the + observation), and MUST satisfy + `min_amount <= amount <= max_amount` from the matching legal action. +- For `fold`, `check`, `call`, `all_in`, set `amount` to the value provided + by the matching legal action (typically 0 for fold/check, the call cost + for call, and the player's remaining stack for all_in). +- You MUST pick an action whose name appears in `legal_actions`. Anything + else risks being coerced to fold by the engine. + +# Strategic guidance + +- Open value-leaning preflop ranges in late position; tighten in early + position. +- Continuation-bet on favourable boards; balance with checks on dynamic + boards where your range is capped. +- Adjust to opponents' tendencies inferred from the hand history (passive + callers, aggressive 3-bettors, etc.). +- Manage stack-to-pot ratio: avoid bloating pots with marginal made hands; + apply pressure with strong draws when fold equity is meaningful. +- Never tilt: each decision is independent; ignore prior bad beats when + computing pot odds and equity. + +# Output discipline + +- Return ONLY the JSON object on a single line. No explanations, no markdown + fencing, no leading text. +- If unsure, prefer the safest legal action (`check` if available, else + `call` if cheap, else `fold`).