| |
|
|
| """ |
| Interactive Chess GUI with GRPO Model Predictions |
| - Movable pieces (drag & drop) |
| - Arrow showing top-3 predicted moves |
| - Legal move enforcement |
| STANDALONE VERSION: Contains necessary model classes to run without model.py. |
| """ |
|
|
| import sys |
| import os |
| import threading |
| import queue |
| import math |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" |
| try: |
| import pygame |
| import chess |
| except ImportError: |
| os.system("pip install -q pygame chess") |
| import pygame |
| import chess |
|
|
| |
| |
| |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| if hasattr(torch, 'set_float32_matmul_precision'): |
| torch.set_float32_matmul_precision('high') |
|
|
|
|
| |
| |
| |
| class ActionMapper: |
| __slots__ = ['move_to_idx', 'idx_to_move', 'num_actions'] |
| def __init__(self): |
| self.move_to_idx = {} |
| self.idx_to_move = [] |
| idx = 0 |
| for f in range(64): |
| for t in range(64): |
| if f == t: continue |
| uci = chess.SQUARE_NAMES[f] + chess.SQUARE_NAMES[t] |
| self.move_to_idx[uci] = idx |
| self.idx_to_move.append(uci) |
| idx += 1 |
| if chess.square_rank(f) in (1, 6) and abs(chess.square_file(f) - chess.square_file(t)) <= 1: |
| for promo in "nbrq": |
| promo_uci = uci + promo |
| self.move_to_idx[promo_uci] = idx |
| self.idx_to_move.append(promo_uci) |
| idx += 1 |
| self.num_actions = idx |
|
|
| ACTION_MAPPER = ActionMapper() |
|
|
| class ChessNet(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv_in = nn.Conv2d(14, 128, kernel_size=3, padding=1, bias=False) |
| self.bn_in = nn.BatchNorm2d(128) |
| |
| self.res_blocks = nn.ModuleList([ |
| nn.Sequential( |
| nn.Conv2d(128, 128, 3, padding=1, bias=False), |
| nn.BatchNorm2d(128), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(128, 128, 3, padding=1, bias=False), |
| nn.BatchNorm2d(128) |
| ) for _ in range(6) |
| ]) |
| |
| self.policy_head = nn.Sequential( |
| nn.Conv2d(128, 32, 1, bias=False), |
| nn.BatchNorm2d(32), |
| nn.ReLU(inplace=True), |
| nn.Flatten(), |
| nn.Linear(32 * 8 * 8, ACTION_MAPPER.num_actions) |
| ) |
| |
| self.value_head = nn.Sequential( |
| nn.Conv2d(128, 32, 1, bias=False), |
| nn.BatchNorm2d(32), |
| nn.ReLU(inplace=True), |
| nn.Flatten(), |
| nn.Linear(32 * 8 * 8, 256), |
| nn.ReLU(inplace=True), |
| nn.Linear(256, 1), |
| nn.Tanh() |
| ) |
|
|
| def forward(self, x): |
| x = F.relu(self.bn_in(self.conv_in(x)), inplace=True) |
| for block in self.res_blocks: |
| x = F.relu(x + block(x), inplace=True) |
| return self.policy_head(x), self.value_head(x) |
|
|
| def boards_to_tensor_vectorized(envs: List[chess.Board], out_tensor: np.ndarray): |
| B = len(envs) |
| bbs = np.zeros((B, 12), dtype=np.uint64) |
| meta = np.zeros((B, 3), dtype=np.float32) |
|
|
| for b, env in enumerate(envs): |
| w = env.occupied_co[chess.WHITE] |
| bc = env.occupied_co[chess.BLACK] |
| |
| bbs[b, 0] = env.pawns & w |
| bbs[b, 1] = env.knights & w |
| bbs[b, 2] = env.bishops & w |
| bbs[b, 3] = env.rooks & w |
| bbs[b, 4] = env.queens & w |
| bbs[b, 5] = env.kings & w |
| bbs[b, 6] = env.pawns & bc |
| bbs[b, 7] = env.knights & bc |
| bbs[b, 8] = env.bishops & bc |
| bbs[b, 9] = env.rooks & bc |
| bbs[b, 10] = env.queens & bc |
| bbs[b, 11] = env.kings & bc |
|
|
| meta[b, 0] = 1.0 if env.turn else -1.0 |
| meta[b, 1] = env.castling_rights * 0.1333333 - 1.0 |
| meta[b, 2] = 1.0 if env.ep_square else -1.0 |
|
|
| |
| bbs_bytes = bbs.view(np.uint8).reshape(B, 12, 8) |
| unpacked = np.unpackbits(bbs_bytes, axis=2, bitorder='little').reshape(B, 12, 8, 8) |
| |
| |
| meta_int8 = meta.astype(np.int8) |
| |
| out_tensor[:, :12, :, :] = unpacked.astype(np.float32) |
| |
| |
| out_tensor[:, 12, :, :] = meta_int8[:, 0].reshape(B, 1, 1) |
| |
| |
| out_tensor[:, 13, :, :] = meta_int8[:, 1].reshape(B, 1, 1) |
| for b in range(B): |
| out_tensor[b, 13, 0, 1] = meta_int8[b, 2] |
|
|
|
|
| |
| |
| |
| |
| if "google.colab" in sys.modules: |
| print("WARNING: You are running test.py in Google Colab.") |
| print("Pygame requires a GUI display which Colab does not have natively.") |
| print("It is recommended to run test.py locally on your Windows PC and load the latest.pt file.") |
|
|
| try: |
| pygame.init() |
| screen_test = pygame.display.set_mode((1, 1)) |
| pygame.display.quit() |
| HAS_DISPLAY = True |
| except pygame.error: |
| print("ERROR: No display detected. Pygame GUI cannot start without a screen.") |
| HAS_DISPLAY = False |
|
|
| WIDTH, HEIGHT = 800, 800 |
| SQUARE_SIZE = WIDTH // 8 |
| FPS = 60 |
|
|
| LIGHT = (240, 217, 181) |
| DARK = (181, 136, 99) |
| HIGHLIGHT = (255, 255, 0, 100) |
| ARROW_COLOR = (50, 150, 250) |
| TEXT_COLOR = (0, 0, 0) |
|
|
| def create_piece_surface(piece: chess.Piece) -> pygame.Surface: |
| surf = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA) |
| |
| |
| font_names = ["segoeuisymbol", "arial", "msgothic"] |
| font = None |
| is_default = False |
| for fn in font_names: |
| if fn in pygame.font.get_fonts(): |
| font = pygame.font.SysFont(fn, int(SQUARE_SIZE * 0.7)) |
| break |
| |
| if font is None: |
| font = pygame.font.Font(None, int(SQUARE_SIZE * 0.6)) |
| is_default = True |
|
|
| symbols = { |
| 'P': '♙', 'N': '♘', 'B': '♗', 'R': '♖', 'Q': '♕', 'K': '♔', |
| 'p': '♟', 'n': '♞', 'b': '♝', 'r': '♜', 'q': '♛', 'k': '♚' |
| } |
| |
| char = symbols[piece.symbol()] |
| |
| |
| |
| if is_default: |
| char = piece.symbol().upper() |
|
|
| color = (255, 255, 255) if piece.color == chess.WHITE else (30, 30, 30) |
| outline_color = (0, 0, 0) if piece.color == chess.WHITE else (255, 255, 255) |
| |
| text = font.render(char, True, color) |
| text_rect = text.get_rect(center=(SQUARE_SIZE//2, SQUARE_SIZE//2)) |
| |
| |
| for dx, dy in [(-1,-1), (-1,1), (1,-1), (1,1), (-2,0), (2,0), (0,-2), (0,2)]: |
| outline = font.render(char, True, outline_color) |
| surf.blit(outline, text_rect.move(dx, dy)) |
| |
| surf.blit(text, text_rect) |
| return surf |
|
|
| |
| |
| |
| class ModelInference: |
| def __init__(self, checkpoint_path: str): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model = ChessNet().to(self.device).to(memory_format=torch.channels_last) |
| |
| if os.path.exists(checkpoint_path): |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) |
| state_dict = checkpoint['model_state_dict'] |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| name = k.replace('_orig_mod.', '') |
| new_state_dict[name] = v |
| self.model.load_state_dict(new_state_dict) |
| print(f"Loaded checkpoint from {checkpoint_path}") |
| else: |
| print(f"Warning: Checkpoint '{checkpoint_path}' not found. Using untrained weights.") |
| |
| self.model.eval() |
| self.queue = queue.Queue() |
| self.running = True |
| self.thread = threading.Thread(target=self._worker, daemon=True) |
| self.thread.start() |
|
|
| def _worker(self): |
| state_np = np.zeros((1, 14, 8, 8), dtype=np.float32) |
| while self.running: |
| try: |
| board_fen, callback = self.queue.get(timeout=0.1) |
| except queue.Empty: |
| continue |
| |
| board = chess.Board(board_fen) |
| boards_to_tensor_vectorized([board], state_np) |
| tensor = torch.tensor(state_np, dtype=torch.float32, device=self.device).to(memory_format=torch.channels_last) |
| |
| with torch.no_grad(): |
| if hasattr(torch.amp, 'autocast'): |
| with torch.amp.autocast(self.device): |
| logits, value = self.model(tensor) |
| else: |
| with torch.cuda.amp.autocast(): |
| logits, value = self.model(tensor) |
| |
| probs = torch.softmax(logits.to(torch.float32), dim=-1).cpu().numpy().flatten() |
| |
| legal_moves = list(board.legal_moves) |
| legal_indices = [ACTION_MAPPER.move_to_idx[m.uci()] for m in legal_moves] |
| |
| probs_filtered = np.zeros_like(probs) |
| probs_filtered[legal_indices] = probs[legal_indices] |
| s = probs_filtered.sum() |
| if s > 0: |
| probs_filtered /= s |
| |
| top_indices = np.argsort(probs_filtered)[-3:][::-1] |
| top_moves = [(ACTION_MAPPER.idx_to_move[i], probs_filtered[i]) for i in top_indices if probs_filtered[i] > 0] |
| |
| top_moves_obj = [] |
| for uci, p in top_moves: |
| move = chess.Move.from_uci(uci) |
| if move in legal_moves: |
| top_moves_obj.append((move, p)) |
| |
| callback(top_moves_obj, value.item()) |
|
|
| def predict_async(self, board: chess.Board, callback): |
| while not self.queue.empty(): |
| try: |
| self.queue.get_nowait() |
| except queue.Empty: |
| pass |
| self.queue.put((board.fen(), callback)) |
|
|
| def shutdown(self): |
| self.running = False |
| self.thread.join() |
|
|
| |
| |
| |
| class ChessApp: |
| def __init__(self, model_checkpoint: str): |
| if not HAS_DISPLAY: |
| print("\nExiting because no GUI display is available.") |
| sys.exit(1) |
| |
| pygame.init() |
| self.screen = pygame.display.set_mode((WIDTH, HEIGHT)) |
| pygame.display.set_caption("GRPO Chess Agent - Real-Time Predictions") |
| self.clock = pygame.time.Clock() |
| self.board = chess.Board() |
| self.selected_square = None |
| self.valid_moves = [] |
| self.predicted_arrows = [] |
| self.prediction_value = 0.0 |
| self.piece_images = {} |
| self._load_pieces() |
| self.inference = ModelInference(model_checkpoint) |
| self.running = True |
| self.update_predictions() |
|
|
| def _load_pieces(self): |
| for piece_type in chess.PIECE_TYPES: |
| for color in (chess.WHITE, chess.BLACK): |
| piece = chess.Piece(piece_type, color) |
| self.piece_images[(piece_type, color)] = create_piece_surface(piece) |
|
|
| def square_to_xy(self, square: chess.Square) -> Tuple[int, int]: |
| file_idx = chess.square_file(square) |
| rank_idx = 7 - chess.square_rank(square) |
| return file_idx * SQUARE_SIZE, rank_idx * SQUARE_SIZE |
|
|
| def xy_to_square(self, x: int, y: int) -> Optional[chess.Square]: |
| file_idx = x // SQUARE_SIZE |
| rank_idx = 7 - (y // SQUARE_SIZE) |
| if 0 <= file_idx < 8 and 0 <= rank_idx < 8: |
| return chess.square(file_idx, rank_idx) |
| return None |
|
|
| def draw_board(self): |
| for row in range(8): |
| for col in range(8): |
| color = LIGHT if (row + col) % 2 == 0 else DARK |
| rect = pygame.Rect(col * SQUARE_SIZE, row * SQUARE_SIZE, SQUARE_SIZE, SQUARE_SIZE) |
| pygame.draw.rect(self.screen, color, rect) |
|
|
| def draw_pieces(self): |
| for square in chess.SQUARES: |
| piece = self.board.piece_at(square) |
| if piece: |
| x, y = self.square_to_xy(square) |
| self.screen.blit(self.piece_images[(piece.piece_type, piece.color)], (x, y)) |
|
|
| def draw_highlights(self): |
| if self.selected_square is not None: |
| x, y = self.square_to_xy(self.selected_square) |
| s = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA) |
| s.fill((255, 255, 0, 100)) |
| self.screen.blit(s, (x, y)) |
| for move in self.valid_moves: |
| x, y = self.square_to_xy(move.to_square) |
| s = pygame.Surface((SQUARE_SIZE, SQUARE_SIZE), pygame.SRCALPHA) |
| s.fill((0, 255, 0, 60)) |
| self.screen.blit(s, (x, y)) |
|
|
| def draw_arrows(self): |
| for move, prob in self.predicted_arrows: |
| start_x, start_y = self.square_to_xy(move.from_square) |
| end_x, end_y = self.square_to_xy(move.to_square) |
| start = (start_x + SQUARE_SIZE//2, start_y + SQUARE_SIZE//2) |
| end = (end_x + SQUARE_SIZE//2, end_y + SQUARE_SIZE//2) |
| |
| alpha = max(100, int(255 * prob)) |
| color = (ARROW_COLOR[0], ARROW_COLOR[1], ARROW_COLOR[2], alpha) |
| width = max(3, int(12 * prob)) |
| |
| arrow_surface = pygame.Surface((WIDTH, HEIGHT), pygame.SRCALPHA) |
| pygame.draw.line(arrow_surface, color, start, end, width) |
| |
| angle = math.atan2(end[1]-start[1], end[0]-start[0]) |
| arrow_len = max(15, int(25 * prob)) |
| arrow_angle = math.pi/6 |
| x2 = end[0] - arrow_len * math.cos(angle - arrow_angle) |
| y2 = end[1] - arrow_len * math.sin(angle - arrow_angle) |
| x3 = end[0] - arrow_len * math.cos(angle + arrow_angle) |
| y3 = end[1] - arrow_len * math.sin(angle + arrow_angle) |
| pygame.draw.polygon(arrow_surface, color, [end, (x2, y2), (x3, y3)]) |
| |
| self.screen.blit(arrow_surface, (0, 0)) |
|
|
| def draw_info(self): |
| font = pygame.font.Font(None, 36) |
| text = f"Eval Value: {self.prediction_value:.3f} (W/B)" |
| surf = font.render(text, True, TEXT_COLOR) |
| |
| bg_rect = surf.get_rect(topleft=(10, HEIGHT - 40)) |
| bg_rect.inflate_ip(10, 10) |
| pygame.draw.rect(self.screen, (255, 255, 255), bg_rect) |
| pygame.draw.rect(self.screen, (0, 0, 0), bg_rect, 2) |
| |
| self.screen.blit(surf, (15, HEIGHT - 35)) |
|
|
| def update_predictions(self): |
| def callback(top_moves, value): |
| self.predicted_arrows = top_moves |
| self.prediction_value = value |
| self.inference.predict_async(self.board, callback) |
|
|
| def handle_click(self, pos): |
| square = self.xy_to_square(*pos) |
| if square is None: |
| return |
| |
| if self.selected_square is None: |
| piece = self.board.piece_at(square) |
| if piece and piece.color == self.board.turn: |
| self.selected_square = square |
| self.valid_moves = [m for m in self.board.legal_moves if m.from_square == square] |
| else: |
| move = chess.Move(self.selected_square, square) |
| if chess.square_rank(square) in (0, 7) and self.board.piece_at(self.selected_square).piece_type == chess.PAWN: |
| move = chess.Move(self.selected_square, square, promotion=chess.QUEEN) |
|
|
| if move in self.board.legal_moves: |
| self.board.push(move) |
| self.selected_square = None |
| self.valid_moves = [] |
| self.update_predictions() |
| else: |
| piece = self.board.piece_at(square) |
| if piece and piece.color == self.board.turn: |
| self.selected_square = square |
| self.valid_moves = [m for m in self.board.legal_moves if m.from_square == square] |
| else: |
| self.selected_square = None |
| self.valid_moves = [] |
|
|
| def run(self): |
| while self.running: |
| for event in pygame.event.get(): |
| if event.type == pygame.QUIT: |
| self.running = False |
| elif event.type == pygame.MOUSEBUTTONDOWN: |
| if event.button == 1: |
| self.handle_click(event.pos) |
| elif event.type == pygame.KEYDOWN: |
| if event.key == pygame.K_r: |
| self.board.reset() |
| self.selected_square = None |
| self.valid_moves = [] |
| self.update_predictions() |
|
|
| self.screen.fill((0, 0, 0)) |
| self.draw_board() |
| self.draw_highlights() |
| self.draw_arrows() |
| self.draw_pieces() |
| self.draw_info() |
| |
| pygame.display.flip() |
| self.clock.tick(FPS) |
|
|
| self.inference.shutdown() |
| pygame.quit() |
| sys.exit() |
|
|
| if __name__ == "__main__": |
| checkpoint_path = "./checkpoints/latest.pt" |
| if len(sys.argv) >= 2 and not sys.argv[1].startswith('-'): |
| checkpoint_path = sys.argv[1] |
| |
| app = ChessApp(checkpoint_path) |
| if HAS_DISPLAY: |
| app.run() |