1135 lines
42 KiB
Python
1135 lines
42 KiB
Python
"""Standalone HTTP AI poker agent backed by an OpenAI-compatible LLM.
|
|
|
|
Run as a process exposing two endpoints that the Texas Hold'em service
|
|
calls:
|
|
|
|
* ``POST /game`` - delivered at the start of every new hand. We use it as
|
|
the boundary that opens a fresh chat session for that hand and seeds it
|
|
with a human-readable rendering of the table snapshot (history of past
|
|
hands included).
|
|
* ``POST /act`` - the per-decision request. We render the observation
|
|
with a templated user prompt, ask the configured LLM, parse the JSON
|
|
reply and return it to the server.
|
|
|
|
Run::
|
|
|
|
python -m texas_holdem.ai_client \\
|
|
--host 127.0.0.1 --port 9101 \\
|
|
--base-url https://api.openai.com/v1 \\
|
|
--api-key $OPENAI_API_KEY \\
|
|
--model gpt-4o-mini
|
|
|
|
Hook it up by passing the *base* URL when creating the game::
|
|
|
|
{"id": "ai", "name": "AI",
|
|
"agent": {"type": "http", "endpoint": "http://127.0.0.1:9101"}}
|
|
|
|
Design notes:
|
|
- Prompts live in ``texas_holdem/prompts/*.md`` so non-engineers can edit
|
|
them without touching code; we read them once at boot via
|
|
:class:`PromptLibrary`.
|
|
- :class:`LLMSession` owns the chat history for a single (game_id, hand)
|
|
scope and is reset by every ``/game`` push, matching the user's mental
|
|
model that one hand == one session.
|
|
- :class:`SessionRegistry` keeps a tiny LRU per ``(game_id, player_id)``
|
|
so the same process can serve multiple parallel games / seats.
|
|
- The LLM client speaks OpenAI's ``/v1/chat/completions`` schema, which
|
|
is what virtually every OpenAI-compatible provider implements.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
import threading
|
|
from collections import OrderedDict
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from http import HTTPStatus
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from pathlib import Path
|
|
from typing import IO, Any, Callable, Iterator
|
|
from urllib.error import HTTPError, URLError
|
|
from urllib.request import Request, urlopen
|
|
|
|
from texas_holdem.human_io import clear_screen, render_game_state, render_observation
|
|
|
|
# Default location of the prompt templates. Living next to this module lets
|
|
# operators edit them in place without re-installing the package.
|
|
PROMPTS_DIR = Path(__file__).resolve().parent / "prompts"
|
|
ANSI_GRAY = "\x1b[90m"
|
|
ANSI_RESET = "\x1b[0m"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Terminal diagnostics
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class _ThinkingIndicator:
|
|
"""Animated "thinking..." marquee for the AI agent console.
|
|
|
|
Design rationale:
|
|
- Encapsulated as its own class so the animation lifecycle (timer
|
|
thread, frame state, screen erase sequence) does not pollute the
|
|
surrounding console class.
|
|
- Runs in a daemon background thread driven by ``threading.Event`` so
|
|
``stop`` returns promptly even if the current frame is mid-sleep.
|
|
- Uses ANSI ``\\r`` plus a clearing escape sequence to overwrite the
|
|
previous frame in place, avoiding scrollback noise. The frames
|
|
cycle through 0/1/2/3 dots every 0.5s as requested.
|
|
- ``start``/``stop`` are idempotent so the higher-level console can
|
|
call ``stop`` defensively (e.g. on the fallback path) without
|
|
tracking whether a marquee is actually running.
|
|
"""
|
|
|
|
# Frame interval in seconds; matches the user-visible cadence.
|
|
_FRAME_INTERVAL = 0.5
|
|
# 0..3 dots, looping.
|
|
_FRAMES = ("thinking", "thinking.", "thinking..", "thinking...")
|
|
# ANSI escape that clears from the cursor to the end of the line; we
|
|
# combine it with a leading carriage return to redraw the frame in
|
|
# place.
|
|
_ERASE_LINE = "\r\x1b[K"
|
|
|
|
def __init__(
|
|
self,
|
|
write_fn: Callable[[str], None],
|
|
gray_fn: Callable[[str], str],
|
|
) -> None:
|
|
self._write = write_fn
|
|
self._gray = gray_fn
|
|
self._stop_event = threading.Event()
|
|
self._thread: threading.Thread | None = None
|
|
# ``_active`` reflects whether a frame is currently visible on
|
|
# screen; ``stop`` uses it to decide whether to emit the final
|
|
# erase sequence.
|
|
self._active = False
|
|
# Guard against concurrent start/stop calls from different
|
|
# threads (e.g. content-delta handler vs. end_llm_stream).
|
|
self._lifecycle_lock = threading.Lock()
|
|
|
|
def start(self) -> None:
|
|
"""Begin the marquee in a background thread.
|
|
|
|
Calling ``start`` while already running is a no-op.
|
|
"""
|
|
with self._lifecycle_lock:
|
|
if self._thread is not None and self._thread.is_alive():
|
|
return
|
|
self._stop_event.clear()
|
|
self._active = True
|
|
thread = threading.Thread(
|
|
target=self._run,
|
|
name="ai-thinking-indicator",
|
|
daemon=True,
|
|
)
|
|
self._thread = thread
|
|
thread.start()
|
|
|
|
def stop(self) -> None:
|
|
"""Stop the marquee and erase the current frame from the screen.
|
|
|
|
Safe to call when not running.
|
|
"""
|
|
with self._lifecycle_lock:
|
|
thread = self._thread
|
|
if thread is None:
|
|
return
|
|
self._stop_event.set()
|
|
self._thread = None
|
|
# Wait for the worker outside the lifecycle lock so an in-flight
|
|
# ``_render_frame`` cannot deadlock against ``start`` from
|
|
# another thread.
|
|
thread.join()
|
|
if self._active:
|
|
# Wipe the last frame so the model's actual content begins on
|
|
# a clean line.
|
|
self._write(self._ERASE_LINE)
|
|
self._active = False
|
|
|
|
def _run(self) -> None:
|
|
"""Background loop: redraw the next frame every ``_FRAME_INTERVAL``."""
|
|
index = 0
|
|
while not self._stop_event.is_set():
|
|
self._render_frame(self._FRAMES[index % len(self._FRAMES)])
|
|
index += 1
|
|
# ``Event.wait`` returns immediately when ``set`` is called,
|
|
# so ``stop`` is responsive even mid-frame.
|
|
if self._stop_event.wait(self._FRAME_INTERVAL):
|
|
return
|
|
|
|
def _render_frame(self, label: str) -> None:
|
|
"""Emit one frame in place using carriage-return + erase-EOL."""
|
|
self._write(f"{self._ERASE_LINE}{self._gray(label)}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# AI agent console
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class AIAgentConsole:
|
|
"""Serialised terminal output for the standalone AI agent.
|
|
|
|
The behaviour mirrors :class:`HumanClientConsole`: by default each
|
|
decision clears the terminal first, while ``keep_history=True`` leaves
|
|
previous game/act logs in scrollback. The LLM stream is printed in gray
|
|
so model output stays visually separate from game state and final
|
|
actions.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
output_stream: IO[str] | None = None,
|
|
keep_history: bool = False,
|
|
use_color: bool = True,
|
|
show_reasoning: bool = True,
|
|
) -> None:
|
|
self._output = output_stream if output_stream is not None else sys.stdout
|
|
self._keep_history = keep_history
|
|
self._use_color = use_color
|
|
# ``show_reasoning`` controls whether the LLM's chain-of-thought
|
|
# ("reasoning") deltas are printed to the terminal. The final
|
|
# answer ("content") is always printed so operators can still see
|
|
# the action being chosen.
|
|
self._show_reasoning = show_reasoning
|
|
# ``_lock`` serialises whole act/game render blocks (coarse grain).
|
|
# ``_io_lock`` is a finer-grained mutex protecting just the
|
|
# ``self._output.write`` calls so the thinking-indicator background
|
|
# thread can interleave safely with the main rendering thread
|
|
# without being blocked by the coarse lock.
|
|
self._lock = threading.Lock()
|
|
self._io_lock = threading.Lock()
|
|
# Animated "thinking..." marquee shown while reasoning output is
|
|
# suppressed. Created up-front so callers can ``start``/``stop``
|
|
# idempotently regardless of the show_reasoning flag.
|
|
self._thinking = _ThinkingIndicator(
|
|
write_fn=self._write,
|
|
gray_fn=self._gray,
|
|
)
|
|
|
|
@contextmanager
|
|
def act_log(self, observation: dict[str, Any]) -> Iterator[None]:
|
|
"""Render one received ``/act`` payload and hold the console lock."""
|
|
with self._lock:
|
|
if not self._keep_history:
|
|
clear_screen(self._write)
|
|
self._write(render_observation(observation))
|
|
yield
|
|
|
|
def announce_game(self, game_state: dict[str, Any]) -> None:
|
|
"""Render one received ``/game`` payload."""
|
|
with self._lock:
|
|
self._write(render_game_state(game_state))
|
|
|
|
def begin_llm_stream(self) -> None:
|
|
self._write(self._gray("AI MODEL STREAM\n"))
|
|
# When reasoning output is hidden, immediately start the marquee
|
|
# so the user sees liveness while the model is "thinking" before
|
|
# any content delta arrives.
|
|
if not self._show_reasoning:
|
|
self._thinking.start()
|
|
|
|
def write_llm_delta(self, kind: str, text: str) -> None:
|
|
if not text:
|
|
return
|
|
# Skip "reasoning" deltas entirely when reasoning output is hidden;
|
|
# this keeps the terminal focused on the final answer for users
|
|
# who do not care about chain-of-thought traces.
|
|
if kind == "reasoning" and not self._show_reasoning:
|
|
return
|
|
# First non-reasoning delta means the model has started speaking
|
|
# the actual answer; tear down the marquee before printing so the
|
|
# animation does not collide with the content stream.
|
|
if kind == "content" and not self._show_reasoning:
|
|
self._thinking.stop()
|
|
self._write(self._gray(text))
|
|
|
|
def end_llm_stream(self) -> None:
|
|
# Defensive stop in case the request finished without ever
|
|
# producing a content delta (e.g. fallback path / error).
|
|
if not self._show_reasoning:
|
|
self._thinking.stop()
|
|
self._write(self._gray("\n"))
|
|
|
|
def announce_action(
|
|
self,
|
|
action: dict[str, Any],
|
|
source: str = "model",
|
|
) -> None:
|
|
# Defensive stop: error / fallback paths bypass end_llm_stream, so
|
|
# we ensure the marquee never leaks into action / warning output.
|
|
self._thinking.stop()
|
|
body = json.dumps(action, ensure_ascii=False)
|
|
self._write(f"\nAI ACTION ({source}) -> {body}\n")
|
|
self._write("~" * 60 + "\n\n")
|
|
|
|
def announce_warning(self, message: str) -> None:
|
|
# Same defensive stop as ``announce_action`` - warnings can fire
|
|
# before the LLM stream closes (HTTP error, JSON parse error...).
|
|
self._thinking.stop()
|
|
self._write(f"\nAI WARNING -> {message}\n")
|
|
|
|
def _gray(self, text: str) -> str:
|
|
if not self._use_color:
|
|
return text
|
|
return f"{ANSI_GRAY}{text}{ANSI_RESET}"
|
|
|
|
def _write(self, text: str) -> None:
|
|
# The thinking-indicator background thread writes from a different
|
|
# thread than the main /act handler; the fine-grained ``_io_lock``
|
|
# avoids tearing of escape sequences and keeps stdout consistent.
|
|
with self._io_lock:
|
|
self._output.write(text)
|
|
self._output.flush()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Prompt rendering
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class PromptLibrary:
|
|
"""Loads and renders prompt templates from disk.
|
|
|
|
Templates use Python ``str.format`` placeholders. Centralising the
|
|
render step here keeps the LLM-facing wording out of the code base and
|
|
avoids ad-hoc string concatenation in the agent logic.
|
|
"""
|
|
|
|
def __init__(self, directory: Path = PROMPTS_DIR) -> None:
|
|
self.directory = directory
|
|
# Cache prompts by name; they are a couple of KB and read-only at
|
|
# runtime, so we trade a few bytes of memory for zero disk IO per
|
|
# request.
|
|
self._cache: dict[str, str] = {}
|
|
|
|
def load(self, name: str) -> str:
|
|
if name not in self._cache:
|
|
path = self.directory / f"{name}.md"
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"missing prompt template: {path}")
|
|
self._cache[name] = path.read_text(encoding="utf-8")
|
|
return self._cache[name]
|
|
|
|
def render(self, name: str, **fields: Any) -> str:
|
|
"""Return the named template with ``fields`` substituted via format()."""
|
|
template = self.load(name)
|
|
return template.format(**fields)
|
|
|
|
|
|
def render_game_start_prompt(library: PromptLibrary, game_state: dict[str, Any]) -> str:
|
|
"""Build the opening user-message describing a new hand.
|
|
|
|
Pulls out the heavy formatting (player rows, hand history) into helpers
|
|
so the template itself stays declarative.
|
|
"""
|
|
return library.render(
|
|
"game_start",
|
|
game_id=game_state.get("game_id"),
|
|
hand_number=game_state.get("hand_number"),
|
|
status=game_state.get("status"),
|
|
small_blind=game_state.get("small_blind"),
|
|
big_blind=game_state.get("big_blind"),
|
|
button_seat=game_state.get("button_seat"),
|
|
starting_stack=game_state.get("starting_stack"),
|
|
players_block=_format_players_block(game_state.get("players") or []),
|
|
hand_count=len(game_state.get("hands") or []),
|
|
history_block=_format_history_block(game_state.get("hands") or []),
|
|
)
|
|
|
|
|
|
def render_observation_prompt(
|
|
library: PromptLibrary, observation: dict[str, Any]
|
|
) -> str:
|
|
"""Build the per-decision user-message from a server observation dict."""
|
|
legal_actions = list(observation.get("legal_actions") or [])
|
|
you = _find_self_player(observation)
|
|
return library.render(
|
|
"observation",
|
|
hand_number=observation.get("hand_number"),
|
|
street=observation.get("street"),
|
|
player_id=observation.get("player_id"),
|
|
player_name=you.get("name") if you else observation.get("player_id"),
|
|
seat=observation.get("seat"),
|
|
button_seat=observation.get("button_seat"),
|
|
pot=observation.get("pot"),
|
|
to_call=observation.get("to_call"),
|
|
min_raise_to=observation.get("min_raise_to"),
|
|
amount_mode=observation.get("amount_mode") or "street_total",
|
|
hole_cards=_format_card_list(observation.get("hole_cards")),
|
|
board=_format_card_list(observation.get("board")),
|
|
players_block=_format_players_block(observation.get("players") or []),
|
|
action_history_block=_format_action_history(
|
|
observation.get("action_history") or []
|
|
),
|
|
legal_actions_block=_format_legal_actions(legal_actions),
|
|
)
|
|
|
|
|
|
def _find_self_player(observation: dict[str, Any]) -> dict[str, Any] | None:
|
|
"""Locate the acting player's row inside the observation snapshot."""
|
|
pid = observation.get("player_id")
|
|
for player in observation.get("players") or []:
|
|
if player.get("player_id") == pid:
|
|
return player
|
|
return None
|
|
|
|
|
|
def _format_card_list(cards: Any) -> str:
|
|
"""Render a list of card labels for prompts; never returns an empty string."""
|
|
if not cards:
|
|
return "(none)"
|
|
return " ".join(str(card) for card in cards)
|
|
|
|
|
|
def _format_players_block(players: list[dict[str, Any]]) -> str:
|
|
"""Render the per-seat status table used in both prompt templates."""
|
|
if not players:
|
|
return "(no players)"
|
|
rows: list[str] = []
|
|
for player in players:
|
|
flags = []
|
|
if player.get("folded"):
|
|
flags.append("folded")
|
|
if player.get("all_in"):
|
|
flags.append("all_in")
|
|
if not player.get("in_hand"):
|
|
flags.append("out")
|
|
flag_text = ",".join(flags) if flags else "active"
|
|
rows.append(
|
|
f"- seat {player.get('seat')}: id={player.get('player_id')}, "
|
|
f"name={player.get('name')}, stack={player.get('stack')}, "
|
|
f"street_bet={player.get('street_bet', 0)}, "
|
|
f"total_bet={player.get('total_bet', 0)}, status={flag_text}"
|
|
)
|
|
return "\n".join(rows)
|
|
|
|
|
|
def _format_history_block(hands: list[dict[str, Any]]) -> str:
|
|
"""Render a compact digest of finished hands for the GAME_START message."""
|
|
if not hands:
|
|
return "(no hands played yet)"
|
|
digests: list[str] = []
|
|
for hand in hands:
|
|
awards = hand.get("awards") or []
|
|
winner_lines = [
|
|
f" pot {a.get('amount')}: -> {','.join(a.get('winners') or [])} "
|
|
f"({(a.get('hand_value') or {}).get('name', '-')})"
|
|
for a in awards
|
|
]
|
|
showdown = hand.get("showdown_hands") or {}
|
|
showdown_lines = [
|
|
f" showdown {pid}: {' '.join(cards)}"
|
|
for pid, cards in showdown.items()
|
|
]
|
|
digests.append(
|
|
"\n".join(
|
|
[
|
|
f"- Hand #{hand.get('hand_number')} "
|
|
f"(button seat {hand.get('button_seat')}), "
|
|
f"board: {_format_card_list(hand.get('board'))}",
|
|
*winner_lines,
|
|
*showdown_lines,
|
|
]
|
|
)
|
|
)
|
|
return "\n".join(digests)
|
|
|
|
|
|
def _format_action_history(history: list[dict[str, Any]]) -> str:
|
|
"""Render the per-action log; trims very old entries to keep prompts cheap."""
|
|
if not history:
|
|
return "(no actions yet)"
|
|
# The engine never produces unbounded history within a single hand, but
|
|
# we cap defensively so a malformed payload cannot blow up token usage.
|
|
rows = []
|
|
for record in history[-32:]:
|
|
rows.append(
|
|
f"- [{record.get('street')}] {record.get('player_id')} -> "
|
|
f"{record.get('action')} amount={record.get('amount', 0)}"
|
|
)
|
|
return "\n".join(rows)
|
|
|
|
|
|
def _format_legal_actions(legal: list[dict[str, Any]]) -> str:
|
|
"""Render a numbered list of legal actions including amount ranges."""
|
|
if not legal:
|
|
return "(no legal actions)"
|
|
rows: list[str] = []
|
|
for index, action in enumerate(legal, start=1):
|
|
amount = action.get("amount")
|
|
if action.get("action") in {"bet", "raise"}:
|
|
rows.append(
|
|
f"{index}. {action['action']}: street_total in "
|
|
f"[{action.get('min_amount')}, {action.get('max_amount')}]"
|
|
)
|
|
else:
|
|
rows.append(f"{index}. {action['action']} (amount={amount})")
|
|
return "\n".join(rows)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OpenAI-compatible chat client
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class LLMConfig:
|
|
"""Static configuration for the OpenAI-compatible provider.
|
|
|
|
A small dataclass keeps the constructor wiring readable and lets us add
|
|
fields (e.g. organisation id) without touching every call site.
|
|
"""
|
|
|
|
base_url: str
|
|
api_key: str
|
|
model: str
|
|
timeout_seconds: float = 60.0
|
|
temperature: float = 0.4
|
|
stream: bool = True
|
|
|
|
def chat_completions_url(self) -> str:
|
|
"""Return the canonical chat completions URL for the configured base."""
|
|
base = self.base_url.rstrip("/")
|
|
# Tolerate users passing the full ``/chat/completions`` path; only
|
|
# the OpenAI-style base URL is documented but mistakes are common.
|
|
if base.endswith("/chat/completions"):
|
|
return base
|
|
return f"{base}/chat/completions"
|
|
|
|
|
|
class LLMClient:
|
|
"""Thin wrapper around the OpenAI-compatible Chat Completions API.
|
|
|
|
Implemented with ``urllib`` to honour the project's "no third-party
|
|
dependency" constraint; swapping in ``httpx`` later would only touch
|
|
this class.
|
|
"""
|
|
|
|
def __init__(self, config: LLMConfig) -> None:
|
|
self.config = config
|
|
|
|
def chat(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
on_delta: Callable[[str, str], None] | None = None,
|
|
) -> str:
|
|
"""Send a chat completion request and return the assistant text."""
|
|
if self.config.stream:
|
|
return self._chat_stream(messages, on_delta)
|
|
return self._chat_once(messages, on_delta)
|
|
|
|
def _request(self, body: dict[str, Any]) -> Request:
|
|
body = json.dumps(
|
|
body
|
|
).encode("utf-8")
|
|
return Request(
|
|
self.config.chat_completions_url(),
|
|
data=body,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.config.api_key}",
|
|
},
|
|
method="POST",
|
|
)
|
|
|
|
def _chat_once(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
on_delta: Callable[[str, str], None] | None = None,
|
|
) -> str:
|
|
request = self._request(
|
|
{
|
|
"model": self.config.model,
|
|
"messages": messages,
|
|
"temperature": self.config.temperature,
|
|
"stream": False,
|
|
}
|
|
)
|
|
try:
|
|
with urlopen(request, timeout=self.config.timeout_seconds) as resp:
|
|
payload = json.loads(resp.read().decode("utf-8"))
|
|
except HTTPError as exc:
|
|
detail = exc.read().decode("utf-8", errors="replace")
|
|
raise RuntimeError(
|
|
f"LLM HTTP {exc.code} from {self.config.chat_completions_url()}: {detail}"
|
|
) from exc
|
|
except (OSError, URLError) as exc:
|
|
raise RuntimeError(
|
|
f"LLM request failed: {self.config.chat_completions_url()}"
|
|
) from exc
|
|
try:
|
|
message = payload["choices"][0]["message"]
|
|
reasoning = _reasoning_text(message)
|
|
content = _message_text(message.get("content"))
|
|
if on_delta and reasoning:
|
|
on_delta("reasoning", reasoning)
|
|
if on_delta and content:
|
|
on_delta("content", content)
|
|
return content
|
|
except (KeyError, IndexError, TypeError) as exc:
|
|
raise RuntimeError(f"LLM returned unexpected payload: {payload}") from exc
|
|
|
|
def _chat_stream(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
on_delta: Callable[[str, str], None] | None = None,
|
|
) -> str:
|
|
request = self._request(
|
|
{
|
|
"model": self.config.model,
|
|
"messages": messages,
|
|
"temperature": self.config.temperature,
|
|
"stream": True,
|
|
}
|
|
)
|
|
parts: list[str] = []
|
|
try:
|
|
with urlopen(request, timeout=self.config.timeout_seconds) as resp:
|
|
for event in _iter_sse_payloads(resp):
|
|
if event == "[DONE]":
|
|
break
|
|
try:
|
|
payload = json.loads(event)
|
|
except json.JSONDecodeError as exc:
|
|
raise RuntimeError(
|
|
f"LLM returned invalid stream event: {event!r}"
|
|
) from exc
|
|
delta = _stream_delta(payload)
|
|
reasoning = _reasoning_text(delta)
|
|
content = _message_text(delta.get("content"))
|
|
if reasoning and on_delta:
|
|
on_delta("reasoning", reasoning)
|
|
if content:
|
|
parts.append(content)
|
|
if on_delta:
|
|
on_delta("content", content)
|
|
except HTTPError as exc:
|
|
detail = exc.read().decode("utf-8", errors="replace")
|
|
raise RuntimeError(
|
|
f"LLM HTTP {exc.code} from {self.config.chat_completions_url()}: {detail}"
|
|
) from exc
|
|
except (OSError, URLError) as exc:
|
|
raise RuntimeError(
|
|
f"LLM request failed: {self.config.chat_completions_url()}"
|
|
) from exc
|
|
return "".join(parts)
|
|
|
|
|
|
def _message_text(value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, str):
|
|
return value
|
|
return str(value)
|
|
|
|
|
|
def _reasoning_text(message_or_delta: dict[str, Any]) -> str:
|
|
for key in ("reasoning_content", "reasoning", "reasoning_text"):
|
|
text = _message_text(message_or_delta.get(key))
|
|
if text:
|
|
return text
|
|
return ""
|
|
|
|
|
|
def _stream_delta(payload: dict[str, Any]) -> dict[str, Any]:
|
|
try:
|
|
delta = payload["choices"][0].get("delta") or {}
|
|
except (KeyError, IndexError, TypeError) as exc:
|
|
raise RuntimeError(f"LLM returned unexpected stream payload: {payload}") from exc
|
|
if not isinstance(delta, dict):
|
|
raise RuntimeError(f"LLM returned unexpected stream delta: {payload}")
|
|
return delta
|
|
|
|
|
|
def _iter_sse_payloads(response: Any) -> Iterator[str]:
|
|
"""Yield ``data:`` payloads from an OpenAI-compatible SSE response."""
|
|
data_lines: list[str] = []
|
|
for raw in response:
|
|
line = raw.decode("utf-8", errors="replace").rstrip("\r\n")
|
|
if line == "":
|
|
if data_lines:
|
|
yield "\n".join(data_lines)
|
|
data_lines = []
|
|
continue
|
|
if line.startswith("data:"):
|
|
data_lines.append(line[5:].lstrip())
|
|
if data_lines:
|
|
yield "\n".join(data_lines)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class LLMSession:
|
|
"""Chat session bound to a single hand.
|
|
|
|
The system prompt is fixed for the whole match while the user/assistant
|
|
exchange is reset every time a new hand begins (i.e. when ``/game`` is
|
|
received). Storing recent assistant turns lets the model maintain
|
|
intra-hand continuity without re-paying for the long table snapshot.
|
|
"""
|
|
|
|
system_prompt: str
|
|
game_id: str
|
|
player_id: str
|
|
messages: list[dict[str, Any]] = field(default_factory=list)
|
|
hand_number: int = 0
|
|
|
|
def reset_with_game(self, hand_number: int, game_user_message: str) -> None:
|
|
"""Start a fresh exchange for a new hand."""
|
|
self.hand_number = hand_number
|
|
self.messages = [
|
|
{"role": "system", "content": self.system_prompt},
|
|
{"role": "user", "content": game_user_message},
|
|
]
|
|
|
|
def append_user(self, content: str) -> None:
|
|
self.messages.append({"role": "user", "content": content})
|
|
|
|
def append_assistant(self, content: str) -> None:
|
|
self.messages.append({"role": "assistant", "content": content})
|
|
|
|
def chat_messages(self) -> list[dict[str, Any]]:
|
|
# Always include the system prompt; if reset_with_game has not been
|
|
# called yet (e.g. /act arrives before /game), we still want a
|
|
# legal request to go through.
|
|
if not self.messages:
|
|
return [{"role": "system", "content": self.system_prompt}]
|
|
return list(self.messages)
|
|
|
|
|
|
class SessionRegistry:
|
|
"""Tiny LRU-style registry keyed by ``(game_id, player_id)``.
|
|
|
|
Multiple parallel games or multiple seats served by the same process
|
|
each need an isolated chat history; the registry provides exactly that
|
|
while bounding memory.
|
|
"""
|
|
|
|
def __init__(self, system_prompt: str, max_sessions: int = 64) -> None:
|
|
self._system_prompt = system_prompt
|
|
self._sessions: OrderedDict[tuple[str, str], LLMSession] = OrderedDict()
|
|
self._max = max_sessions
|
|
self._lock = threading.Lock()
|
|
|
|
def get_or_create(self, game_id: str, player_id: str) -> LLMSession:
|
|
with self._lock:
|
|
key = (game_id, player_id)
|
|
session = self._sessions.get(key)
|
|
if session is None:
|
|
session = LLMSession(
|
|
system_prompt=self._system_prompt,
|
|
game_id=game_id,
|
|
player_id=player_id,
|
|
)
|
|
self._sessions[key] = session
|
|
# Drop the oldest if we exceed the cap; LLM context is
|
|
# expensive but we never need stale game histories.
|
|
while len(self._sessions) > self._max:
|
|
self._sessions.popitem(last=False)
|
|
else:
|
|
self._sessions.move_to_end(key)
|
|
return session
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Action parsing
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
_JSON_OBJECT_RE = re.compile(r"\{[\s\S]*\}")
|
|
|
|
|
|
def parse_action_reply(reply: str) -> dict[str, Any]:
|
|
"""Extract the action JSON from a possibly chatty LLM response.
|
|
|
|
LLMs occasionally wrap JSON in markdown fences or add a sentence of
|
|
chatter despite explicit instructions. We pluck the first ``{...}``
|
|
block and parse it; downstream code (engine ``_coerce_action``) will
|
|
sanitise illegal values, so we do not need to validate ranges here.
|
|
"""
|
|
if not isinstance(reply, str) or not reply.strip():
|
|
raise ValueError("empty LLM reply")
|
|
match = _JSON_OBJECT_RE.search(reply)
|
|
if match is None:
|
|
raise ValueError(f"no JSON object found in LLM reply: {reply!r}")
|
|
try:
|
|
payload = json.loads(match.group(0))
|
|
except json.JSONDecodeError as exc:
|
|
raise ValueError(f"invalid JSON in LLM reply: {reply!r}") from exc
|
|
if not isinstance(payload, dict):
|
|
raise ValueError(f"LLM reply was not a JSON object: {reply!r}")
|
|
return {
|
|
"action": str(payload.get("action") or "fold"),
|
|
"amount": int(payload.get("amount") or 0),
|
|
}
|
|
|
|
|
|
def fallback_action(observation: dict[str, Any]) -> dict[str, Any]:
|
|
"""Pick a safe legal action when the LLM call fails entirely.
|
|
|
|
Order of preference: ``check`` > ``call`` (cheapest) > ``fold``.
|
|
"""
|
|
legal = observation.get("legal_actions") or []
|
|
by_name = {item.get("action"): item for item in legal}
|
|
if "check" in by_name:
|
|
return {"action": "check", "amount": 0}
|
|
if "call" in by_name:
|
|
call = by_name["call"]
|
|
return {"action": "call", "amount": int(call.get("amount") or 0)}
|
|
if "fold" in by_name:
|
|
return {"action": "fold", "amount": 0}
|
|
# Last resort: echo the first legal action as-is.
|
|
if legal:
|
|
first = legal[0]
|
|
return {
|
|
"action": str(first.get("action")),
|
|
"amount": int(first.get("amount") or 0),
|
|
}
|
|
return {"action": "fold", "amount": 0}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HTTP service
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class AIAgentService:
|
|
"""Glues the LLM client, prompt library and session registry together.
|
|
|
|
Exposed as a single object so that the HTTP handler stays thin and
|
|
purely concerned with request parsing and response framing.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
llm: LLMClient,
|
|
prompts: PromptLibrary,
|
|
registry: SessionRegistry | None = None,
|
|
console: AIAgentConsole | None = None,
|
|
) -> None:
|
|
self.llm = llm
|
|
self.prompts = prompts
|
|
self.console = console
|
|
# The system prompt is read once and shared across all sessions to
|
|
# avoid stale copies if the operator hot-edits the markdown file
|
|
# mid-game (intentional: restart the agent to pick up changes).
|
|
system_prompt = self.prompts.load("system")
|
|
self.registry = registry or SessionRegistry(system_prompt=system_prompt)
|
|
|
|
def handle_game(self, game_state: dict[str, Any], player_id: str) -> None:
|
|
"""Open or refresh the per-hand session."""
|
|
if self.console:
|
|
self.console.announce_game(game_state)
|
|
game_id = str(game_state.get("game_id") or "")
|
|
hand_number = int(game_state.get("hand_number") or 0)
|
|
session = self.registry.get_or_create(game_id, player_id)
|
|
session.reset_with_game(
|
|
hand_number=hand_number,
|
|
game_user_message=render_game_start_prompt(self.prompts, game_state),
|
|
)
|
|
|
|
def handle_act(self, observation: dict[str, Any]) -> dict[str, Any]:
|
|
"""Render the prompt, call the LLM, parse the reply."""
|
|
if self.console:
|
|
with self.console.act_log(observation):
|
|
return self._handle_act_locked(observation)
|
|
return self._handle_act_locked(observation)
|
|
|
|
def _handle_act_locked(self, observation: dict[str, Any]) -> dict[str, Any]:
|
|
game_id = str(observation.get("game_id") or "")
|
|
player_id = str(observation.get("player_id") or "")
|
|
session = self.registry.get_or_create(game_id, player_id)
|
|
# Ensure we never silently drop the system prompt even if /game
|
|
# arrives after /act (e.g. process started mid-hand on a restart).
|
|
if not session.messages:
|
|
session.messages = [
|
|
{"role": "system", "content": session.system_prompt}
|
|
]
|
|
user_msg = render_observation_prompt(self.prompts, observation)
|
|
session.append_user(user_msg)
|
|
|
|
try:
|
|
if self.console:
|
|
self.console.begin_llm_stream()
|
|
assistant_text = self.llm.chat(
|
|
session.chat_messages(),
|
|
on_delta=self.console.write_llm_delta if self.console else None,
|
|
)
|
|
if self.console:
|
|
self.console.end_llm_stream()
|
|
except RuntimeError as exc:
|
|
# On any LLM error, use the safe fallback. We also drop the
|
|
# last user message to avoid contaminating the next turn with
|
|
# a request that produced no assistant reply.
|
|
session.messages.pop()
|
|
action = fallback_action(observation)
|
|
if self.console:
|
|
self.console.announce_warning(str(exc))
|
|
self.console.announce_action(action, source="fallback")
|
|
return action
|
|
|
|
session.append_assistant(assistant_text)
|
|
try:
|
|
action = parse_action_reply(assistant_text)
|
|
if self.console:
|
|
self.console.announce_action(action, source="model")
|
|
return action
|
|
except ValueError as exc:
|
|
# Reply was unparseable; keep it in the history so the LLM can
|
|
# see what it did wrong on the next turn (no extra prompt
|
|
# needed - the next observation will fill that role) and
|
|
# answer with a safe action this turn.
|
|
action = fallback_action(observation)
|
|
if self.console:
|
|
self.console.announce_warning(str(exc))
|
|
self.console.announce_action(action, source="fallback")
|
|
return action
|
|
|
|
|
|
def _bind_player_id(handler: BaseHTTPRequestHandler) -> str:
|
|
"""Resolve which seat a request belongs to.
|
|
|
|
The standalone process serves *one* AI seat by default. Multi-seat
|
|
deployments can pass ``X-Player-Id`` so the registry can keep the
|
|
sessions isolated.
|
|
"""
|
|
explicit = handler.headers.get("X-Player-Id")
|
|
if explicit:
|
|
return explicit
|
|
return getattr(handler.server, "default_player_id", "ai") # type: ignore[attr-defined]
|
|
|
|
|
|
class AIRequestHandler(BaseHTTPRequestHandler):
|
|
"""HTTP entry point for the AI agent.
|
|
|
|
Routes:
|
|
- ``GET /health`` - liveness probe.
|
|
- ``POST /game`` - new hand boundary; opens a fresh session.
|
|
- ``POST /act`` - returns the AI-decided action.
|
|
"""
|
|
|
|
server_version = "TexasHoldemAIAgent/0.1"
|
|
service: AIAgentService # injected by ``create_server``
|
|
|
|
def do_GET(self) -> None:
|
|
if self.path == "/health":
|
|
self._json({"ok": True})
|
|
return
|
|
self._json({"error": "not found"}, HTTPStatus.NOT_FOUND)
|
|
|
|
def do_POST(self) -> None:
|
|
routes = {
|
|
"/game": self._handle_game,
|
|
"/act": self._handle_act,
|
|
}
|
|
handler = routes.get(self.path)
|
|
if handler is None:
|
|
self._json({"error": "not found"}, HTTPStatus.NOT_FOUND)
|
|
return
|
|
|
|
try:
|
|
payload = self._read_json()
|
|
except ValueError as exc:
|
|
self._json({"error": str(exc)}, HTTPStatus.BAD_REQUEST)
|
|
return
|
|
|
|
try:
|
|
handler(payload)
|
|
except Exception as exc: # pragma: no cover - last-resort guard
|
|
self._json({"error": str(exc)}, HTTPStatus.INTERNAL_SERVER_ERROR)
|
|
|
|
def _handle_game(self, payload: dict[str, Any]) -> None:
|
|
self.service.handle_game(payload, player_id=_bind_player_id(self))
|
|
self._empty(HTTPStatus.NO_CONTENT)
|
|
|
|
def _handle_act(self, payload: dict[str, Any]) -> None:
|
|
action = self.service.handle_act(payload)
|
|
self._json(action)
|
|
|
|
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
|
|
return
|
|
|
|
def _read_json(self) -> dict[str, Any]:
|
|
length = int(self.headers.get("Content-Length", "0"))
|
|
if length <= 0:
|
|
raise ValueError("request body is required")
|
|
try:
|
|
payload = json.loads(self.rfile.read(length).decode("utf-8"))
|
|
except json.JSONDecodeError as exc:
|
|
raise ValueError("request body must be valid JSON") from exc
|
|
if not isinstance(payload, dict):
|
|
raise ValueError("request body must be a JSON object")
|
|
return payload
|
|
|
|
def _json(
|
|
self, payload: dict[str, Any], status: HTTPStatus = HTTPStatus.OK
|
|
) -> None:
|
|
body = json.dumps(payload, ensure_ascii=True).encode("utf-8")
|
|
self.send_response(status)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.send_header("Content-Length", str(len(body)))
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
|
|
def _empty(self, status: HTTPStatus) -> None:
|
|
self.send_response(status)
|
|
self.send_header("Content-Length", "0")
|
|
self.end_headers()
|
|
|
|
|
|
def create_server(
|
|
host: str,
|
|
port: int,
|
|
service: AIAgentService,
|
|
default_player_id: str = "ai",
|
|
) -> ThreadingHTTPServer:
|
|
"""Build and configure the HTTP server.
|
|
|
|
The ``service`` and ``default_player_id`` are attached to the server /
|
|
handler classes so that all worker threads share a single instance,
|
|
which in turn means a single registry of sessions.
|
|
"""
|
|
server = ThreadingHTTPServer((host, port), AIRequestHandler)
|
|
AIRequestHandler.service = service
|
|
server.default_player_id = default_player_id # type: ignore[attr-defined]
|
|
return server
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI entry point
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description=(
|
|
"Run an OpenAI-compatible AI poker agent that exposes "
|
|
"POST /act and POST /game."
|
|
)
|
|
)
|
|
parser.add_argument("--host", default="127.0.0.1")
|
|
parser.add_argument("--port", default=9101, type=int)
|
|
parser.add_argument(
|
|
"--base-url",
|
|
default=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
|
help="OpenAI-compatible base URL (default: $OPENAI_BASE_URL or "
|
|
"https://api.openai.com/v1).",
|
|
)
|
|
parser.add_argument(
|
|
"--api-key",
|
|
default=os.environ.get("OPENAI_API_KEY", ""),
|
|
help="API key (default: $OPENAI_API_KEY).",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
default=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
|
|
help="Model identifier (default: $OPENAI_MODEL or gpt-4o-mini).",
|
|
)
|
|
parser.add_argument(
|
|
"--timeout",
|
|
default=60.0,
|
|
type=float,
|
|
help="LLM request timeout in seconds.",
|
|
)
|
|
parser.add_argument(
|
|
"--temperature",
|
|
default=0.4,
|
|
type=float,
|
|
help="LLM sampling temperature.",
|
|
)
|
|
parser.add_argument(
|
|
"--player-id",
|
|
default="ai",
|
|
help=(
|
|
"Default player_id used to key sessions. Override per-request "
|
|
"via the X-Player-Id header for multi-seat setups."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--prompts-dir",
|
|
default=str(PROMPTS_DIR),
|
|
help="Directory containing the prompt markdown templates.",
|
|
)
|
|
parser.add_argument(
|
|
"--keep-history",
|
|
action="store_true",
|
|
help=(
|
|
"Keep previous terminal output when a new /act request arrives "
|
|
"instead of clearing the screen."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--no-stream",
|
|
action="store_true",
|
|
help="Disable streaming Chat Completions requests.",
|
|
)
|
|
parser.add_argument(
|
|
"--no-color",
|
|
action="store_true",
|
|
help="Disable ANSI gray coloring for streamed LLM output.",
|
|
)
|
|
parser.add_argument(
|
|
"--no-reasoning",
|
|
action="store_true",
|
|
help=(
|
|
"Hide the LLM's reasoning/chain-of-thought stream from the "
|
|
"terminal. The final answer (content) is still printed so "
|
|
"operators can see the chosen action."
|
|
),
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if not args.api_key:
|
|
parser.error("--api-key (or OPENAI_API_KEY) is required")
|
|
|
|
config = LLMConfig(
|
|
base_url=args.base_url,
|
|
api_key=args.api_key,
|
|
model=args.model,
|
|
timeout_seconds=args.timeout,
|
|
temperature=args.temperature,
|
|
stream=not args.no_stream,
|
|
)
|
|
prompts = PromptLibrary(directory=Path(args.prompts_dir))
|
|
console = AIAgentConsole(
|
|
keep_history=args.keep_history,
|
|
use_color=not args.no_color,
|
|
show_reasoning=not args.no_reasoning,
|
|
)
|
|
service = AIAgentService(LLMClient(config), prompts, console=console)
|
|
server = create_server(args.host, args.port, service, default_player_id=args.player_id)
|
|
|
|
print(
|
|
f"AI HTTP agent listening on http://{args.host}:{args.port}\n"
|
|
f" POST /act - decision request\n"
|
|
f" POST /game - new-hand snapshot (opens a fresh session)\n"
|
|
f" model : {config.model}\n"
|
|
f" base_url : {config.base_url}\n"
|
|
f" player_id : {args.player_id}\n"
|
|
f" stream : {'on' if config.stream else 'off'}\n"
|
|
f" reasoning : {'off (hidden)' if args.no_reasoning else 'on'}\n"
|
|
f" clear-screen: {'off (keep history)' if args.keep_history else 'on'}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
try:
|
|
server.serve_forever()
|
|
except KeyboardInterrupt:
|
|
pass
|
|
finally:
|
|
server.server_close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|