# !pip -q install pygame chess torch numpy """ Interactive Chess GUI with GRPO Model Predictions - Movable pieces (drag & drop) - Arrow showing top-3 predicted moves - Legal move enforcement STANDALONE VERSION: Contains necessary model classes to run without model.py. """ import sys import os import threading import queue import math from typing import List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F import numpy as np os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" try: import pygame import chess except ImportError: os.system("pip install -q pygame chess") import pygame import chess # ---------------------------------------------------------------------- # Core High-Performance Flags # ---------------------------------------------------------------------- torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True if hasattr(torch, 'set_float32_matmul_precision'): torch.set_float32_matmul_precision('high') # ---------------------------------------------------------------------- # Model Components (Included for standalone execution) # ---------------------------------------------------------------------- class ActionMapper: __slots__ = ['move_to_idx', 'idx_to_move', 'num_actions'] def __init__(self): self.move_to_idx = {} self.idx_to_move = [] idx = 0 for f in range(64): for t in range(64): if f == t: continue uci = chess.SQUARE_NAMES[f] + chess.SQUARE_NAMES[t] self.move_to_idx[uci] = idx self.idx_to_move.append(uci) idx += 1 if chess.square_rank(f) in (1, 6) and abs(chess.square_file(f) - chess.square_file(t)) <= 1: for promo in "nbrq": promo_uci = uci + promo self.move_to_idx[promo_uci] = idx self.idx_to_move.append(promo_uci) idx += 1 self.num_actions = idx ACTION_MAPPER = ActionMapper() class ChessNet(nn.Module): def __init__(self): super().__init__() self.conv_in = nn.Conv2d(14, 128, kernel_size=3, padding=1, bias=False) self.bn_in = nn.BatchNorm2d(128) self.res_blocks = nn.ModuleList([ nn.Sequential( nn.Conv2d(128, 128, 3, padding=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, 3, padding=1, bias=False), nn.BatchNorm2d(128) ) for _ in range(6) ]) self.policy_head = nn.Sequential( nn.Conv2d(128, 32, 1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(32 * 8 * 8, ACTION_MAPPER.num_actions) ) self.value_head = nn.Sequential( nn.Conv2d(128, 32, 1, bias=False), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Flatten(), nn.Linear(32 * 8 * 8, 256), nn.ReLU(inplace=True), nn.Linear(256, 1), nn.Tanh() ) def forward(self, x): x = F.relu(self.bn_in(self.conv_in(x)), inplace=True) for block in self.res_blocks: x = F.relu(x + block(x), inplace=True) return self.policy_head(x), self.value_head(x) def boards_to_tensor_vectorized(envs: List[chess.Board], out_tensor: np.ndarray): B = len(envs) bbs = np.zeros((B, 12), dtype=np.uint64) meta = np.zeros((B, 3), dtype=np.float32) for b, env in enumerate(envs): w = env.occupied_co[chess.WHITE] bc = env.occupied_co[chess.BLACK] bbs[b, 0] = env.pawns & w bbs[b, 1] = env.knights & w bbs[b, 2] = env.bishops & w bbs[b, 3] = env.rooks & w bbs[b, 4] = env.queens & w bbs[b, 5] = env.kings & w bbs[b, 6] = env.pawns & bc bbs[b, 7] = env.knights & bc bbs[b, 8] = env.bishops & bc bbs[b, 9] = env.rooks & bc bbs[b, 10] = env.queens & bc bbs[b, 11] = env.kings & bc meta[b, 0] = 1.0 if env.turn else -1.0 meta[b, 1] = env.castling_rights * 0.1333333 - 1.0 meta[b, 2] = 1.0 if env.ep_square else -1.0 # Bit unpacking (equivalent to model.py torch logic) bbs_bytes = bbs.view(np.uint8).reshape(B, 12, 8) unpacked = np.unpackbits(bbs_bytes, axis=2, bitorder='little').reshape(B, 12, 8, 8) # Quantize meta to int8 to match model.py training buffer behavior meta_int8 = meta.astype(np.int8) out_tensor[:, :12, :, :] = unpacked.astype(np.float32) # Channel 12: Turn (filled 8x8) out_tensor[:, 12, :, :] = meta_int8[:, 0].reshape(B, 1, 1) # Channel 13: Castling (filled 8x8) then EP at (0, 1) out_tensor[:, 13, :, :] = meta_int8[:, 1].reshape(B, 1, 1) for b in range(B): out_tensor[b, 13, 0, 1] = meta_int8[b, 2] # ---------------------------------------------------------------------- # Pygame Initialization # ---------------------------------------------------------------------- # Suppress video driver errors if running in headless Colab environment if "google.colab" in sys.modules: print("WARNING: You are running test.py in Google Colab.") print("Pygame requires a GUI display which Colab does not have natively.") print("It is recommended to run test.py locally on your Windows PC and load the latest.pt file.") try: pygame.init() screen_test = pygame.display.set_mode((1, 1)) pygame.display.quit() HAS_DISPLAY = True except pygame.error: print("ERROR: No display detected. Pygame GUI cannot start without a screen.") HAS_DISPLAY = False WIDTH, HEIGHT = 800, 800 SQUARE_SIZE = WIDTH // 8 FPS = 60 LIGHT = (240, 217, 181) DARK = (181, 136, 99) HIGHLIGHT = (255, 255, 0, 100) ARROW_COLOR = (50, 150, 250) TEXT_COLOR = (0, 0, 0) def create_piece_surface(piece: chess.Piece) -> pygame.Surface: surf = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA) # Try Windows-specific Unicode fonts first, fallback to default font_names = ["segoeuisymbol", "arial", "msgothic"] font = None is_default = False for fn in font_names: if fn in pygame.font.get_fonts(): font = pygame.font.SysFont(fn, int(SQUARE_SIZE * 0.7)) break if font is None: font = pygame.font.Font(None, int(SQUARE_SIZE * 0.6)) is_default = True symbols = { 'P': '♙', 'N': '♘', 'B': '♗', 'R': '♖', 'Q': '♕', 'K': '♔', 'p': '♟', 'n': '♞', 'b': '♝', 'r': '♜', 'q': '♛', 'k': '♚' } char = symbols[piece.symbol()] # If using default font (freesansbold), unicode chess pieces often render as missing boxes. # We fallback to standard English letters to guarantee visibility. if is_default: char = piece.symbol().upper() color = (255, 255, 255) if piece.color == chess.WHITE else (30, 30, 30) outline_color = (0, 0, 0) if piece.color == chess.WHITE else (255, 255, 255) text = font.render(char, True, color) text_rect = text.get_rect(center=(SQUARE_SIZE//2, SQUARE_SIZE//2)) # Draw outline for better visibility for dx, dy in [(-1,-1), (-1,1), (1,-1), (1,1), (-2,0), (2,0), (0,-2), (0,2)]: outline = font.render(char, True, outline_color) surf.blit(outline, text_rect.move(dx, dy)) surf.blit(text, text_rect) return surf # ---------------------------------------------------------------------- # Model Inference (Async Worker) # ---------------------------------------------------------------------- class ModelInference: def __init__(self, checkpoint_path: str): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = ChessNet().to(self.device).to(memory_format=torch.channels_last) if os.path.exists(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location=self.device) state_dict = checkpoint['model_state_dict'] # Handle compiled model prefix new_state_dict = {} for k, v in state_dict.items(): name = k.replace('_orig_mod.', '') new_state_dict[name] = v self.model.load_state_dict(new_state_dict) print(f"Loaded checkpoint from {checkpoint_path}") else: print(f"Warning: Checkpoint '{checkpoint_path}' not found. Using untrained weights.") self.model.eval() self.queue = queue.Queue() self.running = True self.thread = threading.Thread(target=self._worker, daemon=True) self.thread.start() def _worker(self): state_np = np.zeros((1, 14, 8, 8), dtype=np.float32) while self.running: try: board_fen, callback = self.queue.get(timeout=0.1) except queue.Empty: continue board = chess.Board(board_fen) boards_to_tensor_vectorized([board], state_np) tensor = torch.tensor(state_np, dtype=torch.float32, device=self.device).to(memory_format=torch.channels_last) with torch.no_grad(): if hasattr(torch.amp, 'autocast'): with torch.amp.autocast(self.device): logits, value = self.model(tensor) else: with torch.cuda.amp.autocast(): logits, value = self.model(tensor) probs = torch.softmax(logits.to(torch.float32), dim=-1).cpu().numpy().flatten() legal_moves = list(board.legal_moves) legal_indices = [ACTION_MAPPER.move_to_idx[m.uci()] for m in legal_moves] probs_filtered = np.zeros_like(probs) probs_filtered[legal_indices] = probs[legal_indices] s = probs_filtered.sum() if s > 0: probs_filtered /= s top_indices = np.argsort(probs_filtered)[-3:][::-1] top_moves = [(ACTION_MAPPER.idx_to_move[i], probs_filtered[i]) for i in top_indices if probs_filtered[i] > 0] top_moves_obj = [] for uci, p in top_moves: move = chess.Move.from_uci(uci) if move in legal_moves: top_moves_obj.append((move, p)) callback(top_moves_obj, value.item()) def predict_async(self, board: chess.Board, callback): while not self.queue.empty(): try: self.queue.get_nowait() except queue.Empty: pass self.queue.put((board.fen(), callback)) def shutdown(self): self.running = False self.thread.join() # ---------------------------------------------------------------------- # Main GUI Application # ---------------------------------------------------------------------- class ChessApp: def __init__(self, model_checkpoint: str): if not HAS_DISPLAY: print("\nExiting because no GUI display is available.") sys.exit(1) pygame.init() self.screen = pygame.display.set_mode((WIDTH, HEIGHT)) pygame.display.set_caption("GRPO Chess Agent - Real-Time Predictions") self.clock = pygame.time.Clock() self.board = chess.Board() self.selected_square = None self.valid_moves = [] self.predicted_arrows = [] self.prediction_value = 0.0 self.piece_images = {} self._load_pieces() self.inference = ModelInference(model_checkpoint) self.running = True self.update_predictions() def _load_pieces(self): for piece_type in chess.PIECE_TYPES: for color in (chess.WHITE, chess.BLACK): piece = chess.Piece(piece_type, color) self.piece_images[(piece_type, color)] = create_piece_surface(piece) def square_to_xy(self, square: chess.Square) -> Tuple[int, int]: file_idx = chess.square_file(square) rank_idx = 7 - chess.square_rank(square) return file_idx * SQUARE_SIZE, rank_idx * SQUARE_SIZE def xy_to_square(self, x: int, y: int) -> Optional[chess.Square]: file_idx = x // SQUARE_SIZE rank_idx = 7 - (y // SQUARE_SIZE) if 0 <= file_idx < 8 and 0 <= rank_idx < 8: return chess.square(file_idx, rank_idx) return None def draw_board(self): for row in range(8): for col in range(8): color = LIGHT if (row + col) % 2 == 0 else DARK rect = pygame.Rect(col * SQUARE_SIZE, row * SQUARE_SIZE, SQUARE_SIZE, SQUARE_SIZE) pygame.draw.rect(self.screen, color, rect) def draw_pieces(self): for square in chess.SQUARES: piece = self.board.piece_at(square) if piece: x, y = self.square_to_xy(square) self.screen.blit(self.piece_images[(piece.piece_type, piece.color)], (x, y)) def draw_highlights(self): if self.selected_square is not None: x, y = self.square_to_xy(self.selected_square) s = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA) s.fill((255, 255, 0, 100)) self.screen.blit(s, (x, y)) for move in self.valid_moves: x, y = self.square_to_xy(move.to_square) s = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA) s.fill((0, 255, 0, 60)) self.screen.blit(s, (x, y)) def draw_arrows(self): for move, prob in self.predicted_arrows: start_x, start_y = self.square_to_xy(move.from_square) end_x, end_y = self.square_to_xy(move.to_square) start = (start_x + SQUARE_SIZE//2, start_y + SQUARE_SIZE//2) end = (end_x + SQUARE_SIZE//2, end_y + SQUARE_SIZE//2) alpha = max(100, int(255 * prob)) color = (ARROW_COLOR[0], ARROW_COLOR[1], ARROW_COLOR[2], alpha) width = max(3, int(12 * prob)) arrow_surface = pygame.Surface((WIDTH, HEIGHT), pygame.SRCALPHA) pygame.draw.line(arrow_surface, color, start, end, width) angle = math.atan2(end[1]-start[1], end[0]-start[0]) arrow_len = max(15, int(25 * prob)) arrow_angle = math.pi/6 x2 = end[0] - arrow_len * math.cos(angle - arrow_angle) y2 = end[1] - arrow_len * math.sin(angle - arrow_angle) x3 = end[0] - arrow_len * math.cos(angle + arrow_angle) y3 = end[1] - arrow_len * math.sin(angle + arrow_angle) pygame.draw.polygon(arrow_surface, color, [end, (x2, y2), (x3, y3)]) self.screen.blit(arrow_surface, (0, 0)) def draw_info(self): font = pygame.font.Font(None, 36) text = f"Eval Value: {self.prediction_value:.3f} (W/B)" surf = font.render(text, True, TEXT_COLOR) bg_rect = surf.get_rect(topleft=(10, HEIGHT - 40)) bg_rect.inflate_ip(10, 10) pygame.draw.rect(self.screen, (255, 255, 255), bg_rect) pygame.draw.rect(self.screen, (0, 0, 0), bg_rect, 2) self.screen.blit(surf, (15, HEIGHT - 35)) def update_predictions(self): def callback(top_moves, value): self.predicted_arrows = top_moves self.prediction_value = value self.inference.predict_async(self.board, callback) def handle_click(self, pos): square = self.xy_to_square(*pos) if square is None: return if self.selected_square is None: piece = self.board.piece_at(square) if piece and piece.color == self.board.turn: self.selected_square = square self.valid_moves = [m for m in self.board.legal_moves if m.from_square == square] else: move = chess.Move(self.selected_square, square) if chess.square_rank(square) in (0, 7) and self.board.piece_at(self.selected_square).piece_type == chess.PAWN: move = chess.Move(self.selected_square, square, promotion=chess.QUEEN) if move in self.board.legal_moves: self.board.push(move) self.selected_square = None self.valid_moves = [] self.update_predictions() else: piece = self.board.piece_at(square) if piece and piece.color == self.board.turn: self.selected_square = square self.valid_moves = [m for m in self.board.legal_moves if m.from_square == square] else: self.selected_square = None self.valid_moves = [] def run(self): while self.running: for event in pygame.event.get(): if event.type == pygame.QUIT: self.running = False elif event.type == pygame.MOUSEBUTTONDOWN: if event.button == 1: self.handle_click(event.pos) elif event.type == pygame.KEYDOWN: if event.key == pygame.K_r: self.board.reset() self.selected_square = None self.valid_moves = [] self.update_predictions() self.screen.fill((0, 0, 0)) self.draw_board() self.draw_highlights() self.draw_arrows() self.draw_pieces() self.draw_info() pygame.display.flip() self.clock.tick(FPS) self.inference.shutdown() pygame.quit() sys.exit() if __name__ == "__main__": checkpoint_path = "./checkpoints/latest.pt" if len(sys.argv) >= 2 and not sys.argv[1].startswith('-'): checkpoint_path = sys.argv[1] app = ChessApp(checkpoint_path) if HAS_DISPLAY: app.run()