Files
texas_hold_x/tests/test_engine.py
T
mamamiyear c0bc5384f4 feat: add hand detail API and enrich hand summary fields
- HandSummary: add hole_cards, starting_stacks, ending_stacks, pot_contributions
- Engine: capture all players' hole cards (not just showdown), pre/post hand stacks, per-level pot contributions
- Server: new GET /game/<game_id>/hands/<hand_number> route
- Service: add get_hand_state() method
- Tests: add ServerTests for new endpoint, update existing tests
- Existing GET /game/<game_id> auto-inherits new fields via shared to_dict()
2026-05-23 22:11:45 +08:00

169 lines
6.4 KiB
Python

import unittest
from random import Random
from texas_holdem.agents import CallingStationAgent, PokerAgent
from texas_holdem.cards import Card
from texas_holdem.engine import TableGame
from texas_holdem.models import Observation, PlayerAction
class RecordingAgent(PokerAgent):
def __init__(self, seen: list[tuple[str, str, int]]) -> None:
self.seen = seen
def decide(self, observation: Observation) -> PlayerAction:
self.seen.append((observation.street, observation.player_id, observation.to_call))
for action in observation.legal_actions:
if action["action"] == "check":
return PlayerAction("check")
return PlayerAction("call")
class ScriptedAgent(PokerAgent):
def __init__(
self,
actions: list[PlayerAction],
seen: list[tuple[str, str, list[str]]],
) -> None:
self.actions = actions
self.seen = seen
def decide(self, observation: Observation) -> PlayerAction:
self.seen.append(
(
observation.street,
observation.player_id,
[str(action["action"]) for action in observation.legal_actions],
)
)
if self.actions:
return self.actions.pop(0)
for action in observation.legal_actions:
if action["action"] == "check":
return PlayerAction("check")
return PlayerAction("call")
class EngineTests(unittest.TestCase):
def test_full_hand_preserves_total_chips(self) -> None:
players = [
("p1", "Player 1", CallingStationAgent()),
("p2", "Player 2", CallingStationAgent()),
("p3", "Player 3", CallingStationAgent()),
("p4", "Player 4", CallingStationAgent()),
]
game = TableGame("g1", players, starting_stack=1000, small_blind=5, big_blind=10, rng=Random(7))
summary = game.run_hand()
self.assertEqual(sum(player.stack for player in game.players), 4000)
self.assertEqual(len(summary.board), 5)
self.assertGreaterEqual(len(summary.awards), 1)
def test_preflop_observations_follow_table_order(self) -> None:
seen: list[tuple[str, str, int]] = []
players = [
("p1", "Button", RecordingAgent(seen)),
("p2", "Small Blind", RecordingAgent(seen)),
("p3", "Big Blind", RecordingAgent(seen)),
]
game = TableGame("g2", players, starting_stack=100, small_blind=5, big_blind=10, rng=Random(3))
game.run_hand()
preflop = [player_id for street, player_id, _ in seen if street == "preflop"]
self.assertEqual(preflop[:3], ["p1", "p2", "p3"])
def test_side_pots_are_awarded_to_eligible_players(self) -> None:
players = [
("p1", "Short", CallingStationAgent()),
("p2", "Middle", CallingStationAgent()),
("p3", "Deep", CallingStationAgent()),
]
game = TableGame("g3", players, starting_stack=0 + 100, small_blind=5, big_blind=10, rng=Random(1))
board = [Card.parse(value) for value in "2h 7d 9c Js 3h".split()]
holes = {
"p1": "Ah Ac",
"p2": "Kh Kc",
"p3": "Qh Qc",
}
bets = {"p1": 50, "p2": 100, "p3": 100}
for player in game.players:
player.stack = 0
player.in_hand = True
player.folded = False
player.hole_cards = [Card.parse(value) for value in holes[player.player_id].split()]
player.total_bet = bets[player.player_id]
game.board = board
game.button_index = 0
awards = game._award_pots()
self.assertEqual([award.amount for award in awards], [150, 100])
self.assertEqual(
[contribution["amount"] for contribution in game._last_pot_contributions],
[150, 100],
)
self.assertEqual(
game._last_pot_contributions[0]["contributors"],
{"p1": 50, "p2": 50, "p3": 50},
)
self.assertEqual(
game._last_pot_contributions[1]["contributors"],
{"p2": 50, "p3": 50},
)
self.assertEqual(game.players[0].stack, 150)
self.assertEqual(game.players[1].stack, 100)
self.assertEqual(game.players[2].stack, 0)
def test_hand_summary_includes_full_hand_snapshots(self) -> None:
players = [
("p1", "Player 1", CallingStationAgent()),
("p2", "Player 2", CallingStationAgent()),
("p3", "Player 3", CallingStationAgent()),
]
game = TableGame("g5", players, starting_stack=100, small_blind=5, big_blind=10, rng=Random(23))
summary = game.run_hand()
payload = summary.to_dict()
self.assertEqual(set(summary.hole_cards), {"p1", "p2", "p3"})
self.assertTrue(all(len(cards) == 2 for cards in summary.hole_cards.values()))
self.assertEqual(summary.starting_stacks, {"p1": 100, "p2": 100, "p3": 100})
self.assertEqual(set(summary.ending_stacks), {"p1", "p2", "p3"})
self.assertEqual(sum(summary.starting_stacks.values()), sum(summary.ending_stacks.values()))
self.assertGreaterEqual(len(summary.pot_contributions), 1)
self.assertTrue(
all(
contribution["amount"] == sum(contribution["contributors"].values())
for contribution in summary.pot_contributions
)
)
self.assertEqual(set(payload["hole_cards"]), {"p1", "p2", "p3"})
self.assertEqual(payload["starting_stacks"], {"p1": 100, "p2": 100, "p3": 100})
self.assertIn("ending_stacks", payload)
self.assertIn("pot_contributions", payload)
def test_short_all_in_does_not_reopen_raising_to_prior_actor(self) -> None:
seen: list[tuple[str, str, list[str]]] = []
players = [
("p1", "Button", ScriptedAgent([PlayerAction("raise", 20)], seen)),
("p2", "Short Blind", ScriptedAgent([PlayerAction("all_in", 25)], seen)),
("p3", "Big Blind", ScriptedAgent([PlayerAction("call")], seen)),
]
game = TableGame("g4", players, starting_stack=100, small_blind=5, big_blind=10, rng=Random(13))
game.players[1].stack = 25
game.run_hand()
p1_second_preflop = [
legal
for street, player_id, legal in seen
if street == "preflop" and player_id == "p1"
][1]
self.assertEqual(p1_second_preflop, ["fold", "call"])
if __name__ == "__main__":
unittest.main()