diff --git "a/examples/04_parity.ipynb" "b/examples/04_parity.ipynb" new file mode 100644--- /dev/null +++ "b/examples/04_parity.ipynb" @@ -0,0 +1,681 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "04a72c0e", + "metadata": {}, + "source": [ + "# The Continuous Thought Machine – Tutorial 04: Parity [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SakanaAI/continuous-thought-machines/blob/main/examples/04_parity.ipynb) [![arXiv](https://img.shields.io/badge/arXiv-2505.05522-b31b1b.svg)](https://arxiv.org/abs/2505.05522)" + ] + }, + { + "cell_type": "markdown", + "id": "b05cf27b", + "metadata": {}, + "source": [ + "### Parity\n", + "\n", + "The parity of a binary sequence, given by the sign of the product of its elements, can reasonably be predicted by an RNN when the data is fed sequentially - the model need only maintain an internal state, flipping a ‘switch’ whenever a negative number is encountered. When the entire sequence is provided at once, however, the task is significantly more challenging.\n", + "\n", + "In Section 8 of the [technical report](https://arxiv.org/pdf/2505.05522), we showcase how a CTM can be trained to do exactly this. In particular, we input the CTM with a binary sequence, and train the model to predict the cumulative parity at each position along the sequence." + ] + }, + { + "cell_type": "markdown", + "id": "0dbffa93", + "metadata": {}, + "source": [ + "### Tutorial Overview" + ] + }, + { + "cell_type": "markdown", + "id": "07c2bbea", + "metadata": {}, + "source": [ + "In this tutorial, we walk through how we trained the CTM, using sequences of length 16." + ] + }, + { + "cell_type": "markdown", + "id": "c2272f18", + "metadata": {}, + "source": [ + "### Setup" + ] + }, + { + "cell_type": "markdown", + "id": "c257dbd3", + "metadata": {}, + "source": [ + "In addition to installing some dependencies, we also clone the CTM repo (assuming this tutorial is being run in Colab), so that we can access the base CTM model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1ccfdcf", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install gdown\n", + "!pip install mediapy\n", + "!git clone https://github.com/SakanaAI/continuous-thought-machines.git\n" + ] + }, + { + "cell_type": "markdown", + "id": "1ab57a96", + "metadata": {}, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "24ffe416", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"./continuous-thought-machines\")\n", + "\n", + "import os\n", + "import torch\n", + "import torch.nn as nn\n", + "import numpy as np\n", + "import random\n", + "from torch.utils.data import Dataset\n", + "from tqdm.auto import tqdm\n", + "import matplotlib.pyplot as plt\n", + "from IPython.display import display, clear_output\n", + "import imageio\n", + "import mediapy\n", + "\n", + "# From CTM repo\n", + "from models.ctm import ContinuousThoughtMachine as CTM\n", + "from tasks.parity.plotting import make_parity_gif\n", + "from tasks.parity.utils import reshape_attention_weights, reshape_inputs" + ] + }, + { + "cell_type": "markdown", + "id": "82620e4e", + "metadata": {}, + "source": [ + "Set a seed for reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "604415b7", + "metadata": {}, + "outputs": [], + "source": [ + "def set_seed(seed=42, deterministic=True):\n", + " random.seed(seed)\n", + " np.random.seed(seed)\n", + " torch.manual_seed(seed)\n", + " torch.cuda.manual_seed_all(seed)\n", + " torch.backends.cudnn.deterministic = deterministic\n", + " torch.backends.cudnn.benchmark = False" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "b2ba7f7b", + "metadata": {}, + "outputs": [], + "source": [ + "set_seed(42)" + ] + }, + { + "cell_type": "markdown", + "id": "4407a4a8", + "metadata": {}, + "source": [ + "### Data" + ] + }, + { + "cell_type": "markdown", + "id": "e271bc4c", + "metadata": {}, + "source": [ + "We define a dataset to create the parity sequences for training and testing. Each sample is a sequence of length `sequence_length`, where we randomly place -1s and 1s at each position. We calculate the target sequence (of the same length) as the parity upto and including that position, with 0s corresponding to negative parity and 1s corrsponding to positive parity." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "830313fb", + "metadata": {}, + "outputs": [], + "source": [ + "class ParityDataset(Dataset):\n", + " def __init__(self, sequence_length=64, length=100000):\n", + " self.sequence_length = sequence_length\n", + " self.length = length\n", + "\n", + " def __len__(self):\n", + " return self.length\n", + "\n", + " def __getitem__(self, idx):\n", + " vector = 2 * torch.randint(0, 2, (self.sequence_length,)) - 1\n", + " vector = vector.float()\n", + " negatives = (vector == -1).to(torch.long)\n", + " cumsum = torch.cumsum(negatives, dim=0)\n", + " target = (cumsum % 2 != 0).to(torch.long)\n", + " return vector, target" + ] + }, + { + "cell_type": "markdown", + "id": "39ec4d00", + "metadata": {}, + "source": [ + "We set the parity sequence length to `grid_size ** 2 = 16`, and prepare the train and test loaders. We use a `batch_size` of 64." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "cc28def0", + "metadata": {}, + "outputs": [], + "source": [ + "grid_size = 4\n", + "parity_sequence_length = grid_size ** 2\n", + "\n", + "train_data = ParityDataset(sequence_length=parity_sequence_length, length=100000)\n", + "test_data = ParityDataset(sequence_length=parity_sequence_length, length=10000)\n", + "\n", + "trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True, num_workers=0)\n", + "testloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)" + ] + }, + { + "cell_type": "markdown", + "id": "e94a149b", + "metadata": {}, + "source": [ + "We can visualise what these inputs and targets look like. White squares correspond to positive parity, and black squares correspond to negative parity." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "33503097", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input: [-1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0]\n", + "Target: [1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjkAAAErCAYAAAA8HZJgAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAD+JJREFUeJzt3W+s1nX9x/EXJxkmHKAi4XgORIodKjrmbCJlbvJnLaulzlCMP23kcsbWphvJPVouhmurBU23CJHIsND+wOZqZTPGnGtadCfWmiFwQuYi4xwk3Djnd6Nx8gTmoZ/nXIf39XjccdfF1+95fy6u89nzfK8vMKa/v78/AADFtDR6AACA4SByAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgJJEDAJQkcgCAki5o9AAMn8cffzxr1qzJjh078oEPfKChs5w4cSKbNm3K1Vdfnblz5zZ0FmD4dHZ2Dum4rVu3jqq94LnnnsuePXuyYsWKTJw4sdHj8CYROYyIEydOZOPGjVm1atWo2tiAN9f9998/6PFPf/rT7Nmz54znL7vsspEc6w397ne/y8aNG3PTTTeJnEJEDgBvmk9/+tODHu/duzd79uw54/n/RX9/f06ePJkLL7zw/30umoN7cprIvffemyuvvDJHjhzJXXfdlSuvvDLXXHNN1q9fn1OnTg0cd+jQoXR2dua73/1utmzZkuuvvz5dXV1ZunRp/vSnPw0657Jly7Js2bKzfq358+cPnG/evHlJko0bN6azszOdnZ3ZsGHDMK4WGK0ee+yxLF++PPPmzcucOXNyww035JFHHjnjuPnz5+cLX/hCdu/enZtvvjldXV3Zvn17kqS7uzt33nlnPvjBD2bevHn52te+lt27d6ezszPPPPPMoPPs3bs3K1euzFVXXZUrrrgiS5cuzbPPPjvw6xs2bBi40rRgwYKBPerQoUPD+CowElzJaTKnTp3KypUr09XVldWrV+fpp5/O5s2bM3369Nx+++2Djv3JT36S48eP5/bbb8/Jkyfzve99LytWrMjOnTszZcqUIX/Nt7/97Vm7dm3Wrl2bRYsWZdGiRUmG/tk9UMsPfvCDXH755Zk/f34uuOCC/PrXv85XvvKV9Pf357Of/eygY//yl7/knnvuya233prFixfn3e9+d1555ZWsWLEiL730UpYvX54pU6Zk165dZ8RNkjz99NO54447MmfOnKxatSpjxozJ448/nhUrVuSRRx5JV1dXFi1alP3792fXrl1Zs2ZN3va2tyX5197F+U3kNJmTJ0/m4x//eL74xS8mSZYsWZKbbropO3bsOCNyDhw4kF/84heZOnVqkuS6667LZz7zmXznO9/JmjVrhvw1L7roonzsYx/L2rVr09nZ+aZctgbOX9u2bRv0kdPSpUuzcuXKPPTQQ2dEzgsvvJBNmzblox/96MBzDz30UA4ePJhvf/vbWbhwYZLktttuy4033jjo/+3v78/atWszd+7cbNq0KWPGjBk49hOf+ES++c1vZvPmzZk9e3be9773ZdeuXVm4cGE6OjqGaeWMNB9XNaElS5YMenzVVVed9bLswoULBwInSbq6unLFFVfkqaeeGvYZgbpeGzg9PT05evRorr766hw8eDA9PT2Dju3o6BgUOEmye/fuTJ06NQsWLBh4bty4cVm8ePGg4/74xz9m//79+dSnPpW///3vOXr0aI4ePZpXXnkl8+bNy29/+9v09fUNwwoZLVzJaTLjxo074xLspEmT8o9//OOMY9/1rned8dzMmTPzxBNPDNt8QH3PPvtsNmzYkN///vc5ceLEoF/r6elJa2vrwOOzXVXp7u7OjBkzBq7MnDZjxoxBj/fv358k+fKXv/y6s/T09GTSpEnnugTOEyKnybzlLW8Zka/z2huZAU47cOBAPve5z+XSSy/Nvffem7a2towdOzZPPfVUtmzZcsaVlf/Pn6Tq7+9PkqxevTrvfe97z3rMRRdd9D+fn9FP5PC6XnjhhTOe279/f9rb2wceT5o0KQcPHjzjuL/+9a+DHv/nT1xAc3ryySfz6quv5oEHHsgll1wy8PzZbhp+Pe3t7fnzn/+c/v7+QXvLgQMHBh03ffr0JMmECRPy4Q9/+L+e0x5Vk3tyeF2//OUvc+TIkYHHf/jDH7J3795cd911A89Nnz49zz//fI4ePTrw3L59+/Lcc88NOtdb3/rWJMmxY8eGeWpgNDt9Nfn0VZbkXx8ZPfbYY0M+x7XXXpsjR47kV7/61cBzJ0+ezA9/+MNBx82ZMyczZszI5s2bc/z48TPO89p96/Qe9Z/3BHF+cyWH1zVjxowsWbIkS5YsyauvvpqtW7dm8uTJ+fznPz9wzC233JItW7Zk5cqVueWWW/K3v/0t27dvz6xZswZtKhdeeGFmzZqVJ554IjNnzszkyZNz+eWX5z3veU8jlgY0yEc+8pGMHTs2d955Z2677bYcP348P/rRj/KOd7wjL7300pDOceutt2bbtm255557snz58rzzne/Mzp07M27cuCT/virT0tKS++67L3fccUc++clP5uabb87UqVNz5MiRPPPMM5kwYUIefPDBJMn73//+JMk3vvGN3HDDDRk7dmyuv/56H2ed51zJ4XXdeOONWbZsWb7//e/nwQcfzKxZs/Lwww/n4osvHjjmsssuy/r169PT05N169blySefzP333z+wYbzWfffdl4svvjjr1q3L3XffnZ///OcjuRxgFLj00kvzrW99K2PGjMn69euzffv2LF68OMuXLx/yOcaPH5+HH34411xzTbZu3ZoHHnggH/rQh3LXXXclyUDsJMncuXPz6KOPZs6cOdm2bVu++tWv5sc//nGmTJmSFStWDBzX1dWVL33pS9m3b1/WrFmTu+++e9CVHs5PY/pfe80Q8q+/oXjBggVZvXp1Vq5c2ehxAIZky5YtWbduXX7zm98M+usvaF6u5ABw3vnnP/856PHJkyfz6KOPZubMmQKHAe7JAeC8s2rVqlxyySWZPXt2ent787Of/SzPP/98vv71rzd6NEYRkQPAeefaa6/Njh07snPnzpw6dSqzZs0auGkYTnNPDgBQkntyAICSRA4AUJLIAQBKEjkAQEnn9Ker2tvb/bseTa6lpSXTpk1r9Bgj7sUXX0xfX1/Trv/48eNn/YdYG2369OkZP358o8cYcaffj9CsWltb093d/YbHnVPk9PT0iJwm197enn379jV6jBHX0dGR7u7upl3/7NmzGz3CWY0fP74pfz9Ovx+B/87HVQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgJJEDAJQkcgCAkkQOAFCSyAEAShI5AEBJIgcAKEnkAAAliRwAoCSRAwCUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgJJEDAJQkcgCAkkQOAFCSyAEAShI5AEBJIgcAKEnkAAAliRwAoCSRAwCUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgpDH9/f39Qz144sSJ6enpGc55GOVaWlrS1tbW6DFG3OHDh9PX19e06+/t7c3LL7/c6DHOMHny5EyYMKHRY4y40+9HaFatra05duzYGx53wQjMQiF9fX3p7u5u9BgN06zrb21tbfQIZ9Wsvx/A0JxT5LS0tKS9vX24Zhm1mv2n+MRPjsDoY0+2J7+Rc4qcadOmZd++fcM1y6jV0dGR7u7utLW15dChQ40epyFOvwYAo4U92Z78Rtx4DACUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgJJEDAJQkcgCAkkQOAFCSyAEAShI5AEBJIgcAKEnkAAAliRwAoCSRAwCUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgJJEDAJQkcgCAkkQOAFCSyAEAShI5AEBJIgcAKEnkAAAliRwAoCSRAwCUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKuuBcDn7xxRfT0dExXLOMWocPHx74bzOuP/n3awAwWtiTeSPnFDl9fX3p7u4erllGvWZfP8BoYk/mjZxT5DS7lpaWtLW1NXqMhjh8+HD6+vqa9jVo9vX39vY2eoSzamlpSXt7e6PHGHHN/n5s9vUnXoOh7kki5xy0tbXl0KFDjR6jITo6OtLd3d20r0Gzr3/27NmNHuGspk2bln379jV6jBHX7O/HZl9/4jUY6p7kxmMAoCSRAwCUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgJJEDAJQkcgCAkkQOAFCSyAEAShI5AEBJIgcAKEnkAAAliRwAoCSRAwCUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgJJEDAJQkcgCAkkQOAFCSyAEAShI5AEBJIgcAKEnkAAAliRwAoCSRAwCUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKGtPf398/1IMnTpyYnp6e4ZxnVGtpaUlbW1ujx2iIw4cPp6+vr2lfg2Zff29vb15++eVGj3GGyZMnZ8KECY0eY8Q1+/ux2defeA2GuieJHOANtba25tixY40e4wz2JGhOQ92TLhiBWSikWX9qOP1TE4wmzf792KzrT+xJQyVyOCdtbW05dOhQo8cYcR0dHenu7m70GDBIs38/Nuv6E3vSULnxGAAoSeQAACWJHACgJJEDAJQkcgCAkkQOAFCSyAEAShI5AEBJIgcAKEnkAAAliRwAoCSRAwCUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgJJEDAJQkcgCAkkQOAFCSyAEAShI5AEBJIgcAKEnkAAAliRwAoCSRAwCUJHIAgJJEDgBQksgBAEoSOQBASSIHAChJ5AAAJYkcAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgJJEDAJQkcgCAkkQOAFCSyAEAShI5AEBJIgcAKEnkAAAliRwAoCSRAwCUdMG5HNza2jpcc3Ce6O3tzezZsxs9xojr7e1t6vf/aF37aJ1rpDT792Ozrj+xJw117WP6+/v7h3kWAIAR5+MqAKAkkQMAlCRyAICSRA4AUJLIAQBKEjkAQEkiBwAoSeQAACWJHACgpP8DhObh5rOTHosAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%matplotlib inline\n", + "sample_inputs, sample_targets = next(iter(trainloader))\n", + "sample_input = sample_inputs[0,:].reshape(grid_size, grid_size)\n", + "sample_target = sample_targets[0,:].reshape(grid_size, grid_size)\n", + "\n", + "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 3))\n", + "\n", + "# Plot the input\n", + "ax1.imshow(sample_input.flip(0), cmap='gray')\n", + "for i in range(grid_size+1):\n", + " ax1.axhline(i-0.5, color='black', linewidth=2,)\n", + " ax1.axvline(i-0.5, color='black', linewidth=2)\n", + "ax1.set_xlim(-0.5, grid_size-0.5)\n", + "ax1.set_ylim(-0.5, grid_size-0.5)\n", + "ax1.set_xticks([])\n", + "ax1.set_yticks([])\n", + "ax1.set_title('Input')\n", + "\n", + "# Plot the target\n", + "ax2.imshow(sample_target.flip(0), cmap='gray')\n", + "for i in range(grid_size+1):\n", + " ax2.axhline(i-0.5, color='black', linewidth=2)\n", + " ax2.axvline(i-0.5, color='black', linewidth=2)\n", + "ax2.set_xlim(-0.5, grid_size-0.5)\n", + "ax2.set_ylim(-0.5, grid_size-0.5)\n", + "ax2.set_xticks([])\n", + "ax2.set_yticks([])\n", + "ax2.set_title('Target')\n", + "\n", + "plt.tight_layout()\n", + "print(f\"Input: {sample_inputs[0].tolist()}\")\n", + "print(f\"Target: {sample_targets[0].tolist()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "069121b9", + "metadata": {}, + "source": [ + "### Loss Function" + ] + }, + { + "cell_type": "markdown", + "id": "88d91d78", + "metadata": {}, + "source": [ + "Next we define the loss function. First, for all internal ticks of the CTM, we calculate the cross-entropy loss for all positions along the output sequence. Then, as with the other experiments, we only use the loss at two specific internal ticks: where the loss is the lowest and where the model is most certain. We use advanced indexing into the losses tensor to extract these losses, and then average them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63e75f71", + "metadata": {}, + "outputs": [], + "source": [ + "def parity_loss(predictions, certainties, targets, use_most_certain=True):\n", + " \"\"\"\n", + " Computes the parity loss.\n", + "\n", + " Predictions are of shape: (B, parity_sequence_length, class, internal_ticks),\n", + " where classes are in [0,1,2,3,4] for [Up, Down, Left, Right, Wait]\n", + " Certainties are of shape: (B, 2, internal_ticks), \n", + " where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy]\n", + " Targets are of shape: [B, parity_sequence_length]\n", + "\n", + " use_most_certain will select either the most certain point or the final point. For baselines,\n", + " the final point proved the only usable option. \n", + " \"\"\"\n", + "\n", + " # Losses are of shape [B, parity_sequence_length, internal_ticks]\n", + " losses = nn.CrossEntropyLoss(reduction='none')(predictions.flatten(0,1), torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1).flatten(0,1).long()).reshape(predictions[:,:,0].shape)\n", + "\n", + " # Average the loss over the parity sequence dimension\n", + " losses = losses.mean(1)\n", + "\n", + " loss_index_1 = losses.argmin(dim=1)\n", + " loss_index_2 = certainties[:,1].argmax(-1)\n", + " if not use_most_certain:\n", + " loss_index_2[:] = -1\n", + " \n", + " batch_indexer = torch.arange(predictions.size(0), device=predictions.device)\n", + " loss_minimum_ce = losses[batch_indexer, loss_index_1].mean()\n", + " loss_selected = losses[batch_indexer, loss_index_2].mean()\n", + "\n", + " loss = (loss_minimum_ce + loss_selected)/2\n", + " return loss, loss_index_2" + ] + }, + { + "cell_type": "markdown", + "id": "a712a9a9", + "metadata": {}, + "source": [ + "### Training" + ] + }, + { + "cell_type": "markdown", + "id": "89cb8dd7", + "metadata": {}, + "source": [ + "We define some helper functions for making the progress bar look pretty, and to display the training curves." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fb57caee", + "metadata": {}, + "outputs": [], + "source": [ + "def make_pbar_desc(train_loss, train_accuracy, test_loss, test_accuracy, lr, where_most_certain):\n", + " \"\"\"A helper function to create a description for the tqdm progress bar\"\"\"\n", + " pbar_desc = f'Train Loss={train_loss:0.3f}. Train Acc={train_accuracy:0.3f}. Test Loss={test_loss:0.3f}. Test Acc={test_accuracy:0.3f}. LR={lr:0.6f}.'\n", + " pbar_desc += f' Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d}).'\n", + " return pbar_desc" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "19208f41", + "metadata": {}, + "outputs": [], + "source": [ + "def update_training_curve_plot(fig, ax1, ax2, train_losses, test_losses, train_accuracies, test_accuracies, steps):\n", + " clear_output(wait=True)\n", + " \n", + " # Plot loss\n", + " ax1.clear()\n", + " ax1.plot(range(len(train_losses)), train_losses, 'b-', alpha=0.7, label=f'Train Loss: {train_losses[-1]:.3f}')\n", + " ax1.plot(steps, test_losses, 'r-', marker='o', label=f'Test Loss: {test_losses[-1]:.3f}')\n", + " ax1.set_title('Loss')\n", + " ax1.set_xlabel('Step')\n", + " ax1.set_ylabel('Loss')\n", + " ax1.legend()\n", + " ax1.grid(True, alpha=0.3)\n", + "\n", + " # Plot accuracy\n", + " ax2.clear()\n", + " ax2.plot(range(len(train_accuracies)), train_accuracies, 'b-', alpha=0.7, label=f'Train Accuracy: {train_accuracies[-1]:.3f}')\n", + " ax2.plot(steps, test_accuracies, 'r-', marker='o', label=f'Test Accuracy: {test_accuracies[-1]:.3f}')\n", + " ax2.set_title('Accuracy')\n", + " ax2.set_xlabel('Step')\n", + " ax2.set_ylabel('Accuracy')\n", + " ax2.legend()\n", + " ax2.grid(True, alpha=0.3)\n", + "\n", + " plt.tight_layout()\n", + " display(fig)" + ] + }, + { + "cell_type": "markdown", + "id": "71541842", + "metadata": {}, + "source": [ + "We then write the function to train the CTM." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "02de7c62", + "metadata": {}, + "outputs": [], + "source": [ + "def train(model, trainloader, testloader, device='cpu', training_iterations=10000, test_every=1000, lr=1e-4, log_dir='./logs'):\n", + "\n", + " os.makedirs(log_dir, exist_ok=True)\n", + " \n", + " model.train()\n", + " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n", + " iterator = iter(trainloader)\n", + " \n", + " train_losses = []\n", + " test_losses = []\n", + " train_accuracies = []\n", + " test_accuracies = []\n", + " steps = []\n", + " \n", + " plt.ion()\n", + " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n", + "\n", + " with tqdm(total=training_iterations) as pbar:\n", + " for stepi in range(training_iterations):\n", + "\n", + " try:\n", + " inputs, targets = next(iterator)\n", + " except StopIteration:\n", + " iterator = iter(trainloader)\n", + " inputs, targets = next(iterator)\n", + " \n", + " inputs, targets = inputs.to(device), targets.to(device)\n", + " \n", + " optimizer.zero_grad()\n", + " \n", + " predictions_raw, certainties, _ = model(inputs)\n", + "\n", + " # Reshape: (B, SeqLength, C, T)\n", + " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 2, predictions_raw.size(-1))\n", + " \n", + " # Compute loss\n", + " train_loss, where_most_certain = parity_loss(predictions, certainties, targets, use_most_certain=True)\n", + " train_accuracy = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()\n", + "\n", + " train_losses.append(train_loss.item())\n", + " train_accuracies.append(train_accuracy)\n", + "\n", + " train_loss.backward()\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + "\n", + " if stepi % test_every == 0 or stepi == 0:\n", + " model.eval()\n", + " with torch.no_grad():\n", + " all_test_predictions = []\n", + " all_test_targets = []\n", + " all_test_where_most_certain = []\n", + " all_test_losses = []\n", + "\n", + " for inputs, targets in testloader:\n", + " inputs = inputs.to(device)\n", + " targets = targets.to(device)\n", + " \n", + " predictions_raw, certainties, where_most_certain = model(inputs)\n", + " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 2, predictions_raw.size(-1))\n", + " \n", + " test_loss, where_most_certain = parity_loss(predictions, certainties, targets, use_most_certain=True)\n", + " all_test_losses.append(test_loss.item())\n", + " all_test_predictions.append(predictions)\n", + " all_test_targets.append(targets)\n", + " all_test_where_most_certain.append(where_most_certain)\n", + "\n", + " all_test_predictions = torch.cat(all_test_predictions, dim=0)\n", + " all_test_targets = torch.cat(all_test_targets, dim=0)\n", + " all_test_where_most_certain = torch.cat(all_test_where_most_certain, dim=0)\n", + "\n", + " test_accuracy = (all_test_predictions.argmax(2)[torch.arange(all_test_predictions.size(0), device=predictions.device), :, all_test_where_most_certain] == all_test_targets).float().mean().item()\n", + " test_loss = sum(all_test_losses) / len(all_test_losses)\n", + "\n", + " test_losses.append(test_loss)\n", + " test_accuracies.append(test_accuracy)\n", + " steps.append(stepi)\n", + "\n", + " create_maze_gif_visualization(model, testloader, device, log_dir)\n", + "\n", + " model.train()\n", + "\n", + " update_training_curve_plot(fig, ax1, ax2, train_losses, test_losses, train_accuracies, test_accuracies, steps)\n", + "\n", + " pbar_desc = make_pbar_desc(train_loss=train_loss.item(), train_accuracy=train_accuracy, test_loss=test_loss, test_accuracy=test_accuracy, lr=optimizer.param_groups[-1][\"lr\"], where_most_certain=where_most_certain)\n", + " pbar.set_description(pbar_desc)\n", + " pbar.update(1)\n", + "\n", + " plt.ioff()\n", + " plt.close(fig)\n", + " return model\n", + "\n", + "def create_maze_gif_visualization(model, testloader, device, log_dir):\n", + " model.eval()\n", + " with torch.no_grad():\n", + " inputs_viz, targets_viz = next(iter(testloader))\n", + " inputs_viz = inputs_viz.to(device)\n", + " targets_viz = targets_viz.to(device)\n", + "\n", + " predictions_raw, certainties, _, pre_activations, post_activations, attention = model(inputs_viz, track=True)\n", + " \n", + " # Reshape predictions\n", + " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 2, predictions_raw.size(-1))\n", + " \n", + " attention = reshape_attention_weights(attention)\n", + " inputs = reshape_inputs(inputs_viz, 50, grid_size=grid_size)\n", + "\n", + " # Generate the parity GIF\n", + " make_parity_gif(\n", + " predictions.detach().cpu().numpy(),\n", + " certainties.detach().cpu().numpy(),\n", + " targets_viz.detach().cpu().numpy(),\n", + " pre_activations,\n", + " post_activations,\n", + " attention,\n", + " inputs,\n", + " f'{log_dir}/prediction.gif',\n", + " )\n", + " \n", + " predictions_raw, certainties, _ = model(inputs_viz)\n", + " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 2, predictions_raw.size(-1))\n" + ] + }, + { + "cell_type": "markdown", + "id": "67e6c8fc", + "metadata": {}, + "source": [ + "### Initialzing the CTM " + ] + }, + { + "cell_type": "markdown", + "id": "7a9cac52", + "metadata": {}, + "source": [ + "Next we initialize the CTM. There are three important arguments to highlight for this task, which differ from, for example, the image classification task.\n", + "\n", + "- `backbone_type = 'parity_backbone'`: the backbone type `'parity_backbone'`, which is defined in the CTM repo, is a learned embedding layer which embeds the binary values in the input sequence.\n", + "- `positional_embedding_type = 'custom-rotational-1d'`: a positional embedding for each position in the parity sequence. These positional embeddings are added to the embedding vectors (produced by the backbone) during the forward pass.\n", + "- `prediction_reshaper = [parity_sequence_length, 2]`: the CTM has an optional argument `prediction_reshaper`. This is required when the output of the model is a sequence. For instance, it is required here where the output is a sequence of parities, or in the maze task where the output is a sequence of actions. This prediction reshaper is used in each internal tick of the CTM when the certainty of the models output is computed. Generally, the prediction reshaper should be like `[SEQUENCE_LENGTH, NUM_CLASS]`." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "2c180995", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using neuron select type: random-pairing\n", + "Synch representation size action: 256\n", + "Synch representation size out: 256\n", + "Model parameters: 652,106\n" + ] + } + ], + "source": [ + "# Set device\n", + "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "\n", + "# Define the model\n", + "model = CTM(\n", + " iterations = 50,\n", + " d_model = 256,\n", + " d_input = 32,\n", + " heads = 8,\n", + " n_synch_out = 256,\n", + " n_synch_action = 256,\n", + " synapse_depth = 8,\n", + " memory_length = 25,\n", + " deep_nlms = True,\n", + " memory_hidden_dims = 16,\n", + " backbone_type = 'parity_backbone',\n", + " out_dims = parity_sequence_length * 2,\n", + " prediction_reshaper = [parity_sequence_length, 2],\n", + " dropout = 0.0,\n", + " do_layernorm_nlm = False,\n", + " positional_embedding_type = 'custom-rotational-1d'\n", + ").to(device)\n", + "\n", + "# Initialize model parameters with dummy forward pass\n", + "sample_batch = next(iter(trainloader))\n", + "dummy_input = sample_batch[0][:1].to(device)\n", + "with torch.no_grad():\n", + " _ = model(dummy_input)\n", + "\n", + "print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a89daf13", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Train Loss=0.007. Train Acc=0.997. Test Loss=0.015. Test Acc=0.994. LR=0.000100. Where_certain=43.95+-3.85 (33<->49).: 100%|██████████| 20000/20000 [1:55:52<00:00, 2.88it/s] \n" + ] + } + ], + "source": [ + "model = train(model=model, trainloader=trainloader, testloader=testloader, device=device, training_iterations=20000, lr=1e-4, log_dir='./parity_logs')" + ] + }, + { + "cell_type": "markdown", + "id": "ca9cbd9b", + "metadata": {}, + "source": [ + "Visualise a gif of a solution" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "6ba0b9e4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "reader = imageio.get_reader(\"parity_logs/prediction.gif\")\n", + "frames = [reader.get_data(i) for i in range(min(len(reader), 100))]\n", + "mediapy.show_video(frames, width=400, codec=\"gif\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "atm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}