diff --git "a/examples/01_mnist.ipynb" "b/examples/01_mnist.ipynb" new file mode 100644--- /dev/null +++ "b/examples/01_mnist.ipynb" @@ -0,0 +1,834 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "04a72c0e", + "metadata": {}, + "source": [ + "# The Continuous Thought Machine – Tutorial 01: MNIST [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SakanaAI/continuous-thought-machines/blob/main/examples/01_mnist.ipynb) [![arXiv](https://img.shields.io/badge/arXiv-XXXX.XXXXX-b31b1b.svg)](https://arxiv.org/abs/)" + ] + }, + { + "cell_type": "markdown", + "id": "d88fc6d1", + "metadata": {}, + "source": [ + "Modern deep learning models ignore time as a core computational element. In contrast, the **Continuous Thought Machine (CTM)** introduces internal recurrence and neural synchronization to model *thinking as a temporal process*.\n", + "\n", + "### Key Ideas\n", + "\n", + "- **Internal Ticks**: The CTM runs over a self-generated temporal axis (independent of input), which we via as a dimension over which though can unfold.\n", + "- **Neuron-Level Models**: Each neuron has a private MLP that processes its own history of pre-activations over time.\n", + "- **Synchronization as Representation**: CTMs compute neuron-to-neuron synchronization over time and use these signals for attention and output.\n", + "\n", + "### Why It Matters\n", + "\n", + "- Enables **interpretable, dynamic reasoning**\n", + "- Supports **adaptive compute** (e.g. more ticks for harder tasks)\n", + "- Works across tasks: classification, reasoning, memory, RL—*without changing the core mechanisms*.\n", + "\n", + "----" + ] + }, + { + "cell_type": "markdown", + "id": "b05cf27b", + "metadata": {}, + "source": [ + "### MNIST Classification\n", + "\n", + "In this tutorial we walk through a simple example; training a CTM to classify MNIST digits. We cover:\n", + "- Defining the model\n", + "- Constructing the loss function\n", + "- Training\n", + "- Building vizualization" + ] + }, + { + "cell_type": "markdown", + "id": "c257dbd3", + "metadata": {}, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "a7bfbfe0", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "from scipy.special import softmax\n", + "import math\n", + "from torchvision import datasets, transforms\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "import imageio\n", + "import mediapy\n" + ] + }, + { + "cell_type": "markdown", + "id": "5ee6aa63", + "metadata": {}, + "source": [ + "We start by defining some helper classes, which we will use in the CTM.\n", + "\n", + "Of note is the SuperLinear class, which implements N unique linear transformations. This SuperLinear class will be used for the neuron-level models." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "50c9ac82", + "metadata": {}, + "outputs": [], + "source": [ + "class Identity(nn.Module):\n", + " \"\"\"Identity Module.\"\"\"\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " def forward(self, x):\n", + " return x\n", + "\n", + "class Squeeze(nn.Module):\n", + " \"\"\"Squeeze Module.\"\"\"\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.dim = dim\n", + "\n", + " def forward(self, x):\n", + " return x.squeeze(self.dim)\n", + "\n", + "class SuperLinear(nn.Module):\n", + " \"\"\"SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.\"\"\"\n", + " def __init__(self, in_dims, out_dims, N):\n", + " super().__init__()\n", + " self.in_dims = in_dims\n", + " self.register_parameter('w1', nn.Parameter(\n", + " torch.empty((in_dims, out_dims, N)).uniform_(\n", + " -1/math.sqrt(in_dims + out_dims),\n", + " 1/math.sqrt(in_dims + out_dims)\n", + " ), requires_grad=True)\n", + " )\n", + " self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))\n", + "\n", + " def forward(self, x):\n", + " out = torch.einsum('BDM,MHD->BDH', x, self.w1) + self.b1\n", + " out = out.squeeze(-1)\n", + " return out" + ] + }, + { + "cell_type": "markdown", + "id": "b0eb50ea", + "metadata": {}, + "source": [ + "Next, we define a helper function `compute_normalized_entropy`. We will use this function inside the CTM to compute the certainty of the model at each internal tick as `certainty = 1 - normalized entropy`." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4eedd9ee", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_normalized_entropy(logits, reduction='mean'):\n", + " \"\"\"Computes the normalized entorpy for certainty-loss.\"\"\"\n", + " preds = F.softmax(logits, dim=-1)\n", + " log_preds = torch.log_softmax(logits, dim=-1)\n", + " entropy = -torch.sum(preds * log_preds, dim=-1)\n", + " num_classes = preds.shape[-1]\n", + " max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32))\n", + " normalized_entropy = entropy / max_entropy\n", + " if len(logits.shape)>2 and reduction == 'mean':\n", + " normalized_entropy = normalized_entropy.flatten(1).mean(-1)\n", + " return normalized_entropy" + ] + }, + { + "cell_type": "markdown", + "id": "f89b70a8", + "metadata": {}, + "source": [ + "## CTM Architecture Overview\n", + "\n", + "The CTM implementation is initialized with the following core parameters:\n", + "\n", + "- `iterations`: Number of internal ticks (recurrent steps).\n", + "- `d_model`: Total number of neurons.\n", + "- `d_input`: Input and attention embedding dimension.\n", + "- `memory_length`: Length of the sliding activation window used by each neuron.\n", + "- `heads`: Number of attention heads.\n", + "- `n_synch_out`: Number of neurons used for output synchronization.\n", + "- `n_synch_action`: Number of neurons used for computing attention queries.\n", + "- `out_dims`: Dimensionality of the model's output.\n", + "\n", + "### Key Components\n", + "\n", + "Upon initialization, the CTM builds the following modules:\n", + "\n", + "- **Backbone**: A CNN feature extractor for the input (e.g. image).\n", + "- **Synapses**: A communication layer allowing neurons to interact.\n", + "- **Trace Processor**: A neuron-level model that operates on each neuron's temporal activation trace.\n", + "- **Synchronization Buffers**: For tracking decay.\n", + "- **Learned Initial States**: Starting activations and traces for the system.\n", + "\n", + "---\n", + "\n", + "## Forward Pass Mechanics\n", + "\n", + "At each internal tick `stepi`, the CTM executes the following procedure:\n", + "\n", + "1. **Initialize recurrent state**:\n", + " - `state_trace`: Memory trace per neuron.\n", + " - `activated_state`: Current post-activations.\n", + " - `decay_alpha_out`, `decay_beta_out`: Values for calculating synchronization.\n", + "\n", + "2. **Featurize input**:\n", + " - Use the CNN backbone to extract key-value attention pairs `kv`.\n", + "\n", + "3. **Internal Loop** (for each tick `stepi`):\n", + " 1. Compute `synchronisation_action` from `n_synch_action` neurons.\n", + " 2. Generate attention query `q` from this synchronization.\n", + " 3. Perform multi-head cross-attention over `kv`.\n", + " 4. Concatenate attention output with the current neuron activations.\n", + " 5. Update neurons via the synaptic module.\n", + " 6. Append new activation to the trace window.\n", + " 7. Update neuron states using the `trace_processor`.\n", + " 8. Compute `synchronisation_out` from `n_synch_out` neurons.\n", + " 9. Project to the output space via `output_projector`.\n", + " 10. Compute prediction certainty from normalized entropy.\n", + "\n", + "This inner loop is repeated for the configured number of internal ticks. The CTM emits **predictions and certainties at every internal tick**.\n", + "\n", + "> For detailed mathematical formulation of the synchronization mechanism, please refer to the technical report." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f357853f", + "metadata": {}, + "outputs": [], + "source": [ + "class ContinuousThoughtMachine(nn.Module):\n", + " def __init__(self,\n", + " iterations,\n", + " d_model,\n", + " d_input,\n", + " memory_length,\n", + " heads,\n", + " n_synch_out,\n", + " n_synch_action,\n", + " out_dims,\n", + " memory_hidden_dims,\n", + " ):\n", + " super(ContinuousThoughtMachine, self).__init__()\n", + "\n", + " # --- Core Parameters ---\n", + " self.iterations = iterations\n", + " self.d_model = d_model\n", + " self.d_input = d_input\n", + " self.memory_length = memory_length\n", + " self.n_synch_out = n_synch_out\n", + " self.n_synch_action = n_synch_action\n", + " self.out_dims = out_dims\n", + " self.memory_length = memory_length\n", + " self.memory_hidden_dims = memory_hidden_dims\n", + "\n", + " # --- Input Processing ---\n", + " self.backbone = nn.Sequential(\n", + " nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),\n", + " nn.BatchNorm2d(d_input),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(2, 2),\n", + " nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),\n", + " nn.BatchNorm2d(d_input),\n", + " nn.ReLU(),\n", + " nn.MaxPool2d(2, 2),\n", + " )\n", + " self.attention = nn.MultiheadAttention(self.d_input, heads, batch_first=True)\n", + " self.kv_proj = nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input))\n", + " self.q_proj = nn.LazyLinear(self.d_input)\n", + "\n", + " # --- Core CTM Modules ---\n", + " self.synapses = nn.Sequential(\n", + " nn.LazyLinear(d_model * 2),\n", + " nn.GLU(),\n", + " nn.LayerNorm(d_model)\n", + " )\n", + " self.trace_processor = nn.Sequential(\n", + " SuperLinear(in_dims=memory_length, out_dims=2 * memory_hidden_dims, N=d_model),\n", + " nn.GLU(),\n", + " SuperLinear(in_dims=memory_hidden_dims, out_dims=2, N=d_model),\n", + " nn.GLU(),\n", + " Squeeze(-1)\n", + " )\n", + "\n", + " # --- Start States ---\n", + " self.register_parameter('start_activated_state', nn.Parameter(\n", + " torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))),\n", + " requires_grad=True\n", + " ))\n", + "\n", + " self.register_parameter('start_trace', nn.Parameter(\n", + " torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))),\n", + " requires_grad=True\n", + " ))\n", + "\n", + " # --- Synchronisation ---\n", + " self.synch_representation_size_action = (self.n_synch_action * (self.n_synch_action+1))//2\n", + " self.synch_representation_size_out = (self.n_synch_out * (self.n_synch_out+1))//2\n", + "\n", + " for synch_type, size in [('action', self.synch_representation_size_action), ('out', self.synch_representation_size_out)]:\n", + " print(f\"Synch representation size {synch_type}: {size}\")\n", + "\n", + " self.set_synchronisation_parameters('out', self.n_synch_out)\n", + " self.set_synchronisation_parameters('action', self.n_synch_action)\n", + "\n", + " # --- Output Procesing ---\n", + " self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))\n", + "\n", + " def set_synchronisation_parameters(self, synch_type: str, n_synch: int):\n", + " left, right = self.initialize_left_right_neurons(synch_type, self.d_model, n_synch)\n", + " synch_representation_size = self.synch_representation_size_action if synch_type == 'action' else self.synch_representation_size_out\n", + " self.register_buffer(f'{synch_type}_neuron_indices_left', left)\n", + " self.register_buffer(f'{synch_type}_neuron_indices_right', right)\n", + " self.register_parameter(f'decay_params_{synch_type}', nn.Parameter(torch.zeros(synch_representation_size), requires_grad=True))\n", + "\n", + " def initialize_left_right_neurons(self, synch_type, d_model, n_synch):\n", + " if synch_type == 'out':\n", + " neuron_indices_left = neuron_indices_right = torch.arange(0, n_synch)\n", + " elif synch_type == 'action':\n", + " neuron_indices_left = neuron_indices_right = torch.arange(d_model-n_synch, d_model)\n", + " return neuron_indices_left, neuron_indices_right\n", + "\n", + " def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, synch_type):\n", + " B = activated_state.size(0)\n", + " if synch_type == 'action':\n", + " n_synch = self.n_synch_action\n", + " decay_params = self.decay_params_action\n", + " selected_left = selected_right = activated_state[:, -n_synch:]\n", + " elif synch_type == 'out':\n", + " n_synch = self.n_synch_out\n", + " decay_params = self.decay_params_out\n", + " selected_left = selected_right = activated_state[:, :n_synch]\n", + "\n", + " outer = selected_left.unsqueeze(2) * selected_right.unsqueeze(1)\n", + " i, j = torch.triu_indices(n_synch, n_synch)\n", + " pairwise_product = outer[:, i, j]\n", + " r = torch.exp(-torch.clamp(decay_params, 0, 15)).unsqueeze(0).repeat(B, 1)\n", + "\n", + " if decay_alpha is None or decay_beta is None:\n", + " decay_alpha = pairwise_product\n", + " decay_beta = torch.ones_like(pairwise_product)\n", + " else:\n", + " decay_alpha = r * decay_alpha + pairwise_product\n", + " decay_beta = r * decay_beta + 1\n", + "\n", + " synchronisation = decay_alpha / (torch.sqrt(decay_beta))\n", + " return synchronisation, decay_alpha, decay_beta\n", + "\n", + " def compute_features(self, x):\n", + " input_features = self.backbone(x)\n", + " kv = self.kv_proj(input_features.flatten(2).transpose(1, 2))\n", + " return kv\n", + "\n", + " def compute_certainty(self, current_prediction):\n", + " ne = compute_normalized_entropy(current_prediction)\n", + " current_certainty = torch.stack((ne, 1-ne), -1)\n", + " return current_certainty\n", + "\n", + " def forward(self, x, track=False):\n", + " B = x.size(0)\n", + " device = x.device\n", + "\n", + " # --- Tracking Initialization ---\n", + " pre_activations_tracking = []\n", + " post_activations_tracking = []\n", + " synch_out_tracking = []\n", + " synch_action_tracking = []\n", + " attention_tracking = []\n", + "\n", + " # --- Featurise Input Data ---\n", + " kv = self.compute_features(x)\n", + "\n", + " # --- Initialise Recurrent State ---\n", + " state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)\n", + " activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)\n", + "\n", + " # --- Storage for outputs per iteration\n", + " predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)\n", + " certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)\n", + "\n", + " decay_alpha_action, decay_beta_action = None, None\n", + " _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, synch_type='out')\n", + "\n", + " # --- Recurrent Loop ---\n", + " for stepi in range(self.iterations):\n", + "\n", + " # --- Calculate Synchronisation for Input Data Interaction ---\n", + " synchronisation_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, synch_type='action')\n", + "\n", + " # --- Interact with Data via Attention ---\n", + " q = self.q_proj(synchronisation_action).unsqueeze(1)\n", + " attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)\n", + " attn_out = attn_out.squeeze(1)\n", + " pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)\n", + "\n", + " # --- Apply Synapses ---\n", + " state = self.synapses(pre_synapse_input)\n", + " state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)\n", + "\n", + " # --- Activate ---\n", + " activated_state = self.trace_processor(state_trace)\n", + "\n", + " # --- Calculate Synchronisation for Output Predictions ---\n", + " synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, synch_type='out')\n", + "\n", + " # --- Get Predictions and Certainties ---\n", + " current_prediction = self.output_projector(synchronisation_out)\n", + " current_certainty = self.compute_certainty(current_prediction)\n", + "\n", + " predictions[..., stepi] = current_prediction\n", + " certainties[..., stepi] = current_certainty\n", + "\n", + " # --- Tracking ---\n", + " if track:\n", + " pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())\n", + " post_activations_tracking.append(activated_state.detach().cpu().numpy())\n", + " attention_tracking.append(attn_weights.detach().cpu().numpy())\n", + " synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())\n", + " synch_action_tracking.append(synchronisation_action.detach().cpu().numpy())\n", + "\n", + " # --- Return Values ---\n", + " if track:\n", + " return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)\n", + " return predictions, certainties, synchronisation_out" + ] + }, + { + "cell_type": "markdown", + "id": "5a049b6a", + "metadata": {}, + "source": [ + "## Certainty-Based Loss Function\n", + "\n", + "The CTM produces outputs at each internal tick, so the question arises: **how do we optimize the model across this internal temporal dimension?**\n", + "\n", + "Our answer is a simple but effective **certainty-based loss** that encourages the model to reason meaningfully across time. Instead of relying on the final tick alone, we aggregate loss from two key internal ticks:\n", + "\n", + "1. The tick where the **prediction loss** is lowest.\n", + "2. The tick where the **certainty** (1 - normalized entropy) is highest.\n", + "\n", + "We then take the **average of the losses** at these two points.\n", + "\n", + "This approach encourages the CTM to both make accurate predictions and express high confidence in them—while supporting adaptive, interpretable computation over time.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0f463eb9", + "metadata": {}, + "outputs": [], + "source": [ + "def get_loss(predictions, certainties, targets, use_most_certain=True):\n", + " \"\"\"use_most_certain will select either the most certain point or the final point.\"\"\"\n", + "\n", + " losses = nn.CrossEntropyLoss(reduction='none')(predictions,\n", + " torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1))\n", + "\n", + " loss_index_1 = losses.argmin(dim=1)\n", + " loss_index_2 = certainties[:,1].argmax(-1)\n", + " if not use_most_certain:\n", + " loss_index_2[:] = -1\n", + "\n", + " batch_indexer = torch.arange(predictions.size(0), device=predictions.device)\n", + " loss_minimum_ce = losses[batch_indexer, loss_index_1].mean()\n", + " loss_selected = losses[batch_indexer, loss_index_2].mean()\n", + "\n", + " loss = (loss_minimum_ce + loss_selected)/2\n", + " return loss, loss_index_2" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e54afe0f", + "metadata": {}, + "outputs": [], + "source": [ + "def calculate_accuracy(predictions, targets, where_most_certain):\n", + " \"\"\"Calculate the accuracy based on the prediction at the most certain internal tick.\"\"\"\n", + " B = predictions.size(0)\n", + " device = predictions.device\n", + "\n", + " predictions_at_most_certain_internal_tick = predictions.argmax(1)[torch.arange(B, device=device), where_most_certain].detach().cpu().numpy()\n", + " accuracy = (targets.detach().cpu().numpy() == predictions_at_most_certain_internal_tick).mean()\n", + "\n", + " return accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c1371279", + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_data():\n", + " transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " ])\n", + " train_data = datasets.MNIST(root=\"./data\", train=True, download=True, transform=transform)\n", + " test_data = datasets.MNIST(root=\"./data\", train=False, download=True, transform=transform)\n", + " trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True, num_workers=1)\n", + " testloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True, num_workers=1, drop_last=False)\n", + " return trainloader, testloader" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a492b058", + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, trainloader, testloader, iterations, device):\n", + "\n", + " test_every = 100\n", + "\n", + " optimizer = torch.optim.AdamW(params=list(model.parameters()), lr=0.0001, eps=1e-8)\n", + "\n", + " model.train()\n", + "\n", + " with tqdm(total=iterations, initial=0, dynamic_ncols=True) as pbar:\n", + " test_loss = None\n", + " test_accuracy = None\n", + " for stepi in range(iterations):\n", + "\n", + " inputs, targets = next(iter(trainloader))\n", + " inputs, targets = inputs.to(device), targets.to(device)\n", + " predictions, certainties, _ = model(inputs, track=False)\n", + " train_loss, where_most_certain = get_loss(predictions, certainties, targets)\n", + " train_accuracy = calculate_accuracy(predictions, targets, where_most_certain)\n", + "\n", + " optimizer.zero_grad()\n", + " train_loss.backward()\n", + " optimizer.step()\n", + "\n", + " if stepi % test_every == 0:\n", + " model.eval()\n", + " with torch.inference_mode():\n", + " all_test_predictions = []\n", + " all_test_targets = []\n", + " all_test_where_most_certain = []\n", + " all_test_losses = []\n", + "\n", + " for inputs, targets in testloader:\n", + " inputs, targets = inputs.to(device), targets.to(device)\n", + " predictions, certainties, _ = model(inputs, track=False)\n", + " test_loss, where_most_certain = get_loss(predictions, certainties, targets)\n", + " all_test_losses.append(test_loss.item())\n", + "\n", + " all_test_predictions.append(predictions)\n", + " all_test_targets.append(targets)\n", + " all_test_where_most_certain.append(where_most_certain)\n", + "\n", + " all_test_predictions = torch.cat(all_test_predictions, dim=0)\n", + " all_test_targets = torch.cat(all_test_targets, dim=0)\n", + " all_test_where_most_certain = torch.cat(all_test_where_most_certain, dim=0)\n", + "\n", + " test_accuracy = calculate_accuracy(all_test_predictions, all_test_targets, all_test_where_most_certain)\n", + " test_loss = sum(all_test_losses) / len(all_test_losses)\n", + " model.train()\n", + "\n", + " pbar.set_description(f'Train Loss: {train_loss:.3f}, Train Accuracy: {train_accuracy:.3f} Test Loss: {test_loss:.3f}, Test Accuracy: {test_accuracy:.3f}')\n", + " pbar.update(1)\n", + "\n", + " return model" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "2d1658a9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Synch representation size action: 136\n", + "Synch representation size out: 136\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/2000 [00:00" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "model.eval()\n", + "with torch.inference_mode():\n", + " inputs, targets = next(iter(testloader))\n", + " inputs = inputs.to(device)\n", + "\n", + " predictions, certainties, (synch_out_tracking, synch_action_tracking), \\\n", + " pre_activations_tracking, post_activations_tracking, attention = model(inputs, track=True)\n", + "\n", + " filename = \"mnist_output.gif\"\n", + "\n", + " make_gif(\n", + " predictions.detach().cpu().numpy(),\n", + " certainties.detach().cpu().numpy(),\n", + " targets.detach().cpu().numpy(),\n", + " pre_activations_tracking,\n", + " post_activations_tracking,\n", + " attention,\n", + " inputs.detach().cpu().numpy(),\n", + " filename\n", + " )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "atm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}