Files
texas_hold_x/texas_holdem/ai_client.py
T
2026-05-11 21:09:55 +08:00

728 lines
26 KiB
Python

"""Standalone HTTP AI poker agent backed by an OpenAI-compatible LLM.
Run as a process exposing two endpoints that the Texas Hold'em service
calls:
* ``POST /game`` - delivered at the start of every new hand. We use it as
the boundary that opens a fresh chat session for that hand and seeds it
with a human-readable rendering of the table snapshot (history of past
hands included).
* ``POST /act`` - the per-decision request. We render the observation
with a templated user prompt, ask the configured LLM, parse the JSON
reply and return it to the server.
Run::
python -m texas_holdem.ai_client \\
--host 127.0.0.1 --port 9101 \\
--base-url https://api.openai.com/v1 \\
--api-key $OPENAI_API_KEY \\
--model gpt-4o-mini
Hook it up by passing the *base* URL when creating the game::
{"id": "ai", "name": "AI",
"agent": {"type": "http", "endpoint": "http://127.0.0.1:9101"}}
Design notes:
- Prompts live in ``texas_holdem/prompts/*.md`` so non-engineers can edit
them without touching code; we read them once at boot via
:class:`PromptLibrary`.
- :class:`LLMSession` owns the chat history for a single (game_id, hand)
scope and is reset by every ``/game`` push, matching the user's mental
model that one hand == one session.
- :class:`SessionRegistry` keeps a tiny LRU per ``(game_id, player_id)``
so the same process can serve multiple parallel games / seats.
- The LLM client speaks OpenAI's ``/v1/chat/completions`` schema, which
is what virtually every OpenAI-compatible provider implements.
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sys
import threading
from collections import OrderedDict
from dataclasses import dataclass, field
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import Any
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
# Default location of the prompt templates. Living next to this module lets
# operators edit them in place without re-installing the package.
PROMPTS_DIR = Path(__file__).resolve().parent / "prompts"
# ---------------------------------------------------------------------------
# Prompt rendering
# ---------------------------------------------------------------------------
class PromptLibrary:
"""Loads and renders prompt templates from disk.
Templates use Python ``str.format`` placeholders. Centralising the
render step here keeps the LLM-facing wording out of the code base and
avoids ad-hoc string concatenation in the agent logic.
"""
def __init__(self, directory: Path = PROMPTS_DIR) -> None:
self.directory = directory
# Cache prompts by name; they are a couple of KB and read-only at
# runtime, so we trade a few bytes of memory for zero disk IO per
# request.
self._cache: dict[str, str] = {}
def load(self, name: str) -> str:
if name not in self._cache:
path = self.directory / f"{name}.md"
if not path.exists():
raise FileNotFoundError(f"missing prompt template: {path}")
self._cache[name] = path.read_text(encoding="utf-8")
return self._cache[name]
def render(self, name: str, **fields: Any) -> str:
"""Return the named template with ``fields`` substituted via format()."""
template = self.load(name)
return template.format(**fields)
def render_game_start_prompt(library: PromptLibrary, game_state: dict[str, Any]) -> str:
"""Build the opening user-message describing a new hand.
Pulls out the heavy formatting (player rows, hand history) into helpers
so the template itself stays declarative.
"""
return library.render(
"game_start",
game_id=game_state.get("game_id"),
hand_number=game_state.get("hand_number"),
status=game_state.get("status"),
small_blind=game_state.get("small_blind"),
big_blind=game_state.get("big_blind"),
button_seat=game_state.get("button_seat"),
starting_stack=game_state.get("starting_stack"),
players_block=_format_players_block(game_state.get("players") or []),
hand_count=len(game_state.get("hands") or []),
history_block=_format_history_block(game_state.get("hands") or []),
)
def render_observation_prompt(
library: PromptLibrary, observation: dict[str, Any]
) -> str:
"""Build the per-decision user-message from a server observation dict."""
legal_actions = list(observation.get("legal_actions") or [])
you = _find_self_player(observation)
return library.render(
"observation",
hand_number=observation.get("hand_number"),
street=observation.get("street"),
player_id=observation.get("player_id"),
player_name=you.get("name") if you else observation.get("player_id"),
seat=observation.get("seat"),
button_seat=observation.get("button_seat"),
pot=observation.get("pot"),
to_call=observation.get("to_call"),
min_raise_to=observation.get("min_raise_to"),
amount_mode=observation.get("amount_mode") or "street_total",
hole_cards=_format_card_list(observation.get("hole_cards")),
board=_format_card_list(observation.get("board")),
players_block=_format_players_block(observation.get("players") or []),
action_history_block=_format_action_history(
observation.get("action_history") or []
),
legal_actions_block=_format_legal_actions(legal_actions),
)
def _find_self_player(observation: dict[str, Any]) -> dict[str, Any] | None:
"""Locate the acting player's row inside the observation snapshot."""
pid = observation.get("player_id")
for player in observation.get("players") or []:
if player.get("player_id") == pid:
return player
return None
def _format_card_list(cards: Any) -> str:
"""Render a list of card labels for prompts; never returns an empty string."""
if not cards:
return "(none)"
return " ".join(str(card) for card in cards)
def _format_players_block(players: list[dict[str, Any]]) -> str:
"""Render the per-seat status table used in both prompt templates."""
if not players:
return "(no players)"
rows: list[str] = []
for player in players:
flags = []
if player.get("folded"):
flags.append("folded")
if player.get("all_in"):
flags.append("all_in")
if not player.get("in_hand"):
flags.append("out")
flag_text = ",".join(flags) if flags else "active"
rows.append(
f"- seat {player.get('seat')}: id={player.get('player_id')}, "
f"name={player.get('name')}, stack={player.get('stack')}, "
f"street_bet={player.get('street_bet', 0)}, "
f"total_bet={player.get('total_bet', 0)}, status={flag_text}"
)
return "\n".join(rows)
def _format_history_block(hands: list[dict[str, Any]]) -> str:
"""Render a compact digest of finished hands for the GAME_START message."""
if not hands:
return "(no hands played yet)"
digests: list[str] = []
for hand in hands:
awards = hand.get("awards") or []
winner_lines = [
f" pot {a.get('amount')}: -> {','.join(a.get('winners') or [])} "
f"({(a.get('hand_value') or {}).get('name', '-')})"
for a in awards
]
showdown = hand.get("showdown_hands") or {}
showdown_lines = [
f" showdown {pid}: {' '.join(cards)}"
for pid, cards in showdown.items()
]
digests.append(
"\n".join(
[
f"- Hand #{hand.get('hand_number')} "
f"(button seat {hand.get('button_seat')}), "
f"board: {_format_card_list(hand.get('board'))}",
*winner_lines,
*showdown_lines,
]
)
)
return "\n".join(digests)
def _format_action_history(history: list[dict[str, Any]]) -> str:
"""Render the per-action log; trims very old entries to keep prompts cheap."""
if not history:
return "(no actions yet)"
# The engine never produces unbounded history within a single hand, but
# we cap defensively so a malformed payload cannot blow up token usage.
rows = []
for record in history[-32:]:
rows.append(
f"- [{record.get('street')}] {record.get('player_id')} -> "
f"{record.get('action')} amount={record.get('amount', 0)}"
)
return "\n".join(rows)
def _format_legal_actions(legal: list[dict[str, Any]]) -> str:
"""Render a numbered list of legal actions including amount ranges."""
if not legal:
return "(no legal actions)"
rows: list[str] = []
for index, action in enumerate(legal, start=1):
amount = action.get("amount")
if action.get("action") in {"bet", "raise"}:
rows.append(
f"{index}. {action['action']}: street_total in "
f"[{action.get('min_amount')}, {action.get('max_amount')}]"
)
else:
rows.append(f"{index}. {action['action']} (amount={amount})")
return "\n".join(rows)
# ---------------------------------------------------------------------------
# OpenAI-compatible chat client
# ---------------------------------------------------------------------------
@dataclass
class LLMConfig:
"""Static configuration for the OpenAI-compatible provider.
A small dataclass keeps the constructor wiring readable and lets us add
fields (e.g. organisation id) without touching every call site.
"""
base_url: str
api_key: str
model: str
timeout_seconds: float = 60.0
temperature: float = 0.4
def chat_completions_url(self) -> str:
"""Return the canonical chat completions URL for the configured base."""
base = self.base_url.rstrip("/")
# Tolerate users passing the full ``/chat/completions`` path; only
# the OpenAI-style base URL is documented but mistakes are common.
if base.endswith("/chat/completions"):
return base
return f"{base}/chat/completions"
class LLMClient:
"""Thin wrapper around the OpenAI-compatible Chat Completions API.
Implemented with ``urllib`` to honour the project's "no third-party
dependency" constraint; swapping in ``httpx`` later would only touch
this class.
"""
def __init__(self, config: LLMConfig) -> None:
self.config = config
def chat(self, messages: list[dict[str, Any]]) -> str:
"""Send a chat completion request and return the assistant text."""
body = json.dumps(
{
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
}
).encode("utf-8")
request = Request(
self.config.chat_completions_url(),
data=body,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {self.config.api_key}",
},
method="POST",
)
try:
with urlopen(request, timeout=self.config.timeout_seconds) as resp:
payload = json.loads(resp.read().decode("utf-8"))
except HTTPError as exc:
detail = exc.read().decode("utf-8", errors="replace")
raise RuntimeError(
f"LLM HTTP {exc.code} from {self.config.chat_completions_url()}: {detail}"
) from exc
except (OSError, URLError) as exc:
raise RuntimeError(
f"LLM request failed: {self.config.chat_completions_url()}"
) from exc
try:
return payload["choices"][0]["message"]["content"]
except (KeyError, IndexError, TypeError) as exc:
raise RuntimeError(f"LLM returned unexpected payload: {payload}") from exc
# ---------------------------------------------------------------------------
# Session lifecycle
# ---------------------------------------------------------------------------
@dataclass
class LLMSession:
"""Chat session bound to a single hand.
The system prompt is fixed for the whole match while the user/assistant
exchange is reset every time a new hand begins (i.e. when ``/game`` is
received). Storing recent assistant turns lets the model maintain
intra-hand continuity without re-paying for the long table snapshot.
"""
system_prompt: str
game_id: str
player_id: str
messages: list[dict[str, Any]] = field(default_factory=list)
hand_number: int = 0
def reset_with_game(self, hand_number: int, game_user_message: str) -> None:
"""Start a fresh exchange for a new hand."""
self.hand_number = hand_number
self.messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": game_user_message},
]
def append_user(self, content: str) -> None:
self.messages.append({"role": "user", "content": content})
def append_assistant(self, content: str) -> None:
self.messages.append({"role": "assistant", "content": content})
def chat_messages(self) -> list[dict[str, Any]]:
# Always include the system prompt; if reset_with_game has not been
# called yet (e.g. /act arrives before /game), we still want a
# legal request to go through.
if not self.messages:
return [{"role": "system", "content": self.system_prompt}]
return list(self.messages)
class SessionRegistry:
"""Tiny LRU-style registry keyed by ``(game_id, player_id)``.
Multiple parallel games or multiple seats served by the same process
each need an isolated chat history; the registry provides exactly that
while bounding memory.
"""
def __init__(self, system_prompt: str, max_sessions: int = 64) -> None:
self._system_prompt = system_prompt
self._sessions: OrderedDict[tuple[str, str], LLMSession] = OrderedDict()
self._max = max_sessions
self._lock = threading.Lock()
def get_or_create(self, game_id: str, player_id: str) -> LLMSession:
with self._lock:
key = (game_id, player_id)
session = self._sessions.get(key)
if session is None:
session = LLMSession(
system_prompt=self._system_prompt,
game_id=game_id,
player_id=player_id,
)
self._sessions[key] = session
# Drop the oldest if we exceed the cap; LLM context is
# expensive but we never need stale game histories.
while len(self._sessions) > self._max:
self._sessions.popitem(last=False)
else:
self._sessions.move_to_end(key)
return session
# ---------------------------------------------------------------------------
# Action parsing
# ---------------------------------------------------------------------------
_JSON_OBJECT_RE = re.compile(r"\{[\s\S]*\}")
def parse_action_reply(reply: str) -> dict[str, Any]:
"""Extract the action JSON from a possibly chatty LLM response.
LLMs occasionally wrap JSON in markdown fences or add a sentence of
chatter despite explicit instructions. We pluck the first ``{...}``
block and parse it; downstream code (engine ``_coerce_action``) will
sanitise illegal values, so we do not need to validate ranges here.
"""
if not isinstance(reply, str) or not reply.strip():
raise ValueError("empty LLM reply")
match = _JSON_OBJECT_RE.search(reply)
if match is None:
raise ValueError(f"no JSON object found in LLM reply: {reply!r}")
try:
payload = json.loads(match.group(0))
except json.JSONDecodeError as exc:
raise ValueError(f"invalid JSON in LLM reply: {reply!r}") from exc
if not isinstance(payload, dict):
raise ValueError(f"LLM reply was not a JSON object: {reply!r}")
return {
"action": str(payload.get("action") or "fold"),
"amount": int(payload.get("amount") or 0),
}
def fallback_action(observation: dict[str, Any]) -> dict[str, Any]:
"""Pick a safe legal action when the LLM call fails entirely.
Order of preference: ``check`` > ``call`` (cheapest) > ``fold``.
"""
legal = observation.get("legal_actions") or []
by_name = {item.get("action"): item for item in legal}
if "check" in by_name:
return {"action": "check", "amount": 0}
if "call" in by_name:
call = by_name["call"]
return {"action": "call", "amount": int(call.get("amount") or 0)}
if "fold" in by_name:
return {"action": "fold", "amount": 0}
# Last resort: echo the first legal action as-is.
if legal:
first = legal[0]
return {
"action": str(first.get("action")),
"amount": int(first.get("amount") or 0),
}
return {"action": "fold", "amount": 0}
# ---------------------------------------------------------------------------
# HTTP service
# ---------------------------------------------------------------------------
class AIAgentService:
"""Glues the LLM client, prompt library and session registry together.
Exposed as a single object so that the HTTP handler stays thin and
purely concerned with request parsing and response framing.
"""
def __init__(
self,
llm: LLMClient,
prompts: PromptLibrary,
registry: SessionRegistry | None = None,
) -> None:
self.llm = llm
self.prompts = prompts
# The system prompt is read once and shared across all sessions to
# avoid stale copies if the operator hot-edits the markdown file
# mid-game (intentional: restart the agent to pick up changes).
system_prompt = self.prompts.load("system")
self.registry = registry or SessionRegistry(system_prompt=system_prompt)
def handle_game(self, game_state: dict[str, Any], player_id: str) -> None:
"""Open or refresh the per-hand session."""
game_id = str(game_state.get("game_id") or "")
hand_number = int(game_state.get("hand_number") or 0)
session = self.registry.get_or_create(game_id, player_id)
session.reset_with_game(
hand_number=hand_number,
game_user_message=render_game_start_prompt(self.prompts, game_state),
)
def handle_act(self, observation: dict[str, Any]) -> dict[str, Any]:
"""Render the prompt, call the LLM, parse the reply."""
game_id = str(observation.get("game_id") or "")
player_id = str(observation.get("player_id") or "")
session = self.registry.get_or_create(game_id, player_id)
# Ensure we never silently drop the system prompt even if /game
# arrives after /act (e.g. process started mid-hand on a restart).
if not session.messages:
session.messages = [
{"role": "system", "content": session.system_prompt}
]
user_msg = render_observation_prompt(self.prompts, observation)
session.append_user(user_msg)
try:
assistant_text = self.llm.chat(session.chat_messages())
except RuntimeError:
# On any LLM error, use the safe fallback. We also drop the
# last user message to avoid contaminating the next turn with
# a request that produced no assistant reply.
session.messages.pop()
return fallback_action(observation)
session.append_assistant(assistant_text)
try:
return parse_action_reply(assistant_text)
except ValueError:
# Reply was unparseable; keep it in the history so the LLM can
# see what it did wrong on the next turn (no extra prompt
# needed - the next observation will fill that role) and
# answer with a safe action this turn.
return fallback_action(observation)
def _bind_player_id(handler: BaseHTTPRequestHandler) -> str:
"""Resolve which seat a request belongs to.
The standalone process serves *one* AI seat by default. Multi-seat
deployments can pass ``X-Player-Id`` so the registry can keep the
sessions isolated.
"""
explicit = handler.headers.get("X-Player-Id")
if explicit:
return explicit
return getattr(handler.server, "default_player_id", "ai") # type: ignore[attr-defined]
class AIRequestHandler(BaseHTTPRequestHandler):
"""HTTP entry point for the AI agent.
Routes:
- ``GET /health`` - liveness probe.
- ``POST /game`` - new hand boundary; opens a fresh session.
- ``POST /act`` - returns the AI-decided action.
"""
server_version = "TexasHoldemAIAgent/0.1"
service: AIAgentService # injected by ``create_server``
def do_GET(self) -> None:
if self.path == "/health":
self._json({"ok": True})
return
self._json({"error": "not found"}, HTTPStatus.NOT_FOUND)
def do_POST(self) -> None:
routes = {
"/game": self._handle_game,
"/act": self._handle_act,
}
handler = routes.get(self.path)
if handler is None:
self._json({"error": "not found"}, HTTPStatus.NOT_FOUND)
return
try:
payload = self._read_json()
except ValueError as exc:
self._json({"error": str(exc)}, HTTPStatus.BAD_REQUEST)
return
try:
handler(payload)
except Exception as exc: # pragma: no cover - last-resort guard
self._json({"error": str(exc)}, HTTPStatus.INTERNAL_SERVER_ERROR)
def _handle_game(self, payload: dict[str, Any]) -> None:
self.service.handle_game(payload, player_id=_bind_player_id(self))
self._empty(HTTPStatus.NO_CONTENT)
def _handle_act(self, payload: dict[str, Any]) -> None:
action = self.service.handle_act(payload)
self._json(action)
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
return
def _read_json(self) -> dict[str, Any]:
length = int(self.headers.get("Content-Length", "0"))
if length <= 0:
raise ValueError("request body is required")
try:
payload = json.loads(self.rfile.read(length).decode("utf-8"))
except json.JSONDecodeError as exc:
raise ValueError("request body must be valid JSON") from exc
if not isinstance(payload, dict):
raise ValueError("request body must be a JSON object")
return payload
def _json(
self, payload: dict[str, Any], status: HTTPStatus = HTTPStatus.OK
) -> None:
body = json.dumps(payload, ensure_ascii=True).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def _empty(self, status: HTTPStatus) -> None:
self.send_response(status)
self.send_header("Content-Length", "0")
self.end_headers()
def create_server(
host: str,
port: int,
service: AIAgentService,
default_player_id: str = "ai",
) -> ThreadingHTTPServer:
"""Build and configure the HTTP server.
The ``service`` and ``default_player_id`` are attached to the server /
handler classes so that all worker threads share a single instance,
which in turn means a single registry of sessions.
"""
server = ThreadingHTTPServer((host, port), AIRequestHandler)
AIRequestHandler.service = service
server.default_player_id = default_player_id # type: ignore[attr-defined]
return server
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(
description=(
"Run an OpenAI-compatible AI poker agent that exposes "
"POST /act and POST /game."
)
)
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=9101, type=int)
parser.add_argument(
"--base-url",
default=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
help="OpenAI-compatible base URL (default: $OPENAI_BASE_URL or "
"https://api.openai.com/v1).",
)
parser.add_argument(
"--api-key",
default=os.environ.get("OPENAI_API_KEY", ""),
help="API key (default: $OPENAI_API_KEY).",
)
parser.add_argument(
"--model",
default=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
help="Model identifier (default: $OPENAI_MODEL or gpt-4o-mini).",
)
parser.add_argument(
"--timeout",
default=60.0,
type=float,
help="LLM request timeout in seconds.",
)
parser.add_argument(
"--temperature",
default=0.4,
type=float,
help="LLM sampling temperature.",
)
parser.add_argument(
"--player-id",
default="ai",
help=(
"Default player_id used to key sessions. Override per-request "
"via the X-Player-Id header for multi-seat setups."
),
)
parser.add_argument(
"--prompts-dir",
default=str(PROMPTS_DIR),
help="Directory containing the prompt markdown templates.",
)
args = parser.parse_args()
if not args.api_key:
parser.error("--api-key (or OPENAI_API_KEY) is required")
config = LLMConfig(
base_url=args.base_url,
api_key=args.api_key,
model=args.model,
timeout_seconds=args.timeout,
temperature=args.temperature,
)
prompts = PromptLibrary(directory=Path(args.prompts_dir))
service = AIAgentService(LLMClient(config), prompts)
server = create_server(args.host, args.port, service, default_player_id=args.player_id)
print(
f"AI HTTP agent listening on http://{args.host}:{args.port}\n"
f" POST /act - decision request\n"
f" POST /game - new-hand snapshot (opens a fresh session)\n"
f" model : {config.model}\n"
f" base_url : {config.base_url}\n"
f" player_id : {args.player_id}",
file=sys.stderr,
flush=True,
)
try:
server.serve_forever()
except KeyboardInterrupt:
pass
finally:
server.server_close()
if __name__ == "__main__":
main()