feat: add ai agent http agent
This commit is contained in:
@@ -0,0 +1,972 @@
|
||||
"""Standalone HTTP AI poker agent backed by an OpenAI-compatible LLM.
|
||||
|
||||
Run as a process exposing two endpoints that the Texas Hold'em service
|
||||
calls:
|
||||
|
||||
* ``POST /game`` - delivered at the start of every new hand. We use it as
|
||||
the boundary that opens a fresh chat session for that hand and seeds it
|
||||
with a human-readable rendering of the table snapshot (history of past
|
||||
hands included).
|
||||
* ``POST /act`` - the per-decision request. We render the observation
|
||||
with a templated user prompt, ask the configured LLM, parse the JSON
|
||||
reply and return it to the server.
|
||||
|
||||
Run::
|
||||
|
||||
python -m texas_holdem.ai_client \\
|
||||
--host 127.0.0.1 --port 9101 \\
|
||||
--base-url https://api.openai.com/v1 \\
|
||||
--api-key $OPENAI_API_KEY \\
|
||||
--model gpt-4o-mini
|
||||
|
||||
Hook it up by passing the *base* URL when creating the game::
|
||||
|
||||
{"id": "ai", "name": "AI",
|
||||
"agent": {"type": "http", "endpoint": "http://127.0.0.1:9101"}}
|
||||
|
||||
Design notes:
|
||||
- Prompts live in ``texas_holdem/prompts/*.md`` so non-engineers can edit
|
||||
them without touching code; we read them once at boot via
|
||||
:class:`PromptLibrary`.
|
||||
- :class:`LLMSession` owns the chat history for a single (game_id, hand)
|
||||
scope and is reset by every ``/game`` push, matching the user's mental
|
||||
model that one hand == one session.
|
||||
- :class:`SessionRegistry` keeps a tiny LRU per ``(game_id, player_id)``
|
||||
so the same process can serve multiple parallel games / seats.
|
||||
- The LLM client speaks OpenAI's ``/v1/chat/completions`` schema, which
|
||||
is what virtually every OpenAI-compatible provider implements.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, Callable, Iterator
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from texas_holdem.human_io import clear_screen, render_game_state, render_observation
|
||||
|
||||
# Default location of the prompt templates. Living next to this module lets
|
||||
# operators edit them in place without re-installing the package.
|
||||
PROMPTS_DIR = Path(__file__).resolve().parent / "prompts"
|
||||
ANSI_GRAY = "\x1b[90m"
|
||||
ANSI_RESET = "\x1b[0m"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Terminal diagnostics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AIAgentConsole:
|
||||
"""Serialised terminal output for the standalone AI agent.
|
||||
|
||||
The behaviour mirrors :class:`HumanClientConsole`: by default each
|
||||
decision clears the terminal first, while ``keep_history=True`` leaves
|
||||
previous game/act logs in scrollback. The LLM stream is printed in gray
|
||||
so model output stays visually separate from game state and final
|
||||
actions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_stream: IO[str] | None = None,
|
||||
keep_history: bool = False,
|
||||
use_color: bool = True,
|
||||
) -> None:
|
||||
self._output = output_stream if output_stream is not None else sys.stdout
|
||||
self._keep_history = keep_history
|
||||
self._use_color = use_color
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@contextmanager
|
||||
def act_log(self, observation: dict[str, Any]) -> Iterator[None]:
|
||||
"""Render one received ``/act`` payload and hold the console lock."""
|
||||
with self._lock:
|
||||
if not self._keep_history:
|
||||
clear_screen(self._write)
|
||||
self._write(render_observation(observation))
|
||||
yield
|
||||
|
||||
def announce_game(self, game_state: dict[str, Any]) -> None:
|
||||
"""Render one received ``/game`` payload."""
|
||||
with self._lock:
|
||||
self._write(render_game_state(game_state))
|
||||
|
||||
def begin_llm_stream(self) -> None:
|
||||
self._write(self._gray("AI MODEL STREAM\n"))
|
||||
|
||||
def write_llm_delta(self, kind: str, text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
self._write(self._gray(text))
|
||||
|
||||
def end_llm_stream(self) -> None:
|
||||
self._write(self._gray("\n"))
|
||||
|
||||
def announce_action(
|
||||
self,
|
||||
action: dict[str, Any],
|
||||
source: str = "model",
|
||||
) -> None:
|
||||
body = json.dumps(action, ensure_ascii=False)
|
||||
self._write(f"\nAI ACTION ({source}) -> {body}\n")
|
||||
self._write("~" * 60 + "\n\n")
|
||||
|
||||
def announce_warning(self, message: str) -> None:
|
||||
self._write(f"\nAI WARNING -> {message}\n")
|
||||
|
||||
def _gray(self, text: str) -> str:
|
||||
if not self._use_color:
|
||||
return text
|
||||
return f"{ANSI_GRAY}{text}{ANSI_RESET}"
|
||||
|
||||
def _write(self, text: str) -> None:
|
||||
self._output.write(text)
|
||||
self._output.flush()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt rendering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class PromptLibrary:
|
||||
"""Loads and renders prompt templates from disk.
|
||||
|
||||
Templates use Python ``str.format`` placeholders. Centralising the
|
||||
render step here keeps the LLM-facing wording out of the code base and
|
||||
avoids ad-hoc string concatenation in the agent logic.
|
||||
"""
|
||||
|
||||
def __init__(self, directory: Path = PROMPTS_DIR) -> None:
|
||||
self.directory = directory
|
||||
# Cache prompts by name; they are a couple of KB and read-only at
|
||||
# runtime, so we trade a few bytes of memory for zero disk IO per
|
||||
# request.
|
||||
self._cache: dict[str, str] = {}
|
||||
|
||||
def load(self, name: str) -> str:
|
||||
if name not in self._cache:
|
||||
path = self.directory / f"{name}.md"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"missing prompt template: {path}")
|
||||
self._cache[name] = path.read_text(encoding="utf-8")
|
||||
return self._cache[name]
|
||||
|
||||
def render(self, name: str, **fields: Any) -> str:
|
||||
"""Return the named template with ``fields`` substituted via format()."""
|
||||
template = self.load(name)
|
||||
return template.format(**fields)
|
||||
|
||||
|
||||
def render_game_start_prompt(library: PromptLibrary, game_state: dict[str, Any]) -> str:
|
||||
"""Build the opening user-message describing a new hand.
|
||||
|
||||
Pulls out the heavy formatting (player rows, hand history) into helpers
|
||||
so the template itself stays declarative.
|
||||
"""
|
||||
return library.render(
|
||||
"game_start",
|
||||
game_id=game_state.get("game_id"),
|
||||
hand_number=game_state.get("hand_number"),
|
||||
status=game_state.get("status"),
|
||||
small_blind=game_state.get("small_blind"),
|
||||
big_blind=game_state.get("big_blind"),
|
||||
button_seat=game_state.get("button_seat"),
|
||||
starting_stack=game_state.get("starting_stack"),
|
||||
players_block=_format_players_block(game_state.get("players") or []),
|
||||
hand_count=len(game_state.get("hands") or []),
|
||||
history_block=_format_history_block(game_state.get("hands") or []),
|
||||
)
|
||||
|
||||
|
||||
def render_observation_prompt(
|
||||
library: PromptLibrary, observation: dict[str, Any]
|
||||
) -> str:
|
||||
"""Build the per-decision user-message from a server observation dict."""
|
||||
legal_actions = list(observation.get("legal_actions") or [])
|
||||
you = _find_self_player(observation)
|
||||
return library.render(
|
||||
"observation",
|
||||
hand_number=observation.get("hand_number"),
|
||||
street=observation.get("street"),
|
||||
player_id=observation.get("player_id"),
|
||||
player_name=you.get("name") if you else observation.get("player_id"),
|
||||
seat=observation.get("seat"),
|
||||
button_seat=observation.get("button_seat"),
|
||||
pot=observation.get("pot"),
|
||||
to_call=observation.get("to_call"),
|
||||
min_raise_to=observation.get("min_raise_to"),
|
||||
amount_mode=observation.get("amount_mode") or "street_total",
|
||||
hole_cards=_format_card_list(observation.get("hole_cards")),
|
||||
board=_format_card_list(observation.get("board")),
|
||||
players_block=_format_players_block(observation.get("players") or []),
|
||||
action_history_block=_format_action_history(
|
||||
observation.get("action_history") or []
|
||||
),
|
||||
legal_actions_block=_format_legal_actions(legal_actions),
|
||||
)
|
||||
|
||||
|
||||
def _find_self_player(observation: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Locate the acting player's row inside the observation snapshot."""
|
||||
pid = observation.get("player_id")
|
||||
for player in observation.get("players") or []:
|
||||
if player.get("player_id") == pid:
|
||||
return player
|
||||
return None
|
||||
|
||||
|
||||
def _format_card_list(cards: Any) -> str:
|
||||
"""Render a list of card labels for prompts; never returns an empty string."""
|
||||
if not cards:
|
||||
return "(none)"
|
||||
return " ".join(str(card) for card in cards)
|
||||
|
||||
|
||||
def _format_players_block(players: list[dict[str, Any]]) -> str:
|
||||
"""Render the per-seat status table used in both prompt templates."""
|
||||
if not players:
|
||||
return "(no players)"
|
||||
rows: list[str] = []
|
||||
for player in players:
|
||||
flags = []
|
||||
if player.get("folded"):
|
||||
flags.append("folded")
|
||||
if player.get("all_in"):
|
||||
flags.append("all_in")
|
||||
if not player.get("in_hand"):
|
||||
flags.append("out")
|
||||
flag_text = ",".join(flags) if flags else "active"
|
||||
rows.append(
|
||||
f"- seat {player.get('seat')}: id={player.get('player_id')}, "
|
||||
f"name={player.get('name')}, stack={player.get('stack')}, "
|
||||
f"street_bet={player.get('street_bet', 0)}, "
|
||||
f"total_bet={player.get('total_bet', 0)}, status={flag_text}"
|
||||
)
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
def _format_history_block(hands: list[dict[str, Any]]) -> str:
|
||||
"""Render a compact digest of finished hands for the GAME_START message."""
|
||||
if not hands:
|
||||
return "(no hands played yet)"
|
||||
digests: list[str] = []
|
||||
for hand in hands:
|
||||
awards = hand.get("awards") or []
|
||||
winner_lines = [
|
||||
f" pot {a.get('amount')}: -> {','.join(a.get('winners') or [])} "
|
||||
f"({(a.get('hand_value') or {}).get('name', '-')})"
|
||||
for a in awards
|
||||
]
|
||||
showdown = hand.get("showdown_hands") or {}
|
||||
showdown_lines = [
|
||||
f" showdown {pid}: {' '.join(cards)}"
|
||||
for pid, cards in showdown.items()
|
||||
]
|
||||
digests.append(
|
||||
"\n".join(
|
||||
[
|
||||
f"- Hand #{hand.get('hand_number')} "
|
||||
f"(button seat {hand.get('button_seat')}), "
|
||||
f"board: {_format_card_list(hand.get('board'))}",
|
||||
*winner_lines,
|
||||
*showdown_lines,
|
||||
]
|
||||
)
|
||||
)
|
||||
return "\n".join(digests)
|
||||
|
||||
|
||||
def _format_action_history(history: list[dict[str, Any]]) -> str:
|
||||
"""Render the per-action log; trims very old entries to keep prompts cheap."""
|
||||
if not history:
|
||||
return "(no actions yet)"
|
||||
# The engine never produces unbounded history within a single hand, but
|
||||
# we cap defensively so a malformed payload cannot blow up token usage.
|
||||
rows = []
|
||||
for record in history[-32:]:
|
||||
rows.append(
|
||||
f"- [{record.get('street')}] {record.get('player_id')} -> "
|
||||
f"{record.get('action')} amount={record.get('amount', 0)}"
|
||||
)
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
def _format_legal_actions(legal: list[dict[str, Any]]) -> str:
|
||||
"""Render a numbered list of legal actions including amount ranges."""
|
||||
if not legal:
|
||||
return "(no legal actions)"
|
||||
rows: list[str] = []
|
||||
for index, action in enumerate(legal, start=1):
|
||||
amount = action.get("amount")
|
||||
if action.get("action") in {"bet", "raise"}:
|
||||
rows.append(
|
||||
f"{index}. {action['action']}: street_total in "
|
||||
f"[{action.get('min_amount')}, {action.get('max_amount')}]"
|
||||
)
|
||||
else:
|
||||
rows.append(f"{index}. {action['action']} (amount={amount})")
|
||||
return "\n".join(rows)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OpenAI-compatible chat client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""Static configuration for the OpenAI-compatible provider.
|
||||
|
||||
A small dataclass keeps the constructor wiring readable and lets us add
|
||||
fields (e.g. organisation id) without touching every call site.
|
||||
"""
|
||||
|
||||
base_url: str
|
||||
api_key: str
|
||||
model: str
|
||||
timeout_seconds: float = 60.0
|
||||
temperature: float = 0.4
|
||||
stream: bool = True
|
||||
|
||||
def chat_completions_url(self) -> str:
|
||||
"""Return the canonical chat completions URL for the configured base."""
|
||||
base = self.base_url.rstrip("/")
|
||||
# Tolerate users passing the full ``/chat/completions`` path; only
|
||||
# the OpenAI-style base URL is documented but mistakes are common.
|
||||
if base.endswith("/chat/completions"):
|
||||
return base
|
||||
return f"{base}/chat/completions"
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""Thin wrapper around the OpenAI-compatible Chat Completions API.
|
||||
|
||||
Implemented with ``urllib`` to honour the project's "no third-party
|
||||
dependency" constraint; swapping in ``httpx`` later would only touch
|
||||
this class.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LLMConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
on_delta: Callable[[str, str], None] | None = None,
|
||||
) -> str:
|
||||
"""Send a chat completion request and return the assistant text."""
|
||||
if self.config.stream:
|
||||
return self._chat_stream(messages, on_delta)
|
||||
return self._chat_once(messages, on_delta)
|
||||
|
||||
def _request(self, body: dict[str, Any]) -> Request:
|
||||
body = json.dumps(
|
||||
body
|
||||
).encode("utf-8")
|
||||
return Request(
|
||||
self.config.chat_completions_url(),
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.config.api_key}",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
def _chat_once(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
on_delta: Callable[[str, str], None] | None = None,
|
||||
) -> str:
|
||||
request = self._request(
|
||||
{
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"stream": False,
|
||||
}
|
||||
)
|
||||
try:
|
||||
with urlopen(request, timeout=self.config.timeout_seconds) as resp:
|
||||
payload = json.loads(resp.read().decode("utf-8"))
|
||||
except HTTPError as exc:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
raise RuntimeError(
|
||||
f"LLM HTTP {exc.code} from {self.config.chat_completions_url()}: {detail}"
|
||||
) from exc
|
||||
except (OSError, URLError) as exc:
|
||||
raise RuntimeError(
|
||||
f"LLM request failed: {self.config.chat_completions_url()}"
|
||||
) from exc
|
||||
try:
|
||||
message = payload["choices"][0]["message"]
|
||||
reasoning = _reasoning_text(message)
|
||||
content = _message_text(message.get("content"))
|
||||
if on_delta and reasoning:
|
||||
on_delta("reasoning", reasoning)
|
||||
if on_delta and content:
|
||||
on_delta("content", content)
|
||||
return content
|
||||
except (KeyError, IndexError, TypeError) as exc:
|
||||
raise RuntimeError(f"LLM returned unexpected payload: {payload}") from exc
|
||||
|
||||
def _chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
on_delta: Callable[[str, str], None] | None = None,
|
||||
) -> str:
|
||||
request = self._request(
|
||||
{
|
||||
"model": self.config.model,
|
||||
"messages": messages,
|
||||
"temperature": self.config.temperature,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
parts: list[str] = []
|
||||
try:
|
||||
with urlopen(request, timeout=self.config.timeout_seconds) as resp:
|
||||
for event in _iter_sse_payloads(resp):
|
||||
if event == "[DONE]":
|
||||
break
|
||||
try:
|
||||
payload = json.loads(event)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(
|
||||
f"LLM returned invalid stream event: {event!r}"
|
||||
) from exc
|
||||
delta = _stream_delta(payload)
|
||||
reasoning = _reasoning_text(delta)
|
||||
content = _message_text(delta.get("content"))
|
||||
if reasoning and on_delta:
|
||||
on_delta("reasoning", reasoning)
|
||||
if content:
|
||||
parts.append(content)
|
||||
if on_delta:
|
||||
on_delta("content", content)
|
||||
except HTTPError as exc:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
raise RuntimeError(
|
||||
f"LLM HTTP {exc.code} from {self.config.chat_completions_url()}: {detail}"
|
||||
) from exc
|
||||
except (OSError, URLError) as exc:
|
||||
raise RuntimeError(
|
||||
f"LLM request failed: {self.config.chat_completions_url()}"
|
||||
) from exc
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _message_text(value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
|
||||
def _reasoning_text(message_or_delta: dict[str, Any]) -> str:
|
||||
for key in ("reasoning_content", "reasoning", "reasoning_text"):
|
||||
text = _message_text(message_or_delta.get(key))
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
|
||||
|
||||
def _stream_delta(payload: dict[str, Any]) -> dict[str, Any]:
|
||||
try:
|
||||
delta = payload["choices"][0].get("delta") or {}
|
||||
except (KeyError, IndexError, TypeError) as exc:
|
||||
raise RuntimeError(f"LLM returned unexpected stream payload: {payload}") from exc
|
||||
if not isinstance(delta, dict):
|
||||
raise RuntimeError(f"LLM returned unexpected stream delta: {payload}")
|
||||
return delta
|
||||
|
||||
|
||||
def _iter_sse_payloads(response: Any) -> Iterator[str]:
|
||||
"""Yield ``data:`` payloads from an OpenAI-compatible SSE response."""
|
||||
data_lines: list[str] = []
|
||||
for raw in response:
|
||||
line = raw.decode("utf-8", errors="replace").rstrip("\r\n")
|
||||
if line == "":
|
||||
if data_lines:
|
||||
yield "\n".join(data_lines)
|
||||
data_lines = []
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
data_lines.append(line[5:].lstrip())
|
||||
if data_lines:
|
||||
yield "\n".join(data_lines)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMSession:
|
||||
"""Chat session bound to a single hand.
|
||||
|
||||
The system prompt is fixed for the whole match while the user/assistant
|
||||
exchange is reset every time a new hand begins (i.e. when ``/game`` is
|
||||
received). Storing recent assistant turns lets the model maintain
|
||||
intra-hand continuity without re-paying for the long table snapshot.
|
||||
"""
|
||||
|
||||
system_prompt: str
|
||||
game_id: str
|
||||
player_id: str
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
hand_number: int = 0
|
||||
|
||||
def reset_with_game(self, hand_number: int, game_user_message: str) -> None:
|
||||
"""Start a fresh exchange for a new hand."""
|
||||
self.hand_number = hand_number
|
||||
self.messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": game_user_message},
|
||||
]
|
||||
|
||||
def append_user(self, content: str) -> None:
|
||||
self.messages.append({"role": "user", "content": content})
|
||||
|
||||
def append_assistant(self, content: str) -> None:
|
||||
self.messages.append({"role": "assistant", "content": content})
|
||||
|
||||
def chat_messages(self) -> list[dict[str, Any]]:
|
||||
# Always include the system prompt; if reset_with_game has not been
|
||||
# called yet (e.g. /act arrives before /game), we still want a
|
||||
# legal request to go through.
|
||||
if not self.messages:
|
||||
return [{"role": "system", "content": self.system_prompt}]
|
||||
return list(self.messages)
|
||||
|
||||
|
||||
class SessionRegistry:
|
||||
"""Tiny LRU-style registry keyed by ``(game_id, player_id)``.
|
||||
|
||||
Multiple parallel games or multiple seats served by the same process
|
||||
each need an isolated chat history; the registry provides exactly that
|
||||
while bounding memory.
|
||||
"""
|
||||
|
||||
def __init__(self, system_prompt: str, max_sessions: int = 64) -> None:
|
||||
self._system_prompt = system_prompt
|
||||
self._sessions: OrderedDict[tuple[str, str], LLMSession] = OrderedDict()
|
||||
self._max = max_sessions
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_or_create(self, game_id: str, player_id: str) -> LLMSession:
|
||||
with self._lock:
|
||||
key = (game_id, player_id)
|
||||
session = self._sessions.get(key)
|
||||
if session is None:
|
||||
session = LLMSession(
|
||||
system_prompt=self._system_prompt,
|
||||
game_id=game_id,
|
||||
player_id=player_id,
|
||||
)
|
||||
self._sessions[key] = session
|
||||
# Drop the oldest if we exceed the cap; LLM context is
|
||||
# expensive but we never need stale game histories.
|
||||
while len(self._sessions) > self._max:
|
||||
self._sessions.popitem(last=False)
|
||||
else:
|
||||
self._sessions.move_to_end(key)
|
||||
return session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Action parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_JSON_OBJECT_RE = re.compile(r"\{[\s\S]*\}")
|
||||
|
||||
|
||||
def parse_action_reply(reply: str) -> dict[str, Any]:
|
||||
"""Extract the action JSON from a possibly chatty LLM response.
|
||||
|
||||
LLMs occasionally wrap JSON in markdown fences or add a sentence of
|
||||
chatter despite explicit instructions. We pluck the first ``{...}``
|
||||
block and parse it; downstream code (engine ``_coerce_action``) will
|
||||
sanitise illegal values, so we do not need to validate ranges here.
|
||||
"""
|
||||
if not isinstance(reply, str) or not reply.strip():
|
||||
raise ValueError("empty LLM reply")
|
||||
match = _JSON_OBJECT_RE.search(reply)
|
||||
if match is None:
|
||||
raise ValueError(f"no JSON object found in LLM reply: {reply!r}")
|
||||
try:
|
||||
payload = json.loads(match.group(0))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"invalid JSON in LLM reply: {reply!r}") from exc
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError(f"LLM reply was not a JSON object: {reply!r}")
|
||||
return {
|
||||
"action": str(payload.get("action") or "fold"),
|
||||
"amount": int(payload.get("amount") or 0),
|
||||
}
|
||||
|
||||
|
||||
def fallback_action(observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Pick a safe legal action when the LLM call fails entirely.
|
||||
|
||||
Order of preference: ``check`` > ``call`` (cheapest) > ``fold``.
|
||||
"""
|
||||
legal = observation.get("legal_actions") or []
|
||||
by_name = {item.get("action"): item for item in legal}
|
||||
if "check" in by_name:
|
||||
return {"action": "check", "amount": 0}
|
||||
if "call" in by_name:
|
||||
call = by_name["call"]
|
||||
return {"action": "call", "amount": int(call.get("amount") or 0)}
|
||||
if "fold" in by_name:
|
||||
return {"action": "fold", "amount": 0}
|
||||
# Last resort: echo the first legal action as-is.
|
||||
if legal:
|
||||
first = legal[0]
|
||||
return {
|
||||
"action": str(first.get("action")),
|
||||
"amount": int(first.get("amount") or 0),
|
||||
}
|
||||
return {"action": "fold", "amount": 0}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP service
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AIAgentService:
|
||||
"""Glues the LLM client, prompt library and session registry together.
|
||||
|
||||
Exposed as a single object so that the HTTP handler stays thin and
|
||||
purely concerned with request parsing and response framing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: LLMClient,
|
||||
prompts: PromptLibrary,
|
||||
registry: SessionRegistry | None = None,
|
||||
console: AIAgentConsole | None = None,
|
||||
) -> None:
|
||||
self.llm = llm
|
||||
self.prompts = prompts
|
||||
self.console = console
|
||||
# The system prompt is read once and shared across all sessions to
|
||||
# avoid stale copies if the operator hot-edits the markdown file
|
||||
# mid-game (intentional: restart the agent to pick up changes).
|
||||
system_prompt = self.prompts.load("system")
|
||||
self.registry = registry or SessionRegistry(system_prompt=system_prompt)
|
||||
|
||||
def handle_game(self, game_state: dict[str, Any], player_id: str) -> None:
|
||||
"""Open or refresh the per-hand session."""
|
||||
if self.console:
|
||||
self.console.announce_game(game_state)
|
||||
game_id = str(game_state.get("game_id") or "")
|
||||
hand_number = int(game_state.get("hand_number") or 0)
|
||||
session = self.registry.get_or_create(game_id, player_id)
|
||||
session.reset_with_game(
|
||||
hand_number=hand_number,
|
||||
game_user_message=render_game_start_prompt(self.prompts, game_state),
|
||||
)
|
||||
|
||||
def handle_act(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Render the prompt, call the LLM, parse the reply."""
|
||||
if self.console:
|
||||
with self.console.act_log(observation):
|
||||
return self._handle_act_locked(observation)
|
||||
return self._handle_act_locked(observation)
|
||||
|
||||
def _handle_act_locked(self, observation: dict[str, Any]) -> dict[str, Any]:
|
||||
game_id = str(observation.get("game_id") or "")
|
||||
player_id = str(observation.get("player_id") or "")
|
||||
session = self.registry.get_or_create(game_id, player_id)
|
||||
# Ensure we never silently drop the system prompt even if /game
|
||||
# arrives after /act (e.g. process started mid-hand on a restart).
|
||||
if not session.messages:
|
||||
session.messages = [
|
||||
{"role": "system", "content": session.system_prompt}
|
||||
]
|
||||
user_msg = render_observation_prompt(self.prompts, observation)
|
||||
session.append_user(user_msg)
|
||||
|
||||
try:
|
||||
if self.console:
|
||||
self.console.begin_llm_stream()
|
||||
assistant_text = self.llm.chat(
|
||||
session.chat_messages(),
|
||||
on_delta=self.console.write_llm_delta if self.console else None,
|
||||
)
|
||||
if self.console:
|
||||
self.console.end_llm_stream()
|
||||
except RuntimeError as exc:
|
||||
# On any LLM error, use the safe fallback. We also drop the
|
||||
# last user message to avoid contaminating the next turn with
|
||||
# a request that produced no assistant reply.
|
||||
session.messages.pop()
|
||||
action = fallback_action(observation)
|
||||
if self.console:
|
||||
self.console.announce_warning(str(exc))
|
||||
self.console.announce_action(action, source="fallback")
|
||||
return action
|
||||
|
||||
session.append_assistant(assistant_text)
|
||||
try:
|
||||
action = parse_action_reply(assistant_text)
|
||||
if self.console:
|
||||
self.console.announce_action(action, source="model")
|
||||
return action
|
||||
except ValueError as exc:
|
||||
# Reply was unparseable; keep it in the history so the LLM can
|
||||
# see what it did wrong on the next turn (no extra prompt
|
||||
# needed - the next observation will fill that role) and
|
||||
# answer with a safe action this turn.
|
||||
action = fallback_action(observation)
|
||||
if self.console:
|
||||
self.console.announce_warning(str(exc))
|
||||
self.console.announce_action(action, source="fallback")
|
||||
return action
|
||||
|
||||
|
||||
def _bind_player_id(handler: BaseHTTPRequestHandler) -> str:
|
||||
"""Resolve which seat a request belongs to.
|
||||
|
||||
The standalone process serves *one* AI seat by default. Multi-seat
|
||||
deployments can pass ``X-Player-Id`` so the registry can keep the
|
||||
sessions isolated.
|
||||
"""
|
||||
explicit = handler.headers.get("X-Player-Id")
|
||||
if explicit:
|
||||
return explicit
|
||||
return getattr(handler.server, "default_player_id", "ai") # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class AIRequestHandler(BaseHTTPRequestHandler):
|
||||
"""HTTP entry point for the AI agent.
|
||||
|
||||
Routes:
|
||||
- ``GET /health`` - liveness probe.
|
||||
- ``POST /game`` - new hand boundary; opens a fresh session.
|
||||
- ``POST /act`` - returns the AI-decided action.
|
||||
"""
|
||||
|
||||
server_version = "TexasHoldemAIAgent/0.1"
|
||||
service: AIAgentService # injected by ``create_server``
|
||||
|
||||
def do_GET(self) -> None:
|
||||
if self.path == "/health":
|
||||
self._json({"ok": True})
|
||||
return
|
||||
self._json({"error": "not found"}, HTTPStatus.NOT_FOUND)
|
||||
|
||||
def do_POST(self) -> None:
|
||||
routes = {
|
||||
"/game": self._handle_game,
|
||||
"/act": self._handle_act,
|
||||
}
|
||||
handler = routes.get(self.path)
|
||||
if handler is None:
|
||||
self._json({"error": "not found"}, HTTPStatus.NOT_FOUND)
|
||||
return
|
||||
|
||||
try:
|
||||
payload = self._read_json()
|
||||
except ValueError as exc:
|
||||
self._json({"error": str(exc)}, HTTPStatus.BAD_REQUEST)
|
||||
return
|
||||
|
||||
try:
|
||||
handler(payload)
|
||||
except Exception as exc: # pragma: no cover - last-resort guard
|
||||
self._json({"error": str(exc)}, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
def _handle_game(self, payload: dict[str, Any]) -> None:
|
||||
self.service.handle_game(payload, player_id=_bind_player_id(self))
|
||||
self._empty(HTTPStatus.NO_CONTENT)
|
||||
|
||||
def _handle_act(self, payload: dict[str, Any]) -> None:
|
||||
action = self.service.handle_act(payload)
|
||||
self._json(action)
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
|
||||
return
|
||||
|
||||
def _read_json(self) -> dict[str, Any]:
|
||||
length = int(self.headers.get("Content-Length", "0"))
|
||||
if length <= 0:
|
||||
raise ValueError("request body is required")
|
||||
try:
|
||||
payload = json.loads(self.rfile.read(length).decode("utf-8"))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("request body must be valid JSON") from exc
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("request body must be a JSON object")
|
||||
return payload
|
||||
|
||||
def _json(
|
||||
self, payload: dict[str, Any], status: HTTPStatus = HTTPStatus.OK
|
||||
) -> None:
|
||||
body = json.dumps(payload, ensure_ascii=True).encode("utf-8")
|
||||
self.send_response(status)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.send_header("Content-Length", str(len(body)))
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
|
||||
def _empty(self, status: HTTPStatus) -> None:
|
||||
self.send_response(status)
|
||||
self.send_header("Content-Length", "0")
|
||||
self.end_headers()
|
||||
|
||||
|
||||
def create_server(
|
||||
host: str,
|
||||
port: int,
|
||||
service: AIAgentService,
|
||||
default_player_id: str = "ai",
|
||||
) -> ThreadingHTTPServer:
|
||||
"""Build and configure the HTTP server.
|
||||
|
||||
The ``service`` and ``default_player_id`` are attached to the server /
|
||||
handler classes so that all worker threads share a single instance,
|
||||
which in turn means a single registry of sessions.
|
||||
"""
|
||||
server = ThreadingHTTPServer((host, port), AIRequestHandler)
|
||||
AIRequestHandler.service = service
|
||||
server.default_player_id = default_player_id # type: ignore[attr-defined]
|
||||
return server
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Run an OpenAI-compatible AI poker agent that exposes "
|
||||
"POST /act and POST /game."
|
||||
)
|
||||
)
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", default=9101, type=int)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
default=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
||||
help="OpenAI-compatible base URL (default: $OPENAI_BASE_URL or "
|
||||
"https://api.openai.com/v1).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
default=os.environ.get("OPENAI_API_KEY", ""),
|
||||
help="API key (default: $OPENAI_API_KEY).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
|
||||
help="Model identifier (default: $OPENAI_MODEL or gpt-4o-mini).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
default=60.0,
|
||||
type=float,
|
||||
help="LLM request timeout in seconds.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
default=0.4,
|
||||
type=float,
|
||||
help="LLM sampling temperature.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--player-id",
|
||||
default="ai",
|
||||
help=(
|
||||
"Default player_id used to key sessions. Override per-request "
|
||||
"via the X-Player-Id header for multi-seat setups."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompts-dir",
|
||||
default=str(PROMPTS_DIR),
|
||||
help="Directory containing the prompt markdown templates.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-history",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Keep previous terminal output when a new /act request arrives "
|
||||
"instead of clearing the screen."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Disable streaming Chat Completions requests.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-color",
|
||||
action="store_true",
|
||||
help="Disable ANSI gray coloring for streamed LLM output.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.api_key:
|
||||
parser.error("--api-key (or OPENAI_API_KEY) is required")
|
||||
|
||||
config = LLMConfig(
|
||||
base_url=args.base_url,
|
||||
api_key=args.api_key,
|
||||
model=args.model,
|
||||
timeout_seconds=args.timeout,
|
||||
temperature=args.temperature,
|
||||
stream=not args.no_stream,
|
||||
)
|
||||
prompts = PromptLibrary(directory=Path(args.prompts_dir))
|
||||
console = AIAgentConsole(
|
||||
keep_history=args.keep_history,
|
||||
use_color=not args.no_color,
|
||||
)
|
||||
service = AIAgentService(LLMClient(config), prompts, console=console)
|
||||
server = create_server(args.host, args.port, service, default_player_id=args.player_id)
|
||||
|
||||
print(
|
||||
f"AI HTTP agent listening on http://{args.host}:{args.port}\n"
|
||||
f" POST /act - decision request\n"
|
||||
f" POST /game - new-hand snapshot (opens a fresh session)\n"
|
||||
f" model : {config.model}\n"
|
||||
f" base_url : {config.base_url}\n"
|
||||
f" player_id : {args.player_id}\n"
|
||||
f" stream : {'on' if config.stream else 'off'}\n"
|
||||
f" clear-screen: {'off (keep history)' if args.keep_history else 'on'}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
try:
|
||||
server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
server.server_close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user