feat: basic function
This commit is contained in:
@@ -0,0 +1,128 @@
|
||||
import unittest
|
||||
from random import Random
|
||||
|
||||
from texas_holdem.agents import CallingStationAgent, PokerAgent
|
||||
from texas_holdem.cards import Card
|
||||
from texas_holdem.engine import TableGame
|
||||
from texas_holdem.models import Observation, PlayerAction
|
||||
|
||||
|
||||
class RecordingAgent(PokerAgent):
|
||||
def __init__(self, seen: list[tuple[str, str, int]]) -> None:
|
||||
self.seen = seen
|
||||
|
||||
def decide(self, observation: Observation) -> PlayerAction:
|
||||
self.seen.append((observation.street, observation.player_id, observation.to_call))
|
||||
for action in observation.legal_actions:
|
||||
if action["action"] == "check":
|
||||
return PlayerAction("check")
|
||||
return PlayerAction("call")
|
||||
|
||||
|
||||
class ScriptedAgent(PokerAgent):
|
||||
def __init__(
|
||||
self,
|
||||
actions: list[PlayerAction],
|
||||
seen: list[tuple[str, str, list[str]]],
|
||||
) -> None:
|
||||
self.actions = actions
|
||||
self.seen = seen
|
||||
|
||||
def decide(self, observation: Observation) -> PlayerAction:
|
||||
self.seen.append(
|
||||
(
|
||||
observation.street,
|
||||
observation.player_id,
|
||||
[str(action["action"]) for action in observation.legal_actions],
|
||||
)
|
||||
)
|
||||
if self.actions:
|
||||
return self.actions.pop(0)
|
||||
for action in observation.legal_actions:
|
||||
if action["action"] == "check":
|
||||
return PlayerAction("check")
|
||||
return PlayerAction("call")
|
||||
|
||||
|
||||
class EngineTests(unittest.TestCase):
|
||||
def test_full_hand_preserves_total_chips(self) -> None:
|
||||
players = [
|
||||
("p1", "Player 1", CallingStationAgent()),
|
||||
("p2", "Player 2", CallingStationAgent()),
|
||||
("p3", "Player 3", CallingStationAgent()),
|
||||
("p4", "Player 4", CallingStationAgent()),
|
||||
]
|
||||
game = TableGame("g1", players, starting_stack=1000, small_blind=5, big_blind=10, rng=Random(7))
|
||||
|
||||
summary = game.run_hand()
|
||||
|
||||
self.assertEqual(sum(player.stack for player in game.players), 4000)
|
||||
self.assertEqual(len(summary.board), 5)
|
||||
self.assertGreaterEqual(len(summary.awards), 1)
|
||||
|
||||
def test_preflop_observations_follow_table_order(self) -> None:
|
||||
seen: list[tuple[str, str, int]] = []
|
||||
players = [
|
||||
("p1", "Button", RecordingAgent(seen)),
|
||||
("p2", "Small Blind", RecordingAgent(seen)),
|
||||
("p3", "Big Blind", RecordingAgent(seen)),
|
||||
]
|
||||
game = TableGame("g2", players, starting_stack=100, small_blind=5, big_blind=10, rng=Random(3))
|
||||
|
||||
game.run_hand()
|
||||
preflop = [player_id for street, player_id, _ in seen if street == "preflop"]
|
||||
|
||||
self.assertEqual(preflop[:3], ["p1", "p2", "p3"])
|
||||
|
||||
def test_side_pots_are_awarded_to_eligible_players(self) -> None:
|
||||
players = [
|
||||
("p1", "Short", CallingStationAgent()),
|
||||
("p2", "Middle", CallingStationAgent()),
|
||||
("p3", "Deep", CallingStationAgent()),
|
||||
]
|
||||
game = TableGame("g3", players, starting_stack=0 + 100, small_blind=5, big_blind=10, rng=Random(1))
|
||||
board = [Card.parse(value) for value in "2h 7d 9c Js 3h".split()]
|
||||
holes = {
|
||||
"p1": "Ah Ac",
|
||||
"p2": "Kh Kc",
|
||||
"p3": "Qh Qc",
|
||||
}
|
||||
bets = {"p1": 50, "p2": 100, "p3": 100}
|
||||
for player in game.players:
|
||||
player.stack = 0
|
||||
player.in_hand = True
|
||||
player.folded = False
|
||||
player.hole_cards = [Card.parse(value) for value in holes[player.player_id].split()]
|
||||
player.total_bet = bets[player.player_id]
|
||||
game.board = board
|
||||
game.button_index = 0
|
||||
|
||||
awards = game._award_pots()
|
||||
|
||||
self.assertEqual([award.amount for award in awards], [150, 100])
|
||||
self.assertEqual(game.players[0].stack, 150)
|
||||
self.assertEqual(game.players[1].stack, 100)
|
||||
self.assertEqual(game.players[2].stack, 0)
|
||||
|
||||
def test_short_all_in_does_not_reopen_raising_to_prior_actor(self) -> None:
|
||||
seen: list[tuple[str, str, list[str]]] = []
|
||||
players = [
|
||||
("p1", "Button", ScriptedAgent([PlayerAction("raise", 20)], seen)),
|
||||
("p2", "Short Blind", ScriptedAgent([PlayerAction("all_in", 25)], seen)),
|
||||
("p3", "Big Blind", ScriptedAgent([PlayerAction("call")], seen)),
|
||||
]
|
||||
game = TableGame("g4", players, starting_stack=100, small_blind=5, big_blind=10, rng=Random(13))
|
||||
game.players[1].stack = 25
|
||||
|
||||
game.run_hand()
|
||||
p1_second_preflop = [
|
||||
legal
|
||||
for street, player_id, legal in seen
|
||||
if street == "preflop" and player_id == "p1"
|
||||
][1]
|
||||
|
||||
self.assertEqual(p1_second_preflop, ["fold", "call"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user