{ "cells": [ { "cell_type": "markdown", "id": "6c9fd40a", "metadata": {}, "source": [ "# The Continuous Thought Machine – Tutorial 05: Hugging Face [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SakanaAI/continuous-thought-machines/blob/main/examples/05_huggingface.ipynb) [![arXiv](https://img.shields.io/badge/arXiv-2505.05522-b31b1b.svg)](https://arxiv.org/abs/2505.05522)" ] }, { "cell_type": "markdown", "id": "1bd709d2", "metadata": {}, "source": [ "The CTM is now on Hugging Face! 🤗\n", "\n", "Specifically, we have uploaded the image classification CTM trained on ImageNet, and the large maze solving CTM. Additionally, we have uploaded the maze datsets (in small, medium, large and extralarge variants) to make working with this task more convienient!\n", "\n", "Everything can be found on Hugging Face [here](https://huggingface.co/collections/SakanaAI/continuous-thought-machines-68edd4bb94a7809e074468e7)!\n", "\n", "In this notebook, we walkthough how to load both of these models, as well as how to use the maze datasets." ] }, { "cell_type": "markdown", "id": "54763a48", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "markdown", "id": "33e0fc19", "metadata": {}, "source": [ "If running the notebook in Colab, set `USE_COLAB` to `True` to clone the repo." ] }, { "cell_type": "code", "execution_count": 18, "id": "81af1184", "metadata": {}, "outputs": [], "source": [ "USE_COLAB = False" ] }, { "cell_type": "code", "execution_count": 19, "id": "a46a5706", "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "if USE_COLAB:\n", " !git clone https://github.com/SakanaAI/continuous-thought-machines.git\n", " sys.path.append(\"./continuous-thought-machines\")\n", "else:\n", " sys.path.append(\"..\")" ] }, { "cell_type": "code", "execution_count": 20, "id": "eef2c91c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: mediapy in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (1.2.4)\n", "Requirement already satisfied: ipython in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from mediapy) (9.2.0)\n", "Requirement already satisfied: matplotlib in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from mediapy) (3.10.3)\n", "Requirement already satisfied: numpy in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from mediapy) (2.2.5)\n", "Requirement already satisfied: Pillow in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from mediapy) (11.2.1)\n", "Requirement already satisfied: decorator in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from ipython->mediapy) (5.2.1)\n", "Requirement already satisfied: ipython-pygments-lexers in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from ipython->mediapy) (1.1.1)\n", "Requirement already satisfied: jedi>=0.16 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from ipython->mediapy) (0.19.2)\n", "Requirement already satisfied: matplotlib-inline in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from ipython->mediapy) (0.1.7)\n", "Requirement already satisfied: pexpect>4.3 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from ipython->mediapy) (4.9.0)\n", "Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from ipython->mediapy) (3.0.51)\n", "Requirement already satisfied: pygments>=2.4.0 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from ipython->mediapy) (2.19.1)\n", "Requirement already satisfied: stack_data in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from ipython->mediapy) (0.6.3)\n", "Requirement already satisfied: traitlets>=5.13.0 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from ipython->mediapy) (5.14.3)\n", "Requirement already satisfied: wcwidth in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython->mediapy) (0.2.13)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from jedi>=0.16->ipython->mediapy) (0.8.4)\n", "Requirement already satisfied: ptyprocess>=0.5 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from pexpect>4.3->ipython->mediapy) (0.7.0)\n", "Requirement already satisfied: contourpy>=1.0.1 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from matplotlib->mediapy) (1.3.2)\n", "Requirement already satisfied: cycler>=0.10 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from matplotlib->mediapy) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from matplotlib->mediapy) (4.58.0)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from matplotlib->mediapy) (1.4.8)\n", "Requirement already satisfied: packaging>=20.0 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from matplotlib->mediapy) (25.0)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from matplotlib->mediapy) (3.2.3)\n", "Requirement already satisfied: python-dateutil>=2.7 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from matplotlib->mediapy) (2.9.0.post0)\n", "Requirement already satisfied: six>=1.5 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib->mediapy) (1.17.0)\n", "Requirement already satisfied: executing>=1.2.0 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from stack_data->ipython->mediapy) (2.2.0)\n", "Requirement already satisfied: asttokens>=2.1.0 in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from stack_data->ipython->mediapy) (3.0.0)\n", "Requirement already satisfied: pure_eval in /home/ciaran_sakana_ai/.conda/envs/ctm/lib/python3.12/site-packages (from stack_data->ipython->mediapy) (0.2.3)\n" ] } ], "source": [ "!pip install mediapy" ] }, { "cell_type": "code", "execution_count": 21, "id": "70b5973c", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import random\n", "from PIL import Image\n", "from torchvision import transforms\n", "import urllib\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from tqdm import tqdm\n", "from matplotlib import patheffects\n", "from scipy import ndimage\n", "import imageio\n", "import mediapy\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", "\n", "from models.ctm import ContinuousThoughtMachine as CTM\n", "from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES\n", "from utils.losses import maze_loss\n", "from tasks.mazes.plotting import make_maze_gif" ] }, { "cell_type": "code", "execution_count": 22, "id": "22ba7fcb", "metadata": {}, "outputs": [], "source": [ "def set_seed(seed=42, deterministic=True):\n", " random.seed(seed)\n", " np.random.seed(seed)\n", " torch.manual_seed(seed)\n", " torch.cuda.manual_seed_all(seed)\n", " torch.backends.cudnn.deterministic = deterministic\n", " torch.backends.cudnn.benchmark = False" ] }, { "cell_type": "code", "execution_count": 23, "id": "ee1ef3fc", "metadata": {}, "outputs": [], "source": [ "set_seed(42) # ... the meaning of life is ..." ] }, { "cell_type": "code", "execution_count": 24, "id": "0cadd21d", "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" ] }, { "cell_type": "markdown", "id": "74a8a152", "metadata": {}, "source": [ "### Loading a Pretrained Model" ] }, { "cell_type": "markdown", "id": "bb8085b0", "metadata": {}, "source": [ "To load the image classfication CTM we can simply call `from_pretrained` on the CTM class." ] }, { "cell_type": "code", "execution_count": null, "id": "1037dda9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using neuron select type: random-pairing\n", "Synch representation size action: 2048\n", "Synch representation size out: 8196\n" ] } ], "source": [ "model = CTM.from_pretrained(\"SakanaAI/ctm-imagenet\")\n", "model = model.to(device)\n", "model.eval();" ] }, { "cell_type": "markdown", "id": "b931c83a", "metadata": {}, "source": [ "Let's test the model by running inference with it. We download an image and convert it to a PyTorch Tensor." ] }, { "cell_type": "code", "execution_count": 32, "id": "f91afeee", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading goldfish.jpg from local directory...\n" ] } ], "source": [ "filename = \"goldfish.jpg\"\n", "\n", "if os.path.exists(filename):\n", " print(f\"Loading {filename} from local directory...\")\n", " image = Image.open(filename).convert(\"RGB\")\n", "else:\n", " print(f\"{filename} not found locally. Downloading...\")\n", " url = \"https://www.seahorseaquariums.com/image/cache/catalog/Categories%20-%20Freshewater/Coldwater%20Fish/Fantails/Fantail%20Goldfish-2000x2000.jpg\"\n", " urllib.request.urlretrieve(url, filename)\n", " image = Image.open(filename).convert(\"RGB\")\n", "target = 1 # Goldfish\n", "urllib.request.urlretrieve(url, filename)\n", "image = Image.open(filename).convert(\"RGB\")\n", "\n", "# Preprocess the image\n", "dataset_mean = [0.485, 0.456, 0.406]\n", "dataset_std = [0.229, 0.224, 0.225]\n", "transform = transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize(mean=dataset_mean, std=dataset_std)\n", "])\n", "\n", "input_tensor = transform(image).unsqueeze(0).to(device)" ] }, { "cell_type": "markdown", "id": "560364c9", "metadata": {}, "source": [ "Run a forward pass with the CTM. For simplicity, we take the prediction at the last internal tick as the final prediction. Alternatively, we could use the prediction at the models most certain internal tick. See [tutorial 01: MNIST](https://github.com/SakanaAI/continuous-thought-machines/blob/main/examples/01_mnist.ipynb) for more details." ] }, { "cell_type": "code", "execution_count": 27, "id": "c527cf15", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Target Class: 1 = goldfish, Carassius auratus\n", "Predicted Class (final): 1 = goldfish, Carassius auratus\n" ] } ], "source": [ "with torch.no_grad():\n", " predictions, certainties, synchronization, pre_activations, post_activations, attention_tracking = model(input_tensor, track=True)\n", "# Get predictions\n", "prediction_last = predictions[0, :, -1].argmax(dim=0)\n", "IMAGENET_CLASS_LIST = list(IMAGENET2012_CLASSES.values())\n", "\n", "print(f\"Target Class: {target} = {IMAGENET_CLASS_LIST[target]}\")\n", "print(f\"Predicted Class (final): {prediction_last.item()} = {IMAGENET_CLASS_LIST[prediction_last.item()]}\")" ] }, { "cell_type": "markdown", "id": "ff4003a7", "metadata": {}, "source": [ "The pretrained CTM correctly classifies the image as a goldfish! 🐠\n", "\n", "Let's make a gif to vizualise the reasoning process." ] }, { "cell_type": "code", "execution_count": 33, "id": "dd17f979", "metadata": {}, "outputs": [], "source": [ "def make_gif(predictions, certainties, attention_tracking, ground_truth_target, out_path, dataset_mean, dataset_std):\n", "\n", " def find_island_centers(array_2d, threshold):\n", " \"\"\"\n", " Finds the center of mass of each island (connected component > threshold)\n", " in a 2D array, weighted by the array's values.\n", " Returns list of (y, x) centers and list of areas.\n", " \"\"\"\n", " binary_image = array_2d > threshold\n", " labeled_image, num_labels = ndimage.label(binary_image)\n", " centers = []\n", " areas = []\n", " # Calculate center of mass for each labeled island (label 0 is background)\n", " for i in range(1, num_labels + 1):\n", " island_mask = (labeled_image == i)\n", " total_mass = np.sum(array_2d[island_mask])\n", " if total_mass > 0:\n", " # Get coordinates for this island\n", " y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]\n", " # Calculate weighted average for center\n", " x_center = np.average(x_coords[island_mask], weights=array_2d[island_mask])\n", " y_center = np.average(y_coords[island_mask], weights=array_2d[island_mask])\n", " centers.append((round(y_center, 4), round(x_center, 4)))\n", " areas.append(np.sum(island_mask)) # Area is the count of pixels in the island\n", " return centers, areas\n", "\n", " interp_mode = 'nearest'\n", " figscale = 0.85\n", "\n", " class_labels = list(IMAGENET2012_CLASSES.values()) # Load actual class names\n", "\n", " # predictions: (B, Classes, Steps), attention_tracking: (Steps*B*Heads, SeqLen)\n", " n_steps = predictions.size(-1)\n", "\n", " # --- Reshape Attention ---\n", " # Infer feature map size from model internals (assuming B=1)\n", " h_feat, w_feat = model.kv_features.shape[-2:]\n", "\n", " n_heads = attention_tracking.shape[2] \n", " # Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1\n", " attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)\n", "\n", " # --- Setup for Plotting ---\n", " step_linspace = np.linspace(0, 1, n_steps) # For step colors\n", " # Define color maps\n", " cmap_spectral = sns.color_palette(\"Spectral\", as_cmap=True)\n", " cmap_attention = sns.color_palette('viridis', as_cmap=True)\n", "\n", " # Create output directory for this index\n", " index_output_dir = os.path.join(out_path, str(0))\n", " os.makedirs(index_output_dir, exist_ok=True)\n", "\n", " frames = [] # Store frames for GIF\n", " head_routes = {h: [] for h in range(n_heads)} # Store (y,x) path points per head\n", " head_routes[-1] = []\n", " route_colours_step = [] # Store colors for each step's path segments\n", "\n", " # --- Loop Through Each Step ---\n", " for step_i in tqdm(range(n_steps), desc=\"Processing steps\", unit=\"step\"):\n", "\n", " # --- Prepare Image for Display ---\n", " # Denormalize the input tensor for visualization\n", " data_img_tensor = input_tensor[0].cpu() # Get first item in batch, move to CPU\n", " mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)\n", " std_tensor = torch.tensor(dataset_std).view(3, 1, 1)\n", " data_img_denorm = data_img_tensor * std_tensor + mean_tensor\n", " # Permute to (H, W, C) and convert to numpy, clip to [0, 1]\n", " data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()\n", " data_img_np = np.clip(data_img_np, 0, 1)\n", " img_h, img_w = data_img_np.shape[:2]\n", "\n", " # --- Process Attention & Certainty ---\n", " # Average attention over last few steps (from original code)\n", " start_step = max(0, step_i - 5)\n", " attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)\n", " # Get certainties up to current step\n", " certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty\n", "\n", " # --- Calculate Attention Paths (using bilinear interp) ---\n", " # Interpolate attention to image size using bilinear for center finding\n", " attention_interp_bilinear = F.interpolate(\n", " torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float\n", " size=(img_h, img_w),\n", " mode=interp_mode,\n", " # align_corners=False\n", " ).squeeze(0) # Remove batch dim -> (Heads, H, W)\n", "\n", " # Normalize each head's map to [0, 1]\n", " # Deal with mean\n", " attn_mean = attention_interp_bilinear.mean(0)\n", " attn_mean_min = attn_mean.min()\n", " attn_mean_max = attn_mean.max()\n", " attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)\n", " centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)\n", "\n", " if centers: # If islands found\n", " largest_island_idx = np.argmax(areas)\n", " current_center = centers[largest_island_idx] # (y, x)\n", " head_routes[-1].append(current_center)\n", " elif head_routes[-1]: # If no center now, repeat last known center if history exists\n", " head_routes[-1].append(head_routes[-1][-1])\n", "\n", "\n", " attn_min = attention_interp_bilinear.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)\n", " attn_max = attention_interp_bilinear.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)\n", " attention_interp_bilinear = (attention_interp_bilinear - attn_min) / (attn_max - attn_min + 1e-6)\n", "\n", " # Store step color\n", " current_colour = list(cmap_spectral(step_linspace[step_i]))\n", " route_colours_step.append(current_colour)\n", "\n", " # Find island center for each head\n", " for head_i in range(n_heads):\n", " attn_head_np = attention_interp_bilinear[head_i].detach().cpu().numpy()\n", " # Keep threshold=0.7 based on original call\n", " centers, areas = find_island_centers(attn_head_np, threshold=0.7)\n", "\n", " if centers: # If islands found\n", " largest_island_idx = np.argmax(areas)\n", " current_center = centers[largest_island_idx] # (y, x)\n", " head_routes[head_i].append(current_center)\n", " elif head_routes[head_i]: # If no center now, repeat last known center if history exists\n", " head_routes[head_i].append(head_routes[head_i][-1])\n", " \n", " \n", "\n", " # --- Plotting Setup ---\n", " mosaic = [['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],\n", " ['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],\n", " ['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],\n", " ['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],\n", " ['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay', 'head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],\n", " ['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay','head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],\n", " ['head_8', 'head_8_overlay', 'head_9', 'head_9_overlay','head_10', 'head_10_overlay', 'head_11', 'head_11_overlay'],\n", " ['head_12', 'head_12_overlay', 'head_13', 'head_13_overlay','head_14', 'head_14_overlay', 'head_15', 'head_15_overlay'],\n", " ['probabilities', 'probabilities','probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty'],\n", " ]\n", "\n", " img_aspect = data_img_np.shape[0] / data_img_np.shape[1]\n", " aspect_ratio = (8 * figscale, 9 * figscale * img_aspect) # W, H\n", " fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)\n", "\n", " for ax in axes.values():\n", " ax.axis('off')\n", "\n", " # --- Plot Certainty ---\n", " ax_cert = axes['certainty']\n", " ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)\n", " # Add background color based on prediction correctness at each step\n", " for ii in range(len(certainties_now)):\n", " is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor\n", " facecolor = 'limegreen' if is_correct else 'orchid'\n", " ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)\n", " # Mark the last point\n", " ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)\n", " ax_cert.axis('off')\n", " ax_cert.set_ylim([0.05, 1.05])\n", " ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit\n", "\n", " # --- Plot Probabilities ---\n", " ax_prob = axes['probabilities']\n", " # Get probabilities for the current step\n", " ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()\n", " k = 5 # Top k predictions\n", " topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)\n", " topk_indices = topk_indices.numpy()\n", " topk_probs = topk_probs.numpy()\n", "\n", " true_class_idx = ground_truth_target # Ground truth index\n", "\n", " # Determine bar colors (green if correct, blue otherwise - consistent with original)\n", " colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]\n", "\n", " # Plot horizontal bars (inverted range for top-down display)\n", " ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range\n", " ax_prob.set_xlim([0, 1])\n", " ax_prob.axis('off')\n", "\n", " # Add text labels for top classes\n", " for i, name_idx in enumerate(topk_indices):\n", " name = class_labels[name_idx] # Get name from index\n", " is_correct = name_idx == true_class_idx\n", " fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original\n", " text_str = f'{name[:40]}' # Truncate long names\n", " # Position text on the left side of the horizontal bars\n", " ax_prob.text(\n", " 0.01, # Small offset from left edge\n", " k - 1 - i, # Y-position corresponding to the bar\n", " text_str,\n", " #transform=ax_prob.transAxes, # Use data coordinates for Y\n", " verticalalignment='center',\n", " horizontalalignment='left',\n", " fontsize=8,\n", " color=fg_color,\n", " alpha=0.9, # Slightly more visible than 0.5\n", " path_effects=[\n", " patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke\n", " patheffects.Normal()\n", " ])\n", "\n", "\n", " # --- Plot Attention Heads & Overlays (using nearest interp) ---\n", " # Re-interpolate attention using nearest neighbor for visual plotting\n", " attention_interp_plot = F.interpolate(\n", " torch.from_numpy(attention_now).unsqueeze(0).float(),\n", " size=(img_h, img_w),\n", " mode=interp_mode, # 'nearest'\n", " ).squeeze(0)\n", "\n", " attn_mean = attention_interp_plot.mean(0)\n", " attn_mean_min = attn_mean.min()\n", " attn_mean_max = attn_mean.max()\n", " attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)\n", "\n", "\n", " # Normalize each head's map to [0, 1]\n", " attn_min_plot = attention_interp_plot.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)\n", " attn_max_plot = attention_interp_plot.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)\n", " attention_interp_plot = (attention_interp_plot - attn_min_plot) / (attn_max_plot - attn_min_plot + 1e-6)\n", " attention_interp_plot_np = attention_interp_plot.detach().cpu().numpy()\n", " \n", "\n", "\n", " \n", "\n", "\n", " for head_i in list(range(n_heads)) + [-1]:\n", " axname = f'head_{head_i}' if head_i != -1 else 'head_mean'\n", " if axname not in axes: continue # Skip if mosaic doesn't have this head\n", "\n", " ax = axes[axname]\n", " ax_overlay = axes[f'{axname}_overlay']\n", "\n", " # Plot attention heatmap\n", " this_attn = attention_interp_plot_np[head_i] if head_i != -1 else attn_mean\n", " img_to_plot = cmap_attention(this_attn)\n", " ax.imshow(img_to_plot)\n", " ax.axis('off')\n", "\n", " # Plot overlay: image + paths\n", " these_route_steps = head_routes[head_i]\n", " arrow_scale = 1.5 if head_i != -1 else 3\n", "\n", " if these_route_steps: # Only plot if path exists\n", " # Separate y and x coordinates\n", " y_coords, x_coords = zip(*these_route_steps)\n", " y_coords = np.array(y_coords)\n", " x_coords = np.array(x_coords)\n", "\n", " # Flip y-coordinates for correct plotting (imshow origin is top-left)\n", " # NOTE: Original flip seemed complex, simplifying to standard flip\n", " y_coords_flipped = img_h - 1 - y_coords\n", "\n", " # Show original image flipped vertically to match coordinate system\n", " ax_overlay.imshow(np.flipud(data_img_np), origin='lower')\n", "\n", " # Draw arrows for path segments\n", " # Arrow size scaling from original\n", " for i in range(len(these_route_steps) - 1):\n", " dx = x_coords[i+1] - x_coords[i]\n", " dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta\n", "\n", " # Draw white background arrow (thicker)\n", " ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,\n", " linewidth=1.6 * arrow_scale * 1.3,\n", " head_width=1.9 * arrow_scale * 1.3,\n", " head_length=1.4 * arrow_scale * 1.45,\n", " fc='white', ec='white', length_includes_head=True, alpha=1)\n", " # Draw colored foreground arrow\n", " ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,\n", " linewidth=1.6 * arrow_scale,\n", " head_width=1.9 * arrow_scale,\n", " head_length=1.4 * arrow_scale,\n", " fc=route_colours_step[i], ec=route_colours_step[i], # Use step color\n", " length_includes_head=True)\n", "\n", " else: # If no path yet, just show the image\n", " ax_overlay.imshow(np.flipud(data_img_np), origin='lower')\n", "\n", "\n", " # Set limits and turn off axes for overlay\n", " ax_overlay.set_xlim([0, img_w - 1])\n", " ax_overlay.set_ylim([0, img_h - 1])\n", " ax_overlay.axis('off')\n", " \n", "\n", " # --- Finalize and Save Frame ---\n", " fig.tight_layout(pad=0.1) # Adjust spacing\n", "\n", " # Render the plot to a numpy array\n", " canvas = fig.canvas\n", " canvas.draw()\n", " image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')\n", " image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB\n", "\n", " frames.append(image_numpy) # Add to list for GIF\n", "\n", " \n", "\n", " plt.close(fig) # Close figure to free memory\n", "\n", " # --- Save GIF ---\n", " gif_path = os.path.join(out_path, 'image_classification_prediction.gif')\n", " print(f\"Saving GIF to {gif_path}...\")\n", " mediapy.show_video(frames, width=400, codec=\"gif\")\n", " imageio.mimsave(gif_path, frames, fps=15, loop=0) # loop=0 means infinite loop\n", " pass" ] }, { "cell_type": "code", "execution_count": 34, "id": "d633b114", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Processing steps: 100%|██████████| 50/50 [00:40<00:00, 1.22step/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Saving GIF to 05_output/image_classification_prediction.gif...\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Make an output folder \n", "out_path = \"05_output\"\n", "os.makedirs(out_path, exist_ok=True)\n", "make_gif(predictions, certainties, attention_tracking, target, out_path, dataset_mean, dataset_std);" ] }, { "cell_type": "markdown", "id": "5a61d2bd", "metadata": {}, "source": [ "## Maze Datasets" ] }, { "cell_type": "markdown", "id": "ce28f52e", "metadata": {}, "source": [ "We will now showcase how to use the maze datasets from Hugging Face, as well as loading a pretrained maze solving CTM.\n", "\n", "We use [Hugging Face datasets](https://huggingface.co/docs/datasets/en/index) to load the maze data.\n", "\n", "## Dataset Structure\n", "\n", "Each dataset contains train and test splits with the following fields:\n", "\n", "- **`image`**: A PIL Image in RGB format representing the maze\n", " - **Red pixel** (255, 0, 0): Start position\n", " - **Green pixel** (0, 255, 0): End/goal position \n", " - **Blue pixel** (0, 0, 255): Valid path (walkable areas)\n", " - **Black/White**: Walls and obstacles\n", "\n", "- **`solution_path`**: A list of integers representing the sequence of moves to solve the maze\n", " - `0` = Move Up ↑\n", " - `1` = Move Down ↓\n", " - `2` = Move Left ←\n", " - `3` = Move Right →\n", " - `4` = Wait/Padding (used when solution is shorter than fixed length)\n", "\n", "## Dataset Sizes\n", "\n", "| Variant | Image Size | Solution Path Length |\n", "|---------|------------|---------------------|\n", "| small | 15×15 | 50 |\n", "| medium | 19×19| 50 |\n", "| large | 39×39 | 100 |\n", "| extralarge | 99×99 | 100 |\n", "\n", "When converted to tensors for training, images become shape `(3, H, W)` after applying `.permute(2, 0, 1)`." ] }, { "cell_type": "code", "execution_count": 30, "id": "f2bcd310", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAv4AAAMWCAYAAACJBYLiAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAATq9JREFUeJzt3Xm4HWWZIPD3JiF7SAIEIhCCEGQTg4Agi4QGJHaz2oQQGTAsQWPYQquoPSOgjQgCiqDAIAxRxLFRaFDakRZERKBpFkFRaRaJI4psSdiTkKTmDye3c3KTe885t7Z7v9/veXgeUqlT31df1anz5qu33urIsiwLAACgXxtQdQcAAIDiCfwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABAj8qa2f/exn0dHRET/72c86lx177LGx+eabV9YnAPLT0dERZ599duef582bFx0dHTF//vzK+gT9mcCf5P3nf/5nnH766bHHHnvE0KFDu/3R2XzzzaOjo6PLf7Nnz+5zbQOQr1au6a+99lrMnTs3Nt100xgyZEhsu+22cfnll/fJtuk7BlXdAajavffeG5dccklst912se2228bDDz/c7fo77rhjfPzjH29Y9o53vKPPtQ1QN8ccc0zMmDEjhgwZUnVX2tLsNX358uUxderUeOCBB+Kkk06KrbbaKm699daYM2dOLFy4MP7xH/+xT7VN3yHwJ3mHHHJILFq0KEaNGhUXXnhhj8H3JptsEkcffXSfbxugbgYOHBgDBw6suhtta/aafuONN8Y999wTV199dRx//PEREfGxj30spk2bFv/0T/8Us2bNig033LDPtE3fIdWHBq+++mrMnTs3Nt988xgyZEhsuOGG8f73vz8eeuihznX22WefeOc73xm/+tWvYsqUKTF8+PCYNGlSfP/734+IiDvvvDN22223GDZsWGy99dZx2223NbTxhz/8IebMmRNbb711DBs2LNZff/044ogjKsvpXG+99WLUqFEtfWbp0qXx+uuvr/Hvnn/++Rg3blzss88+kWVZ5/Inn3wyRowYEUceeWRhbQM06+yzz46Ojo54/PHH4+ijj47Ro0fHuHHj4rOf/WxkWRZ//OMf49BDD4111103xo8fHxdddFGXbSxZsiTOOuusmDRpUgwZMiQmTJgQZ5xxRixZsqTLeqeffnqMGzcuRo0aFYccckg888wzXba3phz/1Z8DWGnzzTePY489tstnf/GLX8Spp54a48aNizFjxsRHP/rRWLp0aSxatCg+/OEPx9ixY2Ps2LFxxhlnNFyj89DsNf2uu+6KiIgZM2Y0LJ8xY0YsXrw4br755ogo5vek2bbpnwT+NJg9e3Zcfvnlcfjhh8dll10Wn/jEJ2LYsGHxu9/9rmG9hQsXxkEHHRS77bZbfOlLX4ohQ4bEjBkz4p//+Z9jxowZ8Xd/93dx3nnnxeuvvx7Tpk2LV199tfOz999/f9xzzz0xY8aMuOSSS2L27Nlx++23xz777BNvvPFG2bvcsp/+9KcxfPjwGDlyZGy++ebx1a9+teHvN9xww7j88svjzjvvjEsvvTQiIlasWBHHHntsjBo1Ki677LLC2gZo1ZFHHhkrVqyI8847L3bbbbc455xz4uKLL473v//9sckmm8T5558fkyZNik984hPx85//vPNzK1asiEMOOSQuvPDCOPjgg+PSSy+Nww47LL7yla80BKQREbNmzYqLL744DjjggDjvvPNinXXWiQMPPLCQ/TnllFPiiSeeiM997nNxyCGHxJVXXhmf/exn4+CDD47ly5fHueeeG3vttVdccMEFce211xbSh54sWbIkBg4cGIMHD25YPnz48IiIePDBByOimN+TZtumn8pgFaNHj85OOumkbteZMmVKFhHZd77znc5ljz32WBYR2YABA7J///d/71x+6623ZhGRXXPNNZ3L3njjjS7bvPfee7OIyL71rW91LrvjjjuyiMjuuOOOzmUzZ87MJk6c2PqONemCCy7IIiJ7+umn1/j3Bx98cHb++ednN910U3b11Vdn73vf+7KIyM4444wu637oQx/Khg8fnj3++OOd273ppptKaRugJ2eddVYWEdlHPvKRzmXLli3LNt1006yjoyM777zzOpcvXLgwGzZsWDZz5szOZddee202YMCA7K677mrY7hVXXJFFRHb33XdnWZZlDz/8cBYR2Zw5cxrWO+qoo7KIyM4666zOZddcc02X6+Dq66w0ceLEhv6s/OzUqVOzFStWdC7ffffds46Ojmz27Nld9nPKlCndDVGvdHdNv+iii7KI6DJ2n/70p7OIyA466KCG5Xn+nrTaNv2LGX8ajBkzJu67777485//3O16I0eObLhNuPXWW8eYMWNi2223jd12261z+cr///3vf9+5bNiwYZ3//9Zbb8VLL70UkyZNijFjxjSkFNXRD37wgzjjjDPi0EMPjeOPPz7uvPPOmDp1anz5y1/uctv6a1/7WowePTqmTZsWn/3sZ+OYY46JQw89tJS2AZo1a9aszv8fOHBg7LLLLpFlWZxwwgmdy8eMGRNbb711w7X8e9/7Xmy77baxzTbbxIsvvtj537777hsREXfccUdERPzoRz+KiIhTTz21od25c+cWsj8nnHBCdHR0dP55t91267I/K/dz1f0p01FHHRWjR4+O448/Pn7yk5/E/Pnz48orr+ycwX/zzTcb1s/z96TVtulfBP40+NKXvhSPPvpoTJgwIXbdddc4++yz13hh3HTTTRsurBERo0ePjgkTJnRZFvHX1KCV3nzzzTjzzDNjwoQJMWTIkNhggw1i3LhxsWjRonj55Zd7vQ8vv/xy/OUvf+n8b8GCBb3e5tp0dHTE6aefHsuWLWt430DEX/MtL7nkkvjVr34Vo0ePjksuuaS0tgGatdlmmzX8efTo0TF06NDYYIMNuixf9Vr+xBNPxG9+85sYN25cw38rK409//zzEfHX57oGDBgQW265ZcP2tt566yJ2Z437ExFr/H1adX/WpKjfk/Hjx8cPfvCDWLJkSRxwwAHx9re/PT75yU92pvOMHDmyYf08f09abZv+RVUfGkyfPj3e9773xb/8y7/Ev/3bv8UFF1wQ559/ftx4443xt3/7t53rra3qwtqWZ6s8lHTKKafENddcE3Pnzo3dd989Ro8eHR0dHTFjxoxYsWJFr/fhtNNOi29+85udf54yZUqhgfHKH5M1/SDceuutEfHXf/g888wzMWbMmNLaBmjGmq7bzVzLV6xYETvssEN8+ctfXuO6qwfaeVu+fPkal7fy+5T18HBvkb8ne++9d/z+97+PX//61/H666/H5MmTO++2r6lMc56/J622Tf8h8KeLt73tbTFnzpyYM2dOPP/887HTTjvFF77whYbAvze+//3vx8yZMxsqRCxevDgWLVqUy/bPOOOMhpKXY8eOzWW7a7Pyjsi4ceMalv/4xz+Oq666Ks4444y47rrrYubMmXHffffFoEH5fe3W1jZA0bbccst45JFHYr/99utyB3hVEydOjBUrVsRTTz3VMMv/n//5n021M3bs2C6/D0uXLo1nn322rX63oujfk4EDB8aOO+7Y+eeVVfD233//hvWK+D1ptm36F6k+dFq+fHmXVJsNN9wwNt544y6l2Xpj4MCBXWZZLr300rXO3rRqu+22i/3337/zv5133jmX7S5YsKBLH996660477zzYvDgwfE3f/M3ncsXLVoUs2bNil133TXOPffcuOqqq+Khhx6Kc889t/C2Acowffr0+NOf/hTf+MY3uvzdm2++2Vl2eOWk0erpKRdffHFT7Wy55ZYN1YQiIq688srcfjO6U9TvyZq88MILcf7558e73vWuhuA779+TVtqm/zHjT6dXX301Nt1005g2bVpMnjw5Ro4cGbfddlvcf//9a6zf3K6DDjoorr322hg9enRst912ce+998Ztt90W66+/fm5ttOLll1/uzG28++67I+KvD1KNGTMmxowZEyeffHJE/PXh2nPOOSemTZsWb3/722PBggXxne98Jx599NE499xzY/z48Z3bPO200+Kll16K2267LQYOHBgf+MAHYtasWXHOOefEoYceGpMnTy6sbYAyHHPMMXH99dfH7Nmz44477og999wzli9fHo899lhcf/31ceutt8Yuu+wSO+64Y3zoQx+Kyy67LF5++eXYY4894vbbb48nn3yyqXZmzZoVs2fPjsMPPzze//73xyOPPBK33nprl2cQ6qDZa3rEX9OGdt9995g0aVL85S9/iSuvvDJee+21uOWWW2LAgP+al83796SVtumHqiwpRL0sWbIk++QnP5lNnjw5GzVqVDZixIhs8uTJ2WWXXdaw3pQpU7Ltt9++y+cnTpyYHXjggV2WR0RDidCFCxdmxx13XLbBBhtkI0eOzKZOnZo99thjXUqzlVXO8+mnn84iYo3/rdrWAw88kB188MHZJptskg0ePDgbOXJkttdee2XXX399w/ZuvvnmLCKyiy66qGH5K6+8kk2cODGbPHlytnTp0kLaBmjWynKeL7zwQsPymTNnZiNGjOiy/pqu/UuXLs3OP//8bPvtt8+GDBmSjR07Ntt5552zz33uc9nLL7/cud6bb76ZnXrqqdn666+fjRgxIjv44IOzP/7xj02V81y+fHn2qU99Kttggw2y4cOHZ1OnTs2efPLJtZbzvP/++3u1n73R7DU9y7Ls9NNPz7bYYotsyJAh2bhx47Kjjjoqe+qppxrWKeL3pNm26Z86sizn19YBAAC1434OAAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAKafnNvR0dHkf0AoEl1e/2K3weAeujp98GMPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkYFDVHQCAlbIsq7oL9EcdHflsx/lJATryOj+bYMYfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABJQm6o+KjlQR0U/ae+8X7six77ocS+zQgMANMuMPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIGVd0BAMhbR0dHbtvKsiy3beXZLyrg+PH/5XldKJMZfwAASIDAHwAAEiDwBwCABAj8AQAgAQJ/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIwqOoOrOQ15rSjr74yuwxFf6eMPQD0LWb8AQAgAQJ/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABJQmzr+AEDz+vu7NPJ8F4mxap6x6t/M+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIE/gAAkACBPwAAJGBQ1R1YKcuyqrsA1IhrAgDky4w/AAAkQOAPAAAJqE2qDwBQvo6Ojty21d9T9FIYqzz3MS91Hau+yIw/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkYVHUHACAVWZZV3YU+w1g1r45j1dHRUXUX1qiOY1UmM/4AAJCAJGb86/qvTnov9X+5V6kvf6+cNwCkyIw/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkQOAPAAAJEPgDAEACBlXdAQCos46Ojqq70GcYq+bVdayyLKu6C4Wq67iXxYw/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQALU8c9Bf695S9/U18/L1GstA0DezPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAgT+AACQgEFVdwAA8pZlWdVdgNJ0dHRU3YU+o47XhjKPnxl/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABAyqugP9QUdHR9Vd6JUsywrbdl8fm6L15fEp8rwBAPJnxh8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABCjnCUBt9OUSt9AKJZGb57qQHzP+AACQAIE/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkQOAPAAAJEPgDAEACBlXdgTJkWVZ1F3qlo6Oj6i7UVl8/tkVy3qxdRxgbANJjxh8AABKQxIw/AH1ETjfxsrw2RGXyvGvp7nD5jHnzyrxDb8YfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABAj8AQAgAQJ/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIwqOoOAMBKWWS5bKejoyOX7dCaLMvn+OXN+dA8Y9Wcup7rPTHjDwAACRD4AwBAAgT+AACQgCRy/IvOV+ureV70TK7j2vXlsckrj3xtOqLvjg0A/ZcZfwAASIDAHwAAEiDwBwCABAj8AQAgAQJ/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEjCo6g4AQJ1lWVZ1FwrV0dFRdRcK19+PIc1L4Xzvjhl/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAFe4EWhvDSle8YHACiLGX8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABAj8AQAgAQJ/AABIgDf3AkA3Ojo6qu7CGnnzd/nqei70d871/JjxBwCABAj8AQAgAQJ/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEjCo6g6s1NHRUXUXkmXs+yfHFQBYlRl/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEhAbar6AEAdZVmW27ZU22KlPM+rOqrrud7fx70nZvwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABAj8AQAgAQJ/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAGDqu7ASlmWVd2FZBn7tevo6Ki6C7XlvFk75w1QpjyvOXW8ttfxmlrHcWqGGX8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABAj8AQAgAQJ/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEjCo6g4AAKwuy7Kqu0Av1PX4dXR0VN2FSpnxBwCABAj8AQAgAQJ/AABIgMAfAAASkMTDvak/yFGlose+rg8PpaAvf6+cNwCkyIw/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkYFDVHQAA+oeOjo6qu0Av1fEYZllWdRf6DTP+AACQAIE/AAAkQOAPAAAJEPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAKSeIGXFz/0YwW/aKQvnzt9ue8AQP7M+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkQOAPAAAJSOIFXgCkxQvsmmes+r46HsOOgl+w2a46jlWZzPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAgT+AACQAIE/AAAkQOAPAAAJEPgDAEACBP4AAJCAQVV3YKWOjo6qu0AfVPhZ47wEAPoJM/4AAJAAgT8AACRA4A8AAAmoTY4/AHjeq3nGilQ41/Njxh8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABAj8AQAgAQJ/AABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABAyqugMA9G1ZllXdBQCa0HTg78IOAAB9l1QfAABIgMAfAAASIPAHAIAECPwBACABAn8AAEiAwB8AABIg8AcAgAQI/AEAIAECfwAASIDAHwAAEiDwBwCABAj8AQAgAQJ/AABIgMAfAAASIPCni3nz5kVHR0fMnz+/6q4AkICOjo44++yzq+4G9HsCf5J27rnnxnvf+94YN25cDB06NLbaaquYO3duvPDCC13WffLJJ2PatGkxduzYGD58eOy1115xxx139Mm2AVrxox/9SGCek5/85Cex1157xfDhw2Ps2LExbdq0NU60vfbaazF37tzYdNNNY8iQIbHtttvG5Zdf3qu2v/vd78ZOO+0UQ4cOjXHjxsUJJ5wQL774Ypf1nnvuuTjuuONiww03jGHDhsVOO+0U3/ve93rVNvUwqOoOQJUefPDB2HHHHWPGjBkxatSo+N3vfhff+MY34l//9V/j4YcfjhEjRkRExB//+MfYfffdY+DAgfHJT34yRowYEddcc00ccMABcfvtt8fee+/dp9oGaMWPfvSj+PrXvy7476VbbrklDj300Nhpp53ivPPOi1deeSW++tWvxl577RW//OUvY9y4cRERsXz58pg6dWo88MADcdJJJ8VWW20Vt956a8yZMycWLlwY//iP/9hy25dffnnMmTMn9ttvv/jyl78czzzzTHz1q1+NBx54IO67774YOnRoRES88sorsddee8Vzzz0Xp512WowfPz6uv/76mD59elx33XVx1FFH5TomlCyD1VxzzTVZRGRPP/10Ltt7/fXXc9lOWb7//e9nEZH97//9vzuXzZkzJxs0aFD22GOPdS57/fXXswkTJmQ77bRTv2gbYG1OOumkrNmQ4a233sqWLFnS0vYjIjvrrLPa6FlXr732Wi7bKcJ2222XTZo0qWF8Hn744WzAgAHZP/zDP3Quu/7667OIyK6++uqGzx9++OHZ0KFDs+eee66ldpcsWZKNGTMm23vvvbMVK1Z0Lv/hD3+YRUR2ySWXdC770pe+lEVEdvvtt3cuW758efae97wnGz9+fMvHlnqR6kNTbr755jjwwANj4403jiFDhsSWW24Z//RP/xTLly9vWG+fffaJd77znfHggw/G3nvvHcOHD++cmXjppZfimGOOiXXXXTfGjBkTM2fOjEceeSQ6Ojpi3rx5Ddt57LHHYtq0abHeeuvF0KFDY5dddokf/OAHpezr5ptvHhERixYt6lx21113xbvf/e7YeuutO5cNHz48DjnkkHjooYfiiSeeiIiIn/70pzFgwIA488wzG7b5ne98Jzo6Onq8TdubtgFW9ac//SmOP/742GijjWLIkCGx/fbbx//6X/+r8+/ffPPN2GabbWKbbbaJN998s3P5ggUL4m1ve1vssccesXz58jj22GPj61//ekT8NRd/5X8REfPnz4+Ojo648MIL4+KLL44tt9wyhgwZEr/97W9j6dKlceaZZ8bOO+8co0ePjhEjRsT73ve+ptIU//CHP8ScOXNi6623jmHDhsX6668fRxxxRJeUmJXPpN15550xZ86c2HDDDWPTTTft/Puvf/3rscUWW8SwYcNi1113jbvuuiv22Wef2GeffRq2s2TJkjjrrLNi0qRJMWTIkJgwYUKcccYZsWTJklaHfa0WLFgQv/3tb+ODH/xgDB48uHP55MmTY9ttt43vfve7ncvuuuuuiIiYMWNGwzZmzJgRixcvjptvvjkiIp5//vkYN25c7LPPPpFlWed6Tz75ZIwYMSKOPPLIiIh49NFHY9GiRXHkkUd2HruIiIMOOihGjhzZpe1x48bFvvvu27lswIABMX369PjLX/4Sd955Zx7DQUWk+tCUefPmxciRI+Mf/uEfYuTIkfHTn/40zjzzzHjllVfiggsuaFj3pZdeir/927+NGTNmxNFHHx0bbbRRrFixIg4++OD4j//4j/jYxz4W22yzTdx8880xc+bMLm395je/iT333DM22WST+PSnPx0jRoyI66+/Pg477LC44YYb4oMf/GCu+5ZlWbz00kuxbNmyeOKJJ+LTn/50DBw4sOGHYcmSJTF27Ngunx0+fHhE/DVtZ6uttop999035syZE1/84hfjsMMOi5122imeffbZOOWUU2L//feP2bNnF9Y2wErPPfdcvPe9742Ojo44+eSTY9y4cfF//s//iRNOOCFeeeWVmDt3bgwbNiy++c1vxp577hn//b//9/jyl78cEREnnXRSvPzyyzFv3rwYOHBgfPSjH40///nP8ZOf/CSuvfbaNbZ3zTXXxOLFi+MjH/lIDBkyJNZbb7145ZVX4qqrrooPfehDceKJJ8arr74aV199dUydOjX+4z/+I3bccce19v/++++Pe+65J2bMmBGbbrppzJ8/Py6//PLYZ5994re//W3n9W+lOXPmxLhx4+LMM8+M119/PSL+mtpy8sknx/ve9744/fTTY/78+XHYYYfF2LFjG/5xsGLFijjkkEPiF7/4RXzkIx+JbbfdNn7961/HV77ylXj88cfjpptu6t3B+P9W/iNi2LBhXf5u+PDh8Zvf/Cb+8pe/xPjx42PJkiUxcODAhn8grFwv4q/X/RNPPDE23HDDuPzyy+OII46ISy+9NE499dRYsWJFHHvssTFq1Ki47LLLemx72LBh8ctf/jJWrFgRAwYMiCVLlqy1jyvbfv/739+LkaBSFd9xoIbWlOrzxhtvdFnvox/9aDZ8+PBs8eLFncumTJmSRUR2xRVXNKx7ww03ZBGRXXzxxZ3Lli9fnu27775ZRGTXXHNN5/L99tsv22GHHRq2u2LFimyPPfbIttpqqxz2sNGzzz6bRUTnf5tuumn2z//8zw3rHHzwwdmYMWOyV155pWH57rvvnkVEduGFF3Yue/3117NJkyZl22+/fbZ48eLswAMPzNZdd93sD3/4Q+FtA2RZlp1wwgnZ2972tuzFF19sWD5jxoxs9OjRDdf0z3zmM9mAAQOyn//859n3vve9LtfqLFt7qs/TTz+dRUS27rrrZs8//3zD3y1btqxLWsjChQuzjTbaKDv++OMblsdqqT5r+s259957s4jIvvWtb3UuW/l7tddee2XLli3rXL5kyZJs/fXXz97znvdkb731VufyefPmZRGRTZkypXPZtddemw0YMCC76667Gtq74oorsojI7r777i59acfy5cuzMWPGZPvtt1/D8hdffDEbMWJEFhHZAw88kGVZll100UVZRHTp06c//eksIrKDDjqoYfmHPvShbPjw4dnjjz+eXXDBBVlEZDfddFPn37/wwgtZR0dHdsIJJzR87rHHHuv8/Vl5rpxyyinZgAEDsvnz5zesO2PGjCwispNPPrl3A0GlpPrQlFX/9f/qq6/Giy++GO973/vijTfeiMcee6xh3SFDhsRxxx3XsOzHP/5xrLPOOnHiiSd2LhswYECcdNJJDestWLAgfvrTn8b06dM723nxxRfjpZdeiqlTp8YTTzwRf/rTn3Ldt/XWWy9+8pOfxA9/+MP4/Oc/HxtssEG89tprDet87GMf67xN+stf/jIef/zxmDt3bjzwwAMREQ23yYcPHx7z5s2L3/3ud7H33nvHv/7rv8ZXvvKV2GyzzQpvGyDLsrjhhhvi4IMPjizLOq+jL774YkydOjVefvnleOihhzrXP/vss2P77bePmTNnxpw5c2LKlClx6qmnttTm4Ycf3vlg6kqrzlivWLEiFixYEMuWLYtddtmlof01WfU356233oqXXnopJk2aFGPGjFnjZ0888cQYOHBg558feOCBeOmll+LEE0+MQYP+K7nhv/23/9blDur3vve92HbbbWObbbZpGKuVqS55VVAbMGBAfPSjH43bb789PvOZz8QTTzwRDz74YEyfPj2WLl0aEf91PT/qqKNi9OjRcfzxx8dPfvKTmD9/flx55ZWdM/irX/e/9rWvxejRo2PatGnx2c9+No455pg49NBDO/9+gw02iOnTp8c3v/nNuOiii+L3v/993HXXXXHkkUfGOuus07DNWbNmxcCBA2P69Olxzz33xFNPPRVf/OIX41/+5V/W2DZ9TNX/8qB+1jTj/+ijj2aHHXZYtu666zbMUEdEduedd3auN2XKlGyLLbboss0DDjgg22yzzbosf+SRRxpm/O+7774u21/9v4ceemitfX/++eezZ599tvO/V199teX9v/vuu7OIyH74wx82LL/00ks7Z2UiIps0aVLnQ1Bf+cpXumxn5QzZ1KlTS28bSNdzzz3X43X0xhtvbPjM/fffn0VENnTo0Oz3v/99l232NOP/+c9/fo19mTdvXrbDDjtk66yzTkP7b3/72xvWizXM+H/2s5/NNt1006yjo6Phs8cdd1zneit/r37+8583bO873/lOFhHZT3/60y59eve7390w47/tttt2O1annnrqGvcty7Ls1VdfbfjNWf2ux+qWLFmSnXDCCdmAAQM6t3/AAQdks2fPziIi++Uvf9m57p133pltttlmneutu+662Te/+c0sIrJDDz20y7ZX3q3ZaKONsoULF3b5+0WLFmWHHHJIw74dffTR2d///d9nEdHwme9973vZ+uuv37ne+PHjs8svvzyLiOy0007rdh+pNzn+9GjRokUxZcqUWHfddePzn/98bLnlljF06NB46KGH4lOf+lSsWLGiYf015QY2a+W2PvGJT8TUqVPXuM6kSZPW+vn3vOc98Yc//KHzz2eddVbL5ef22GOPeNvb3hbXXXddHHTQQZ3LTz755DjuuOPiV7/6VQwePDh23HHHuPrqqyMi4h3veEfDNpYsWRI/+9nPIiLiqaeeijfeeKNLTmpRbQNpW3kdPfroo9f4HFVExLve9a6GP996660REbF48eJ44okn4u1vf3tLba7puv/tb387jj322DjssMPik5/8ZGy44YYxcODA+OIXvxhPPfVUt9s75ZRT4pprrom5c+fG7rvvHqNHj46Ojo6YMWNGl9+ctbXfrBUrVsQOO+zQ+YzD6iZMmLDWz1544YXxuc99rvPPEydO7Pbll4MHD46rrroqvvCFL8Tjjz8eG220UbzjHe+Io446KgYMGNDw+7b33nvH73//+/j1r38dr7/+ekyePDn+/Oc/R8Sar/srj+HChQvjmWeeiTFjxjT8/ejRo+Pmm2+O//t//2/Mnz8/Jk6cGBMnTow99tgjxo0b17D+tGnT4pBDDolHHnkkli9fHjvttFPnb5rfnL5N4E+Pfvazn8VLL70UN954Y0PN+KeffrrpbUycODHuuOOOLgHwk08+2bDeFltsERER66yzTuy///4t9/W6665ruA25cnutWrx4cbz88stdlo8YMSJ23333zj/fdtttMWzYsNhzzz0b1jvrrLPid7/7XVx44YXxqU99Kj796U/HJZdcUkrbQNrGjRsXo0aNiuXLlzd1Hf3Vr34Vn//85+O4446Lhx9+OGbNmhW//vWvY/To0Z3rrFoJplnf//73Y4sttogbb7yx4fNnnXVWU5+dOXNmXHTRRZ3LFi9e3FDxrDsTJ06MiL/+xvzN3/xN5/Jly5bF/PnzG/7hs+WWW8YjjzwS++23X8v7+eEPfzj22muvzj83+w+QjTbaKDbaaKOI+GvN/p/97Gex2267xciRIxvWGzhwYMND0LfddltERJfj+uMf/ziuuuqqOOOMM+K6666LmTNnxn333deQ5rTSZptt1pl6umjRonjwwQfj8MMP77Le4MGD4z3veU+PbdO3yPGnRyvzJrNVSoUtXbq0M9ewGVOnTo233norvvGNb3QuW7FiRWeJuJU23HDD2GeffeJ//s//Gc8++2yX7azprbar2nPPPWP//ffv/K+7wP/111+PN954o8vyG264IRYuXBi77LJLt23dc889ceONN8YJJ5zQ8AN53333xYUXXhhz586Nj3/84/HJT34yvva1rzWUQCuqbYCBAwfG4YcfHjfccEM8+uijXf5+1evoW2+9Fccee2xsvPHG8dWvfjXmzZsXzz33XJx++ukNn1n5QsFmA++V/Yho/O2477774t57723qs6t+LiLi0ksv7VJCem122WWXWH/99eMb3/hGLFu2rHP5ddddFwsXLmxYd/r06fGnP/2p4fdppTfffLOzStCabLHFFg2/Oe1MxFx44YXx7LPPxsc//vFu13vhhRfi/PPPj3e9610NwfeiRYti1qxZseuuu8a5554bV111VTz00ENx7rnn9tj2Zz7zmVi2bFmX4726J554Iq644oo46KCDzPj3cWb86dEee+wRY8eOjZkzZ8app54aHR0dce2113a5KHfnsMMOi1133TU+/vGPx5NPPhnbbLNN/OAHP4gFCxZERONs0te//vXYa6+9YocddogTTzwxtthii3juuefi3nvvjWeeeSYeeeSRXPbriSeeiP333z+OPPLI2GabbWLAgAHxwAMPxLe//e3YfPPN47TTTutc9w9/+ENMnz49DjnkkBg/fnz85je/iSuuuCLe9a53NVxcFy9eHDNnzoytttoqvvCFL0RExOc+97n44Q9/GMcdd1z8+te/jhEjRhTSNsBK5513Xtxxxx2x2267xYknnhjbbbddLFiwIB566KG47bbbOq+955xzTjz88MNx++23x6hRo+Jd73pXnHnmmfE//sf/iGnTpsXf/d3fRUTEzjvvHBERp556akydOjUGDhzYpcb86g466KC48cYb44Mf/GAceOCB8fTTT8cVV1wR2223XZciBmv67LXXXhujR4+O7bbbLu6999647bbbYv31129q/wcPHhxnn312nHLKKbHvvvvG9OnTY/78+TFv3rzYcsstG35zjjnmmLj++utj9uzZcccdd8See+4Zy5cvj8ceeyyuv/76uPXWW3ucjGnWt7/97bjhhhti7733jpEjR8Ztt90W119/fcyaNavLrPuUKVNi9913j0mTJsVf/vKXuPLKK+O1116LW265JQYM+K9529NOOy1eeumluO2222LgwIHxgQ98IGbNmhXnnHNOHHrooTF58uSI+Os58eijj8Zuu+0WgwYNiptuuin+7d/+Lc4555yGmf2IiO222y6OOOKI2GyzzeLpp5+Oyy+/PNZbb7244oorchkHKlTtIwbU0Zoe7r377ruz9773vdmwYcOyjTfeODvjjDOyW2+9NYuI7I477uhcb8qUKdn222+/xu2+8MIL2VFHHZWNGjUqGz16dHbsscd2Psz63e9+t2Hdp556Kvvwhz+cjR8/PltnnXWyTTbZJDvooIOy73//+7nt5wsvvJB95CMfybbZZptsxIgR2eDBg7Otttoqmzt3bvbCCy80rLtgwYLs0EMPzcaPH58NHjw4e/vb35596lOf6lJi8/TTT88GDhyY3XfffQ3LH3jggWzQoEHZxz72scLaBljVc889l5100knZhAkTsnXWWScbP358tt9++2VXXnlllmVZ9uCDD2aDBg3KTjnllIbPLVu2LHvPe96Tbbzxxp0PfC5btiw75ZRTsnHjxnU+bJtl//Vw7wUXXNCl/RUrVmTnnntuNnHixGzIkCHZu9/97uyWW27JZs6cmU2cOLFh3Vjt4d6FCxdmxx13XLbBBhtkI0eOzKZOnZo99thj2cSJE7OZM2d2rrfy9+r+++9f4xhccsklne3vuuuu2d13353tvPPO2Qc+8IGG9ZYuXZqdf/752fbbb58NGTIkGzt2bLbzzjtnn/vc57KXX365meFuyn333Zftvffe2dixY7OhQ4dmkydPzq644oqGt+mudPrpp2dbbLFFNmTIkGzcuHHZUUcdlT311FMN69x8881ZRGQXXXRRw/JXXnklmzhxYjZ58uRs6dKlWZZl2S233JLtuuuu2ahRo7Lhw4dn733ve7Prr79+jf2cMWNGNmHChGzw4MHZxhtvnM2ePbvltwVTTx1Z1sK0LeTspptuig9+8IPxi1/8Qq46AIVasWJFjBs3Lv7+7/9+jak90N/J8ac0q9f+Xb58eVx66aWx7rrrxk477VRRrwDojxYvXtwlJfVb3/pWLFiwoOHt6JASOf6U5pRTTok333wzdt9991iyZEnceOONcc8998S5557bq1JsALC6f//3f4/TTz89jjjiiFh//fXjoYceiquvvjre+c53xhFHHFF196ASAn9Ks++++8ZFF10Ut9xySyxevDgmTZoUl156aZx88slVdw2AfmbzzTePCRMmxCWXXBILFiyI9dZbLz784Q/Heeed1/lGYUiNHH8AAEiAHH8AAEiAwB8AABIg8AcAgAQ0/XDvqm+5A6A6dXs0y+8DQD309Ptgxh8AABIg8AcAgASo4w9AruqWilS11VOhyhifVdssu70826xi7KrQX45XFYoYu740Pq2mWprxBwCABAj8AQAgAVJ9AChMK7ehV7+dnle1oFW3W0YFolbSAorYx7q3mddxrns1qe7Ou+7Grrv96s3Ytdtmu8r4Pvckr/Fpd5u90e750xMz/gAAkIBcZ/zr/PADvdPMv3Adf3riPOpZ3WcxAei7zPgDAEAC5PgDUJp27+gUdSeoqPJ/7bZRRp5xu/tc1N2ovPLdu/tc3e8k5pV/3+x+9rReuyUy2+1rGcerlWcOivjO9qSsu71m/AEAIAECfwAASIBUHwAqkVeJw54UkT5TRmnEvManirSg7uSVYtGb/ar6Ifq8xqCV7XYnr9KsRSjrWBX1vV2bVtKt8mTGHwAAEiDwBwCABAj8AQAgAaXn+Of5Ap+qc/T6izLLnDlm5CGv64jzse8oK/e7uzaLystuR1HtFZUjXkTJzt6sm8fnWtHT8xpl5Oa30l6dz+e8+lZGWdu89ivP3yoz/gAAkACBPwAAJEDgDwAACVDHH4BaaCWPtexa4t21X4c288oZb0URbZRdS73IPpT9DoSelJVD3qyqj3V3z11U3beI4q5rZvwBACABAn8AAEiAVB8AaqGKlJ1W1Kl/dSjHmFcbZaRtlVHqs6hyr3U7lmWXHq1qu3m0UYfzZ3Vm/AEAIAH9fsa/TjM0VajzQ0p1aCtPeb6cjnLU8Xh4qRgARTHjDwAACej3M/4A9E1Vl9Sr292XupVjbLf9MvLJi9pudyUgW2mjbseyjPzyojIQyr5O5FUqtqdxLeo8MOMPAAAJEPgDAEACpPoAUIlWUghaub1fRPpFFaroayupLN3Jq3/tbqeKcqd5pe/0JhWqbudwd5rta1Fj126bZaXoFDE+EWb8AQAgCQJ/AABIgMAfAAASIMc/qi+D1o488/j64v6X3ecy8yb74vGoozqOY1/Kvy1KEcelN/ncdTtP2s3r7e4ZiLzKH7byuaJKWa762aJKa5axn61opQTk2j7XkyJKZPZmHPP6zub1fEDVzwbl+SyHGX8AAEiAwB8AABIg8AcAgATI8QegNGXXTy+j/Yhy8rCL0JvxyGs/i3gGI69a62U9l9Nf3iuQx+eK2m5vnmMo4zwo61wz4w8AAAkQ+AMAQAKk+gBQiaLSXMpIn+lNmku7t/TL2K+8Sn1WURqxjHSivI5BFed+1WPQl7/v3bVX5+/zmpjxBwCABAj8AQAgAVJ9mtRX39zaX98UWsc3F5c51v31uPZlVVdnAYCeCPwBqIWqy//llbublyLKXPamzaK2W0TJzp62WXWpy1aUkbdfRcnOZj/bm3K9RT1Pkoeqyv5K9QEAgAQI/AEAIAFSfQCopTJKJxZxu703/S47NaFuZSV7s912P5fXuJbxvFhvUphaUUbJzqqfi6q6NGtZx3J1ZvwBACABAn8AAEiAwB8AABIgxx+AwuSVp9pdmb4qyvIVtV+rKiMvvayxK6NkZ9XlGVeX1/MjfekcKeM5lKrL/vbmc3U4R3MN/Kt+UKNqZe5/6mPdjDq+5CsvdetPnpo5bs3sf17bAYD+QqoPAAAkQOAPAAAJkOMPQGGqSKfqLie4qHzhMvazv7TRm/abPV5V78fq8qpvX8R7J1pVxthWvZ+9uU7U7dxbnRl/AABIgMAfAAASINUHgNL0tTKU7bRZh5J9eSl7X8o6P9otB9vdunUv61hEm71Jgam6P2W0UcdrgRl/AABIgMAfAAASkGuqT27VEaLeT0SvTR1v6eT1oqO6yfOp+brtf389ZhH1r3awJn11rAFgdXL8AchVs//Aq6LEYSvbrXtZxbLVoTRru5/r7njV4fjk1Yeyv3vdrVuH70jZ41rWdnpzzZPqAwAACRD4AwBAAqT6AJCrdm9Dl5EaUIdnNvLoQxmlEXuS1xuRWzmWZZT7bDddJS/9tc2qy2D2Zpt5nc/t9iHPVCgz/gAAkACBPwAAJEDgDwAACZDjD0BhepOb2myudU9tFFEqMLf31hSU397d2FWRw55Xm3mV91xVFWUn61Dqstk2uzt/mt1GXn2p43bavTb1NK5FfTdLD/zrUC+3HXXrd9kP/5S5/1X8oJalzONW9v735+PWjP784jUA+gepPgAAkACBPwAAJECOPwClKbsOe55tlpGGJh2se1UfyyrOu7xqv1fdn+705hmMIurm90Zvnkcqgxl/AABIgMAfAAASINUHgEq0ctu7inKMVd+yr0NaQJ0VVQq1lTaLOEd6c9zL7k+7pT7X9Nlm2y/qe7nqdqu4NrWiN/tpxh8AABIg8AcAgASUnupTx4oFZVYJqOP+56WZfeuLLwIrW5n9bvdWa9Ht9UQKBAC0To4/AJXozT8Ei/jHaE/brHoyoer2q9DKPld9LKs4PnXOLy9rPJrN+e9Nf4p6jiiP9lttT6oPAAAkQOAPAAAJkOoDQCXq/qxG3fpXRH9S2Me+tt3elMisc5pJGeuWkZKT1xuRezMevUlbMuMPAAAJEPgDAEACBP4AAJAAOf4A1EIZ5f/aLQ9ZhbqNR1HbqVtpzVZyr8t4N0nV5+Hqyj5nevPcQt2+73mW5WxXroF/3R4SylOZX7zUXwSWl/58PtZR3V7yVUfOSQCqJNUHAAASINUHgFrK6w5JXmkU7b6xNK8yhq28mbZdeZU4LKKN3ujumJSRrtNTf9rdTtV6c65397kySmRWvZ2q7m6b8QcAgAQI/AEAIAECfwAASIAcfwAKU7cyj3WoPlV1+cp2846rKHPZrqLy5NstD9lTf5od994c5+76U8azLmUoYj9W325vvgd1eH7DjD8AACRA4A8AAAnINdWn6lt77arDrZdU1OEtkUXpzy9e66v9LlMVJQEBoBVy/AEoTBn13XvKoy2jJngZ9e/b3W5v8vaLGLu65T13134V70vIK4+/N/I6f9pdNy9FvK+hr5PqAwAACRD4AwBAAqT6AFCYVtIEuksBKeM2fU9ttFvGMK+ygXV+Rqpufatb6dG8zq0826x6u+2WyOxuO/TMjD8AACRA4A8AAAkQ+AMAQALk+AOQq7xKFbbbRl7bqXPucN3KKFZR9nJ1ZZSy7O65i6rLkhbVhzrs16ra/c7WbT+qYsYfAAASUPqMf9n/4qrzjE1/05//NV23Wco6amaM+uL+59ln1yMAqiTVB4BcNZv+UIcSkP3lH2Ptvu21N2+Jzas/rehLxyuvsW1lO3Uu1VpU++2W2U2VVB8AAEiAwB8AABIg8AcAgATI8QegNHUu9dkXHz5fm7zy+Ks+JmVvs6g+tDKueX1Hqih525fK7Nbh/OlOUf0z4w8AAAkQ+AMAQAIE/gAAkIDSc/zzrKNat/ys/lwjtsycurLHMa/zqD8f/5THqC/2uU6qGL9W8qc7or1c+Kr1pb6uroz69j1ds+rwDokittPsNot6lqOIbRYV6/Xm/CmizbL6Y8YfAAASIPAHAIAEKOcJQGHKSMnsVVpHx6r/W6/00dVVkd5a9vHrKT2kiNSfKlJSihrXMsYnr1Kx3Wn3mHTXRlnpX6202Z2i0vnM+AMAQAIE/gAAkACBPwAAJECOPwClWTVvNa+Sgq18ruoSh62oQ1/70hg0+7mePttuf+pwvLprs+z2W2mzipLRRY1PEc8n5MmMPwAAJCDXGf/+/JKfur0sLM/+9MXjVvbxqNvxz4sX6uWnji+nA4BVSfUBoJaa/cdkUSXzikgH6Ule6QdF/EM8r232puxl2SU7y5rQaLad7kpr5tVGT202u80qymd214cqSsV2p6qJIKk+AACQAIE/AAAkQOAPAAAJkOMPQGHqUBavuxKiebTX03a7U0UZziIKOtQhnzuvNoooy9nKmPcmj7/d/pRRSrcvFYDI63iVcU62yow/AAAkQOAPAAAJEPgDAEACcs3x76svp8mr317gUx7j2LM88yl9R2hFX6qH3V/aLOrZhXZVnc+dV/3/otosQ2/60+6YFJHfnpc6PIdSxnsXemLGHwAAEiDwBwCABCjnCUCuWinzuLbPFaVuKTrtttfTNotIE+iuzbzGtYr0kDqkpJRRTrNdVZekLKokcBXpaEUcr1b3w4w/AAAkQOAPAAAJEPgDAEAC5PgDUEtF5OBW8cxBd2UV8yibWKRmnzMoqj9V5GGX0WbZJW/z1GwfelM+s4jvRX8qa9ubfTHjDwAACSh9xr/sf63m1V6Z/1LM8yVHdZgdaFUd+9zMeNex3z2p40u+mtEXxzqi7/YbgP5Bqg8AhSmjFF8rbfQmvaCocpZ5bKeM9utQmrXq8pV5vZ139c8VkaJSVNpLEf3Jq43etNmsnvpWxHUrT1J9AAAgAQJ/AABIgMAfAAASIMcfgMK0kmNbRYnDKkon5lUOsYjPVr3Pvdlmu8cyr1z4vPL/63Y+t/u5nva/iO9BUXnzRY1Bu20o5wkAAHRL4A8AAAkQ+AMAQAJKz/Gvqm5pd/riS77yVLcXL/XVcWyGsc5HHWs817FPddCbcalbbfOya3DXoaZ+EftZVN5z3dpst/3uavzX4VhW3Z+qv3s9KeJ5gDzPSTP+AACQAIE/AAAkQDlPAHJVt1vd7W636vSQolRd6rOIFK6ettvdulWUgKwiNbCV/uRxjpT1/cnrWDY7PmWVbS3qemjGHwAAEiDwBwCABAj8AQAgAXL8AchVszmvRZWkLKvUZbPb6S4nOK984Xb1pszkqnqTr9yuupUjrmJ82j3v8mqjKN19R8pos4o22n0mpNXxMeMPAAAJKH3Gv+zKB3V7YVJfbKvZ9vrrWJetbmMd0TfP7Tq+UKtuM5QApEWqDwCFKSqVpewSfnl9rhVVTHCUUbKzO628PbkOyj5/8koZyqt8Zn9V1D4WVX61FVJ9AAAgAQJ/AABIgMAfAAASIMcfgFoquwRkb8qAFpETXMWzC3n2oVlVlF+tuohAK+dPGc+aVF3qs26FD/Isn9lOG0Uy4w8AAAkQ+AMAQAIE/gAAkIBa5viXUYe2KnXLY+ur6jiOffF8zFMdj0lP8uxz6se/Wd2NU1G1zcuouZ+XIurfV1Gbv4xnIKp4NqA7/an2e7Nt9qa9Ir6XRb2foO7fr1aY8QcAgAQI/AEAIAG1TPUBoH8oqiRlXm20oox2yi5huroy0h2KGseqS0v2ZnyabbPu53oV50ER26lDemCz/WmVGX8AAEiAwB8AABIg8AcAgATI8QegMFXk/NahrGKd21t1u1WUXGyljXbXLWKbrX62iPKZeX0ur+3k9dxHWcerXf2pXLMZfwAASIDAHwAAEiDVp4b60y2lopQ9RnV7u2cdOW9ZkzJKHPamD3XaZhVtlNWHIspVFpWG1EoqVLP9aUVv+l71G4rrVvK13WNZxflTFjP+AACQAIE/AAAkQOAPAAAJkOMPQGHqUKavjOdPqi6rWIYqjuXq22k3h3z1z1Wdx1/GWHY3dmXsR1558j1tt9ljWdR3q4rzpzfM+AMAQAIE/gAAkACBPwAAJECOPwClqaKuddltllUvvF11qDffV9vP6xmDVj7bynbaVUTd/Craz6uNVp4taVdV14nSA/+qLxT9hXHsWbNj1F9+BKuW+v4DQN1J9QEAgARI9QGgElXcbauinF4R+1mHsoDtllwsqs0i9NReUak/zcqr3Gm76xZVsrMM7abz9OWywxFm/AEAIAkCfwAASIDAHwAAEiDHH4BKtFLiMK+c6FZyoIt6dqDq0oB55Z63ux91qABWRg57K89klH1OVPHsQh2+e82qotRmu+PT6v6b8QcAgAQI/AEAIAG5pvpUXZqpSH113/pqv+vGOOYj9XGsQ4pDGbo7znmlmeRVmrDdNssum9hTm91tt4y0iaL2uRVllDvN642uRY1zu+d6EZ/rSdXlM4tqp9ntVJH+FWHGHwAAkiDwBwCABAj8AQAgAcp5ApCrMnJ3i8itLqMsX5GfzUMdxqAIZfSnihKQ7WqllG6rny16m0WNXV7XrbKf+2iVGX8AAEiAwB8AABIg8AcAgATI8QegMFXXPa9qu+22Ufa7Lqqokb76OVDF+z3KqNleRP37osaujL52pzf7UcT508p1q4zxyVOugX/ZO1HGCZ+CMve/ji9wqtvxr+MY5aUvjnXd+gwA7ZLqAwAACZDqA0AtFXG3Ja/Uo1a0kibQV/tTtzvwrfQnr1KSRbWZVzpREftZdWnNvD7XG2WlNOXFjD8AACRA4A8AAAkQ+AMAQALk+ANQS2XnjZdVMrTs/rSbI97Tdooocdib7VRRKraIZxKKOibdfS4vrexHs32oW6nPVtqsY1U4M/4AAJAAgT8AACSg9FSfsl+YU4fSSf1Bf37RkXMkH3UbxyrKApKvqt/yW7drWhnndFFlFPNqc9V1i9hmKlo5Xv2lZGde+9Gb7dThXDPjDwAACRD4AwBAAgT+AACQAOU8AaiFuj1PUbf+rCqvMpdFlUrsTUnKdtVtO3VT5/HJ6zws6rxrt391PJfM+AMAQAIE/gAAkACBPwAAJECOPwCVKKqmdX+p49+bWvRF1OrvTh3GvIj9yitnvIpzqZVnO/Kqjd/ueVDUcyjttlmUdp9HyLOvfTrwr/qiTOv66jErs991fBiobsetDkFGUep4/AHoH6T6AABAAvr0jD8AfUsRdzTKunNTxt2YupVK7E4Z495uSkpR+9xKf6o4lnUrsdpuG3VOR2tFFSlMPTHjDwAACRD4AwBAAgT+AACQADn+AJQmr9zdqqsf1bEi1Kra7V9eZR5bkVcpybzWraJqWN1KgXbXfrt9bTffvW4lXnvS7PlT1bXQjD8AACRA4A8AAAmoZapP1bdwU1PmePfnY9vMvlV9y3Z1dTwe/SUVZE3qdvyr1lNaSR2P4dq00tdV97O/7mO72+npO1JEH6ooS5rXuq2o2372pe10p4rfrN60acYfAAASIPAHAIAECPwBACABtczxB6B/yCv/tW7baWW7ZeXulq3dZxXqXp6xiDb60nHtSbP7UtY+1718Zrt9KGr8zPgDAEACBP4AAJAAgT8AACRAjj8AuSqidnbd6nr3phZ9Hm30J+3uZ93OiarbKEvVx6tudfPLeEdDnudPLQP//vQQTH9Rty9sf1b3hxh7w3kEANWR6gMAAAmo5Yw/AH1Xs3dkiioBWdR2q77TVHX7RaU3FVEKdfXPVZ1+VdSxK2o/m7V6e0WcIz2Nebvbbfe866m9dj9bxOfWxIw/AAAkQOAPAAAJEPgDAEAC5PgDUJiiyhiuut3e5DXnlbNddUnI7vrTm74W8bxGGaVQW+lDK2PXbht5KepZlyK0kptfh5KYdci/z6PNnpjxBwCABAj8AQAgAVJ9AChMFSUoiyrRWfbt/qJSCJptI8/t5JU+06y6nXe9KbtZxvlTRFnOokq8rq7dcyuvUrHdrZuXfv/m3mZVnb/WX1Rd37goVde8XpO6jVEd9ecxquM5CUA6pPoAAEACBP4AAJCAPp3qA0C9VV3msjfbKSPHvpXtVpE7XLcSh92pIk0wr+NVRgnRqtN68/pcFePabBt5Kqq/ZvwBACABAn8AAEiAVB8AClP3tIVWFNFmFSkxebVZdYnDurVRxfEqqj9ll+wso69lfa6I8qJ5nttm/AEAIAECfwAASEC/T/Vp5vZIM7eY+uqLd+r2dsf+PNZ9td99UannUbPfIccfgJrr94E/AOVqd8KhivzpVpTRh7JLdva1Eofttr/6fjbbv7qVCK17m3Uu31tF2d/uPtvKswKr6833VqoPAAAkQOAPAAAJEPgDAEAC5PgDkKtm80+rqP3em/zcPD63ur763EAVbRSlDn1v9vmWnp5VqPp8LmKbVRyfvvRegVaZ8QcAgAQI/AEAIAFSfQAoTCupCEWlveSVepTXZ4sYg1Y+V0bpxnbLZ9ZB1SUp65BGlscY9CaVr4rzpd3+VFFetDfniMA/R33pwtaqMl+YVMdxLPNFaHXc/zKVmc+Z61gnftwAqD+pPgAAkACBPwAAJECqDwCVKCqtq90Urp760+x2e5PfXnYpwN4cg7qXLWxXK7neZaS4FjHOVZTSbbe9vjTmq2+3jinQZvwBACABAn8AAEiAVB8ASlPnEpB1q6hVRTnGvlROdHWrnj959aeVVI2qS7r2tJ122yzqe9nu8Vpduyl47W6n1e2220ZR6VZm/AEAIAECfwAASIBUnxzVsVJB3W5d13GM8lLmWJc9jmW+wK1Mdexz3b6zAPQfAn8AClNFScEy37Rd5HaKGruiyhgWsZ3ulFGOsTfbbTeHvaic8bzabPdzdStrW9T1pbvtlvFsR0+k+gAAQAIE/gAAkACBPwAAJECOPwCFKaO+e1kPaVdd477dfOo65BWXkYtexL7U4WH7Zo9lUW22+7m8+lq3d0RU1Ye8mPEHAIAECPwBACABUn0AKE0Vt/+L0Jv9qNP7I6roS1FtVlFCtC+VP22lvTqkh63UU1+rTi9aXbNjV9V1QODfpKp/ZFaX5wmT177VbYzyVKcf6mbV8aVbZZ4jdTwf++J5BED/IdUHAAASIPAHAIAESPUBoBKtpGMVlSZVdapj3VLSquhPEedBGeVDe2qzTnnyVbXZne76U0RZ2560e42p+hrSKjP+AACQAIE/AAAkQKoPALXQyq32dm/Ll5F+UcZ+FFVusOqUqtXbL+ptzmWXVexNG0WMQVFlSVuRV9pWu+dz3c+too6JGX8AAEiAwB8AABJQeqpPHV9gU+ab3pppq25vxevv6lbpoD9L/VwDgCrJ8QcgV3UrbZlXacA82lu9zb5UbrDuqigB2e52unuWoYrnAYraTrvt9aVSqGWcW3nus1QfAABIgMAfAAASIPAHAIAEyPEHIFdF1NjPS1H5/nnlJLeb/9/sNnujitrvZddob6WNVtYtKi+9imPQl96B0J2+/L3sTX/M+AMAQAIE/gAAkACpPgAUpje3xcu4pV5Eeb0q3lfRyn4UUUK0r5aRbFXV/au6/Z7U+Tyo23tkWin7m6dcA/+6n5Brk9eLt/qzvnhs+/Mxq+PxqGOfAID/ItUHAAASIPAHAIAEyPEHoDBFlLnsSRX59+2muhVV+rTZ/hQ1Hv21hGhv1m1XK9+hOqW4FlUKte7Xgu60ez3Ms69m/AEAIAECfwAASIBUHwAKU1QqQitvD61bacA8+lPGuOapiNSjMkpH9qY/7b5dtYqSmN29nTcvZaT2VfH25O705twqKvXHjD8AACRA4A8AAAnINdWnmVsP/fllWXXsd5l9KrOtPG/Z1e241fGWfd3GqD9fR7wIDYCiyPEHoBJ1LyW5urxybptdtzf/CGy6Pz200ex2elqvjDKGdZw0WZuqS0n2tI9l5Mq3+wxEu22UJa82ixofqT4AAJAAgT8AACRA4A8AAAmQ4w9AadrNVa2iDnoZec6t1OPvrs22+7N6e6ttp+yHzetQwz4vVZyz7SqqZnzVyhjXqo9dq8z4AwBAAgT+AACQAKk+ABSmDmkCRfShL5cJrFtax6r9aaX9OqTzFNFGHfpaxDnSl78zq2olPa+OzPgDAEAC+vSMf5lv76zbwxt160+zUj5mEeU8dNiKst/u2xePf54vW+prM0MA9C9m/AEAIAF9esYfgPqp4922lYrqW53LXpZRlrQVvcnjr9sdwbqVi1x1fKo+R+rwXaviLnt3z6w0+3dFMuMPAAAJEPgDAEACpPoAkKvu0g3KeMC5qHKVzd6KL6okZSttlF2ys4rjXFT5zL5UJrTqNKC89OZ8LeJcr0O6YlHjbsYfAAASIPAHAIAECPwBACABpef41/EFNnXsE91zzHrWn8eoP+9bSnqTZ1xGSb++VDqyipKdZTzLUNR28lJEf6oq89isql+OWNS5XvbvShVlSSPM+AMAQBIE/gAAkACBPwAAJEAdfwAqUVR97nZryveUR9tuDnARddjzar8sRdXcL0IrbZT9HEjdj+Wq+1HWswrt9qeI9nrz2bKe4THjDwAACRD4AwBAAqT6AFCaIm5n96asZBllMNtNPWplm2V9NoU22t1uXmlbVZV5bKeNolKPumuzqGtBXus2+7miUh17YsYfAAASkOuMf91eMhFRzz7RPccsbf35+PfnfQOg/sz4AwBAAuT4A1CYupW+a7e9onTXZlFlJetQErIIdSjX2O52ym6zp/WKOEfyOtfzUkWbVTwvsToz/gAAkACBPwAAJKAja/LegofSAOqhbqkafh8A6qGn3wcz/gAAkACBPwAAJEDgDwAACWg6xx8AAOi7zPgDAEACBP4AAJAAgT8AACRA4A8AAAkQ+AMAQAIE/gAAkACBPwAAJEDgDwAACRD4AwBAAv4fdIkmvpLUi8EAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sizes = [\"small\", \"medium\", \"large\", \"extralarge\"]\n", "datasets = {}\n", "\n", "for size in sizes:\n", " repo_name = f\"SakanaAI/mazes-{size}\"\n", " datasets[size] = load_dataset(repo_name)\n", "\n", "fig, axes = plt.subplots(2, 2, figsize=(8, 8))\n", "axes = axes.flatten()\n", "\n", "for idx, size in enumerate(sizes):\n", " \n", " # Get first example from train set\n", " example = datasets[size]['train'][0]\n", " image = example['image'] \n", " axes[idx].imshow(image)\n", " \n", " axes[idx].set_title(f\"{size} - {image.size[0]}x{image.size[1]}\")\n", " axes[idx].axis('off')\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"maze_sizes_comparison.png\", dpi=50, bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "7ac76d3f", "metadata": {}, "source": [ "Let's examine a maze dataset in more detail. First let's display a small maze." ] }, { "cell_type": "code", "execution_count": 31, "id": "164e4fc3", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZQAAAGVCAYAAADZmQcFAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAABv5JREFUeJzt3MFO7EYURVFu5P//5ZtphF4UUDYu2r3WGMkHdzVbNWF2dz8A4H/66/QAAJ5BUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQOL66g/OzE/uAOAX+8o/VXFDASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASBxnR7w2e6engCJmbnlOb4z3/ekz+au3+Ur3FAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJA4jo94LOZOT2BN7C7pye8nDu+mz6X1+aGAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkDiOj3gs909PQF4OH9nfoYbCgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgCJ6/SAE2bm9AQO293TE/iDp3033+2cuaEAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAxHV6wJPt7ukJkHnSeZ6Z0xMeyQ0FgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgMR1esCTzczpCZndveU5T3pnd3nSO7vrnPEz3FAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASFynB5ywu6cnpGbm9ISX87QzcAfn7Pvm473emRsKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASFynB5wwM7c8Z3dveQ6/013n7Eme9s724+f/BszH73lnbigAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJK7TA+Cfdvf0hJfjnfFbuKEAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAhKAAkLhOD/hsZk5P4A98LjgD/Bc3FAASggJAQlAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJC4Tg/4bHdPT+APfC7fNzOnJ7wc5+z7ftM5c0MBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBIHGdHnDCzJyewL+467PZ3Vuew/c87bv5bufMDQWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAxHV6wAm7e3oCp838+COeds6e9vvQc0MBICEoACQEBYCEoACQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoAiev0gM9m5vQE3sAtp8xZ5s24oQCQEBQAEoICQEJQAEgICgAJQQEgISgAJAQFgISgAJAQFAASggJAQlAASAgKAAlBASAhKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACQEBYCEoACQuL76g7v7kzsAeHFuKAAkBAWAhKAAkBAUABKCAkBCUABICAoACUEBICEoACT+Br2pXUePWJuoAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure(figsize=(5, 5))\n", "plt.imshow(datasets[\"small\"][\"train\"][0][\"image\"])\n", "plt.axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "42e5eabc", "metadata": {}, "source": [ "Next let's display the solution path for this maze. We can see that the solution path is the actions required to traverse from the red pixel to the green pixel, padded by wait actions (W) to make all solution paths the same length." ] }, { "cell_type": "code", "execution_count": 54, "id": "b008d5eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[3, 3, 0, 0, 0, 0, 3, 3, 3, 3, 0, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]\n", "['→', '→', '↑', '↑', '↑', '↑', '→', '→', '→', '→', '↑', '↑', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W', 'W']\n" ] } ], "source": [ "DIRECTION_MAP = {\n", " 0: \"↑\",\n", " 1: \"↓\", \n", " 2: \"←\",\n", " 3: \"→\",\n", " 4: \"W\"\n", "}\n", "print(datasets[\"small\"][\"train\"][0][\"solution_path\"])\n", "print([DIRECTION_MAP[step] for step in datasets[\"small\"][\"train\"][0][\"solution_path\"]])" ] }, { "cell_type": "markdown", "id": "8abe19f1", "metadata": {}, "source": [ "Now lets test the pretrained large maze solving CTM!" ] }, { "cell_type": "code", "execution_count": 55, "id": "feba19b7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using neuron select type: first-last\n", "Synch representation size action: 528\n", "Synch representation size out: 2080\n" ] } ], "source": [ "# Load the model\n", "model = CTM.from_pretrained(f\"SakanaAI/ctm-maze-large\")\n", "model = model.to(device)\n", "model.eval();" ] }, { "cell_type": "code", "execution_count": 56, "id": "6bc38435", "metadata": {}, "outputs": [], "source": [ "def collate_fn(batch):\n", " \"\"\"Custom collate function for DataLoader\"\"\"\n", " images = []\n", " solutions = []\n", " \n", " for item in batch:\n", " # Convert image to tensor\n", " image_array = np.array(item['image']).astype(np.float32) / 255.0\n", " image_tensor = torch.from_numpy(image_array).permute(2, 0, 1) \n", " images.append(image_tensor)\n", " solutions.append(torch.tensor(item['solution_path'], dtype=torch.long))\n", " \n", " images = torch.stack(images)\n", " solutions = torch.stack(solutions)\n", " \n", " return images, solutions\n", "\n", "test_loader = DataLoader(\n", " datasets[\"large\"]['test'],\n", " batch_size=64,\n", " shuffle=True,\n", " num_workers=0,\n", " collate_fn=collate_fn\n", ")" ] }, { "cell_type": "code", "execution_count": 57, "id": "a968945e", "metadata": {}, "outputs": [], "source": [ "def evaluate_batch(model, dataloader, device, max_batches=None):\n", " \"\"\"Evaluate model on batches of data\"\"\"\n", " model.eval()\n", " \n", " all_targets = []\n", " all_predictions_most_certain = []\n", " all_losses = []\n", "\n", " with torch.no_grad():\n", " for batch_idx, (images, targets) in enumerate(tqdm(dataloader, desc=\"Evaluating\")):\n", " if max_batches and batch_idx >= max_batches:\n", " break\n", " \n", " images = images.to(device)\n", " targets = targets.to(device)\n", "\n", " # Run inference\n", " predictions_raw, certainties, _ = model(images)\n", " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))\n", " \n", " # Calculate loss and get most certain predictions\n", " loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)\n", " all_losses.append(loss.item())\n", " \n", " # Get predictions at most certain time step\n", " pred_at_certain = predictions.argmax(2)[\n", " torch.arange(predictions.size(0), device=predictions.device), \n", " :, \n", " where_most_certain\n", " ]\n", " \n", " all_targets.append(targets.cpu().numpy())\n", " all_predictions_most_certain.append(pred_at_certain.cpu().numpy())\n", "\n", " # Concatenate all batches\n", " all_targets = np.concatenate(all_targets)\n", " all_predictions_most_certain = np.concatenate(all_predictions_most_certain)\n", " \n", " # Calculate metrics\n", " step_accuracy = (all_targets == all_predictions_most_certain).mean()\n", " maze_solve_rate = (all_targets == all_predictions_most_certain).all(axis=1).mean()\n", " avg_loss = np.mean(all_losses)\n", "\n", " return step_accuracy, maze_solve_rate, avg_loss" ] }, { "cell_type": "code", "execution_count": 58, "id": "3e765556", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Evaluating: 1%|▏ | 1/79 [00:00<00:11, 6.94it/s]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Evaluating: 25%|██▌ | 20/79 [00:02<00:08, 6.90it/s]\n" ] } ], "source": [ "step_accuracy, maze_solve_rate, avg_loss = evaluate_batch(model, test_loader, device, max_batches=20)\n" ] }, { "cell_type": "code", "execution_count": 59, "id": "076927e5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step Accuracy: 99.45%\n", "Maze Solve Rate: 95.31%\n", "Average Loss: 0.0151\n" ] } ], "source": [ "print(f\"Step Accuracy: {step_accuracy*100:.2f}%\")\n", "print(f\"Maze Solve Rate: {maze_solve_rate*100:.2f}%\")\n", "print(f\"Average Loss: {avg_loss:.4f}\")" ] }, { "cell_type": "markdown", "id": "de596452", "metadata": {}, "source": [ "The model correctly predicted 99% of the steps through the mazes!" ] }, { "cell_type": "code", "execution_count": 62, "id": "c3ebd8ba", "metadata": {}, "outputs": [], "source": [ "def create_maze_gif_visualization(model, testloader, device, log_dir):\n", " model.eval()\n", " with torch.no_grad():\n", " inputs_viz, targets_viz = next(iter(testloader))\n", " inputs_viz = inputs_viz.to(device)\n", " targets_viz = targets_viz.to(device)\n", " \n", " batch_index_to_viz = 1\n", " \n", " predictions_raw, certainties, _, pre_activations, post_activations, attention_tracking = model(inputs_viz, track=True)\n", " \n", " # Reshape predictions\n", " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))\n", " \n", " # Reshape attention tracking for visualization\n", " att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])\n", " attention_tracking = attention_tracking.reshape(attention_tracking.shape[0], attention_tracking.shape[1], -1, att_shape[0], att_shape[1])\n", " \n", " maze_input = inputs_viz[batch_index_to_viz].detach().cpu().numpy()\n", " \n", " maze_predictions = predictions[batch_index_to_viz].detach().cpu().numpy()\n", " maze_targets = targets_viz[batch_index_to_viz].detach().cpu().numpy()\n", " maze_attention = attention_tracking[:, batch_index_to_viz] if attention_tracking.ndim > 2 else attention_tracking\n", "\n", " # Generate the maze GIF - saves gif to log_dir/prediction.gif\n", " make_maze_gif(\n", " maze_input,\n", " maze_predictions,\n", " maze_targets,\n", " maze_attention,\n", " log_dir\n", " )" ] }, { "cell_type": "code", "execution_count": 63, "id": "581f0007", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\n", "Processing frames for maze plotting: 0%| | 0/75 [00:00" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "create_maze_gif_visualization(model, test_loader, device, \"05_output\")\n", "\n", "reader = imageio.get_reader(\"05_output/prediction.gif\")\n", "frames = [reader.get_data(i) for i in range(min(len(reader), 100))]\n", "mediapy.show_video(frames, width=400, codec=\"gif\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 5 }