GRPO_chessagent / test.py
algorembrant's picture
Upload test.py
1af5cb8 verified
# !pip -q install pygame chess torch numpy
"""
Interactive Chess GUI with GRPO Model Predictions
- Movable pieces (drag & drop)
- Arrow showing top-3 predicted moves
- Legal move enforcement
STANDALONE VERSION: Contains necessary model classes to run without model.py.
"""
import sys
import os
import threading
import queue
import math
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
try:
import pygame
import chess
except ImportError:
os.system("pip install -q pygame chess")
import pygame
import chess
# ----------------------------------------------------------------------
# Core High-Performance Flags
# ----------------------------------------------------------------------
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high')
# ----------------------------------------------------------------------
# Model Components (Included for standalone execution)
# ----------------------------------------------------------------------
class ActionMapper:
__slots__ = ['move_to_idx', 'idx_to_move', 'num_actions']
def __init__(self):
self.move_to_idx = {}
self.idx_to_move = []
idx = 0
for f in range(64):
for t in range(64):
if f == t: continue
uci = chess.SQUARE_NAMES[f] + chess.SQUARE_NAMES[t]
self.move_to_idx[uci] = idx
self.idx_to_move.append(uci)
idx += 1
if chess.square_rank(f) in (1, 6) and abs(chess.square_file(f) - chess.square_file(t)) <= 1:
for promo in "nbrq":
promo_uci = uci + promo
self.move_to_idx[promo_uci] = idx
self.idx_to_move.append(promo_uci)
idx += 1
self.num_actions = idx
ACTION_MAPPER = ActionMapper()
class ChessNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_in = nn.Conv2d(14, 128, kernel_size=3, padding=1, bias=False)
self.bn_in = nn.BatchNorm2d(128)
self.res_blocks = nn.ModuleList([
nn.Sequential(
nn.Conv2d(128, 128, 3, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, 3, padding=1, bias=False),
nn.BatchNorm2d(128)
) for _ in range(6)
])
self.policy_head = nn.Sequential(
nn.Conv2d(128, 32, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(32 * 8 * 8, ACTION_MAPPER.num_actions)
)
self.value_head = nn.Sequential(
nn.Conv2d(128, 32, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Flatten(),
nn.Linear(32 * 8 * 8, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 1),
nn.Tanh()
)
def forward(self, x):
x = F.relu(self.bn_in(self.conv_in(x)), inplace=True)
for block in self.res_blocks:
x = F.relu(x + block(x), inplace=True)
return self.policy_head(x), self.value_head(x)
def boards_to_tensor_vectorized(envs: List[chess.Board], out_tensor: np.ndarray):
B = len(envs)
bbs = np.zeros((B, 12), dtype=np.uint64)
meta = np.zeros((B, 3), dtype=np.float32)
for b, env in enumerate(envs):
w = env.occupied_co[chess.WHITE]
bc = env.occupied_co[chess.BLACK]
bbs[b, 0] = env.pawns & w
bbs[b, 1] = env.knights & w
bbs[b, 2] = env.bishops & w
bbs[b, 3] = env.rooks & w
bbs[b, 4] = env.queens & w
bbs[b, 5] = env.kings & w
bbs[b, 6] = env.pawns & bc
bbs[b, 7] = env.knights & bc
bbs[b, 8] = env.bishops & bc
bbs[b, 9] = env.rooks & bc
bbs[b, 10] = env.queens & bc
bbs[b, 11] = env.kings & bc
meta[b, 0] = 1.0 if env.turn else -1.0
meta[b, 1] = env.castling_rights * 0.1333333 - 1.0
meta[b, 2] = 1.0 if env.ep_square else -1.0
# Bit unpacking (equivalent to model.py torch logic)
bbs_bytes = bbs.view(np.uint8).reshape(B, 12, 8)
unpacked = np.unpackbits(bbs_bytes, axis=2, bitorder='little').reshape(B, 12, 8, 8)
# Quantize meta to int8 to match model.py training buffer behavior
meta_int8 = meta.astype(np.int8)
out_tensor[:, :12, :, :] = unpacked.astype(np.float32)
# Channel 12: Turn (filled 8x8)
out_tensor[:, 12, :, :] = meta_int8[:, 0].reshape(B, 1, 1)
# Channel 13: Castling (filled 8x8) then EP at (0, 1)
out_tensor[:, 13, :, :] = meta_int8[:, 1].reshape(B, 1, 1)
for b in range(B):
out_tensor[b, 13, 0, 1] = meta_int8[b, 2]
# ----------------------------------------------------------------------
# Pygame Initialization
# ----------------------------------------------------------------------
# Suppress video driver errors if running in headless Colab environment
if "google.colab" in sys.modules:
print("WARNING: You are running test.py in Google Colab.")
print("Pygame requires a GUI display which Colab does not have natively.")
print("It is recommended to run test.py locally on your Windows PC and load the latest.pt file.")
try:
pygame.init()
screen_test = pygame.display.set_mode((1, 1))
pygame.display.quit()
HAS_DISPLAY = True
except pygame.error:
print("ERROR: No display detected. Pygame GUI cannot start without a screen.")
HAS_DISPLAY = False
WIDTH, HEIGHT = 800, 800
SQUARE_SIZE = WIDTH // 8
FPS = 60
LIGHT = (240, 217, 181)
DARK = (181, 136, 99)
HIGHLIGHT = (255, 255, 0, 100)
ARROW_COLOR = (50, 150, 250)
TEXT_COLOR = (0, 0, 0)
def create_piece_surface(piece: chess.Piece) -> pygame.Surface:
surf = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA)
# Try Windows-specific Unicode fonts first, fallback to default
font_names = ["segoeuisymbol", "arial", "msgothic"]
font = None
is_default = False
for fn in font_names:
if fn in pygame.font.get_fonts():
font = pygame.font.SysFont(fn, int(SQUARE_SIZE * 0.7))
break
if font is None:
font = pygame.font.Font(None, int(SQUARE_SIZE * 0.6))
is_default = True
symbols = {
'P': '♙', 'N': '♘', 'B': '♗', 'R': '♖', 'Q': '♕', 'K': '♔',
'p': '♟', 'n': '♞', 'b': '♝', 'r': '♜', 'q': '♛', 'k': '♚'
}
char = symbols[piece.symbol()]
# If using default font (freesansbold), unicode chess pieces often render as missing boxes.
# We fallback to standard English letters to guarantee visibility.
if is_default:
char = piece.symbol().upper()
color = (255, 255, 255) if piece.color == chess.WHITE else (30, 30, 30)
outline_color = (0, 0, 0) if piece.color == chess.WHITE else (255, 255, 255)
text = font.render(char, True, color)
text_rect = text.get_rect(center=(SQUARE_SIZE//2, SQUARE_SIZE//2))
# Draw outline for better visibility
for dx, dy in [(-1,-1), (-1,1), (1,-1), (1,1), (-2,0), (2,0), (0,-2), (0,2)]:
outline = font.render(char, True, outline_color)
surf.blit(outline, text_rect.move(dx, dy))
surf.blit(text, text_rect)
return surf
# ----------------------------------------------------------------------
# Model Inference (Async Worker)
# ----------------------------------------------------------------------
class ModelInference:
def __init__(self, checkpoint_path: str):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = ChessNet().to(self.device).to(memory_format=torch.channels_last)
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location=self.device)
state_dict = checkpoint['model_state_dict']
# Handle compiled model prefix
new_state_dict = {}
for k, v in state_dict.items():
name = k.replace('_orig_mod.', '')
new_state_dict[name] = v
self.model.load_state_dict(new_state_dict)
print(f"Loaded checkpoint from {checkpoint_path}")
else:
print(f"Warning: Checkpoint '{checkpoint_path}' not found. Using untrained weights.")
self.model.eval()
self.queue = queue.Queue()
self.running = True
self.thread = threading.Thread(target=self._worker, daemon=True)
self.thread.start()
def _worker(self):
state_np = np.zeros((1, 14, 8, 8), dtype=np.float32)
while self.running:
try:
board_fen, callback = self.queue.get(timeout=0.1)
except queue.Empty:
continue
board = chess.Board(board_fen)
boards_to_tensor_vectorized([board], state_np)
tensor = torch.tensor(state_np, dtype=torch.float32, device=self.device).to(memory_format=torch.channels_last)
with torch.no_grad():
if hasattr(torch.amp, 'autocast'):
with torch.amp.autocast(self.device):
logits, value = self.model(tensor)
else:
with torch.cuda.amp.autocast():
logits, value = self.model(tensor)
probs = torch.softmax(logits.to(torch.float32), dim=-1).cpu().numpy().flatten()
legal_moves = list(board.legal_moves)
legal_indices = [ACTION_MAPPER.move_to_idx[m.uci()] for m in legal_moves]
probs_filtered = np.zeros_like(probs)
probs_filtered[legal_indices] = probs[legal_indices]
s = probs_filtered.sum()
if s > 0:
probs_filtered /= s
top_indices = np.argsort(probs_filtered)[-3:][::-1]
top_moves = [(ACTION_MAPPER.idx_to_move[i], probs_filtered[i]) for i in top_indices if probs_filtered[i] > 0]
top_moves_obj = []
for uci, p in top_moves:
move = chess.Move.from_uci(uci)
if move in legal_moves:
top_moves_obj.append((move, p))
callback(top_moves_obj, value.item())
def predict_async(self, board: chess.Board, callback):
while not self.queue.empty():
try:
self.queue.get_nowait()
except queue.Empty:
pass
self.queue.put((board.fen(), callback))
def shutdown(self):
self.running = False
self.thread.join()
# ----------------------------------------------------------------------
# Main GUI Application
# ----------------------------------------------------------------------
class ChessApp:
def __init__(self, model_checkpoint: str):
if not HAS_DISPLAY:
print("\nExiting because no GUI display is available.")
sys.exit(1)
pygame.init()
self.screen = pygame.display.set_mode((WIDTH, HEIGHT))
pygame.display.set_caption("GRPO Chess Agent - Real-Time Predictions")
self.clock = pygame.time.Clock()
self.board = chess.Board()
self.selected_square = None
self.valid_moves = []
self.predicted_arrows = []
self.prediction_value = 0.0
self.piece_images = {}
self._load_pieces()
self.inference = ModelInference(model_checkpoint)
self.running = True
self.update_predictions()
def _load_pieces(self):
for piece_type in chess.PIECE_TYPES:
for color in (chess.WHITE, chess.BLACK):
piece = chess.Piece(piece_type, color)
self.piece_images[(piece_type, color)] = create_piece_surface(piece)
def square_to_xy(self, square: chess.Square) -> Tuple[int, int]:
file_idx = chess.square_file(square)
rank_idx = 7 - chess.square_rank(square)
return file_idx * SQUARE_SIZE, rank_idx * SQUARE_SIZE
def xy_to_square(self, x: int, y: int) -> Optional[chess.Square]:
file_idx = x // SQUARE_SIZE
rank_idx = 7 - (y // SQUARE_SIZE)
if 0 <= file_idx < 8 and 0 <= rank_idx < 8:
return chess.square(file_idx, rank_idx)
return None
def draw_board(self):
for row in range(8):
for col in range(8):
color = LIGHT if (row + col) % 2 == 0 else DARK
rect = pygame.Rect(col * SQUARE_SIZE, row * SQUARE_SIZE, SQUARE_SIZE, SQUARE_SIZE)
pygame.draw.rect(self.screen, color, rect)
def draw_pieces(self):
for square in chess.SQUARES:
piece = self.board.piece_at(square)
if piece:
x, y = self.square_to_xy(square)
self.screen.blit(self.piece_images[(piece.piece_type, piece.color)], (x, y))
def draw_highlights(self):
if self.selected_square is not None:
x, y = self.square_to_xy(self.selected_square)
s = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA)
s.fill((255, 255, 0, 100))
self.screen.blit(s, (x, y))
for move in self.valid_moves:
x, y = self.square_to_xy(move.to_square)
s = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA)
s.fill((0, 255, 0, 60))
self.screen.blit(s, (x, y))
def draw_arrows(self):
for move, prob in self.predicted_arrows:
start_x, start_y = self.square_to_xy(move.from_square)
end_x, end_y = self.square_to_xy(move.to_square)
start = (start_x + SQUARE_SIZE//2, start_y + SQUARE_SIZE//2)
end = (end_x + SQUARE_SIZE//2, end_y + SQUARE_SIZE//2)
alpha = max(100, int(255 * prob))
color = (ARROW_COLOR[0], ARROW_COLOR[1], ARROW_COLOR[2], alpha)
width = max(3, int(12 * prob))
arrow_surface = pygame.Surface((WIDTH, HEIGHT), pygame.SRCALPHA)
pygame.draw.line(arrow_surface, color, start, end, width)
angle = math.atan2(end[1]-start[1], end[0]-start[0])
arrow_len = max(15, int(25 * prob))
arrow_angle = math.pi/6
x2 = end[0] - arrow_len * math.cos(angle - arrow_angle)
y2 = end[1] - arrow_len * math.sin(angle - arrow_angle)
x3 = end[0] - arrow_len * math.cos(angle + arrow_angle)
y3 = end[1] - arrow_len * math.sin(angle + arrow_angle)
pygame.draw.polygon(arrow_surface, color, [end, (x2, y2), (x3, y3)])
self.screen.blit(arrow_surface, (0, 0))
def draw_info(self):
font = pygame.font.Font(None, 36)
text = f"Eval Value: {self.prediction_value:.3f} (W/B)"
surf = font.render(text, True, TEXT_COLOR)
bg_rect = surf.get_rect(topleft=(10, HEIGHT - 40))
bg_rect.inflate_ip(10, 10)
pygame.draw.rect(self.screen, (255, 255, 255), bg_rect)
pygame.draw.rect(self.screen, (0, 0, 0), bg_rect, 2)
self.screen.blit(surf, (15, HEIGHT - 35))
def update_predictions(self):
def callback(top_moves, value):
self.predicted_arrows = top_moves
self.prediction_value = value
self.inference.predict_async(self.board, callback)
def handle_click(self, pos):
square = self.xy_to_square(*pos)
if square is None:
return
if self.selected_square is None:
piece = self.board.piece_at(square)
if piece and piece.color == self.board.turn:
self.selected_square = square
self.valid_moves = [m for m in self.board.legal_moves if m.from_square == square]
else:
move = chess.Move(self.selected_square, square)
if chess.square_rank(square) in (0, 7) and self.board.piece_at(self.selected_square).piece_type == chess.PAWN:
move = chess.Move(self.selected_square, square, promotion=chess.QUEEN)
if move in self.board.legal_moves:
self.board.push(move)
self.selected_square = None
self.valid_moves = []
self.update_predictions()
else:
piece = self.board.piece_at(square)
if piece and piece.color == self.board.turn:
self.selected_square = square
self.valid_moves = [m for m in self.board.legal_moves if m.from_square == square]
else:
self.selected_square = None
self.valid_moves = []
def run(self):
while self.running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
self.running = False
elif event.type == pygame.MOUSEBUTTONDOWN:
if event.button == 1:
self.handle_click(event.pos)
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_r:
self.board.reset()
self.selected_square = None
self.valid_moves = []
self.update_predictions()
self.screen.fill((0, 0, 0))
self.draw_board()
self.draw_highlights()
self.draw_arrows()
self.draw_pieces()
self.draw_info()
pygame.display.flip()
self.clock.tick(FPS)
self.inference.shutdown()
pygame.quit()
sys.exit()
if __name__ == "__main__":
checkpoint_path = "./checkpoints/latest.pt"
if len(sys.argv) >= 2 and not sys.argv[1].startswith('-'):
checkpoint_path = sys.argv[1]
app = ChessApp(checkpoint_path)
if HAS_DISPLAY:
app.run()