Files
texas_hold_x/texas_holdem/agents.py
T
2026-05-11 15:46:30 +08:00

131 lines
4.9 KiB
Python

from __future__ import annotations
import json
import sys
from abc import ABC, abstractmethod
from random import Random
from typing import IO, Any
from urllib.error import URLError
from urllib.request import Request, urlopen
from texas_holdem.human_io import prompt_action, render_observation
from texas_holdem.models import Observation, PlayerAction
class PokerAgent(ABC):
@abstractmethod
def decide(self, observation: Observation) -> PlayerAction:
raise NotImplementedError
class RandomAgent(PokerAgent):
def __init__(self, rng: Random | None = None) -> None:
self._rng = rng or Random()
def decide(self, observation: Observation) -> PlayerAction:
legal = observation.legal_actions
choice = self._rng.choice(legal)
action_type = str(choice["action"])
if action_type in {"bet", "raise"}:
min_amount = int(choice["min_amount"])
max_amount = int(choice["max_amount"])
return PlayerAction(action_type, self._rng.randint(min_amount, max_amount))
return PlayerAction(action_type, int(choice.get("amount") or 0))
class CallingStationAgent(PokerAgent):
def decide(self, observation: Observation) -> PlayerAction:
for action in observation.legal_actions:
if action["action"] == "check":
return PlayerAction("check")
for action in observation.legal_actions:
if action["action"] == "call":
return PlayerAction("call", int(action.get("amount") or 0))
return PlayerAction("fold")
class HttpAgent(PokerAgent):
def __init__(self, endpoint: str, timeout_seconds: float = 10.0) -> None:
self.endpoint = endpoint
self.timeout_seconds = timeout_seconds
def decide(self, observation: Observation) -> PlayerAction:
body = json.dumps(observation.to_dict()).encode("utf-8")
request = Request(
self.endpoint,
data=body,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urlopen(request, timeout=self.timeout_seconds) as response:
payload: Any = json.loads(response.read().decode("utf-8"))
except (OSError, URLError, json.JSONDecodeError) as exc:
raise RuntimeError(f"agent endpoint failed: {self.endpoint}") from exc
if not isinstance(payload, dict):
raise RuntimeError("agent endpoint must return a JSON object")
return PlayerAction.from_dict(payload)
class HumanAgent(PokerAgent):
"""Interactive CLI agent for debugging and manual play.
The agent renders the current observation in a human-friendly layout and
drives an interactive menu so the operator can only emit legal actions.
Streams are injected to keep the agent testable and to allow alternate
consoles in the future (e.g. piping to a debug log).
"""
def __init__(
self,
input_stream: IO[str] | None = None,
output_stream: IO[str] | None = None,
) -> None:
self._input = input_stream if input_stream is not None else sys.stdin
self._output = output_stream if output_stream is not None else sys.stdout
def decide(self, observation: Observation) -> PlayerAction:
# Convert to dict-form so the rendering/prompting code path is shared
# with the standalone HTTP human client (see texas_holdem.human_io).
obs_dict = observation.to_dict()
self._write(render_observation(obs_dict))
chosen = prompt_action(
list(obs_dict.get("legal_actions") or []),
self._read_line,
self._write,
)
return PlayerAction.from_dict(chosen)
def _write(self, text: str) -> None:
"""Write to the configured output stream and flush eagerly."""
self._output.write(text)
self._output.flush()
def _read_line(self, prompt: str) -> str:
"""Display a prompt and read one line from the configured input.
We avoid builtin ``input()`` to honour the injected streams, which
also makes the agent unit-testable with StringIO.
"""
self._write(prompt)
line = self._input.readline()
if line == "":
raise EOFError("input stream closed while waiting for human action")
return line.rstrip("\n")
def build_agent(spec: dict[str, Any], rng: Random | None = None) -> PokerAgent:
agent_type = str(spec.get("type", "calling")).lower()
if agent_type == "random":
return RandomAgent(rng)
if agent_type in {"calling", "call", "calling_station"}:
return CallingStationAgent()
if agent_type == "http":
endpoint = spec.get("endpoint")
if not endpoint:
raise ValueError("http agent requires an endpoint")
return HttpAgent(str(endpoint), float(spec.get("timeout_seconds", 10.0)))
if agent_type in {"human", "cli", "interactive"}:
return HumanAgent()
raise ValueError(f"unknown agent type: {agent_type}")