ciaran-regan-ie commited on
Commit
a9685d2
Β·
1 Parent(s): 69b35a9

quick maze task

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. examples/03_mazes.ipynb +515 -0
  3. tasks/mazes/plotting.py +2 -2
.gitignore CHANGED
@@ -17,4 +17,5 @@ data/*
17
  examples/*
18
  !examples/01_mnist.ipynb
19
  !examples/02_inference.ipynb
 
20
  checkpoints
 
17
  examples/*
18
  !examples/01_mnist.ipynb
19
  !examples/02_inference.ipynb
20
+ !examples/03_mazes.ipynb
21
  checkpoints
examples/03_mazes.ipynb ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "04a72c0e",
6
+ "metadata": {},
7
+ "source": [
8
+ "# The Continuous Thought Machine – Tutorial 03: Mazes [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SakanaAI/continuous-thought-machines/blob/main/examples/01_mnist.ipynb) [![arXiv](https://img.shields.io/badge/arXiv-2505.05522-b31b1b.svg)](https://arxiv.org/abs/2505.05522)"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "b05cf27b",
14
+ "metadata": {},
15
+ "source": [
16
+ "### Maze Solving\n",
17
+ "\n",
18
+ "In Section 4 of the [technical report](https://arxiv.org/pdf/2505.05522), we showcase how a CTM can be used to solve 2D mazes. Typically in the literature, the task of solving a maze is described as a form of binary classification: by ensuring the output space matches the dimensions of the input space, a model can classify, for each pixel in the input image, if the pixel belongs to the path though the maze. While this approach has seen success, it exludes the need to think in a more natural fashion. We seek to design a more challening task, where a more-human like solution is required.\n",
19
+ "\n"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "bc2f769e",
25
+ "metadata": {},
26
+ "source": [
27
+ "### Task Details\n",
28
+ "\n",
29
+ "Instead of classifying if each pixel in the maze is or is not along the solution path, we instead constrain the output space, such that the model must output a plan: an entire trajectoru of actions corresponding to the steps an agent must take to go from the start position to goal position. Specifically, for each internal tick of the CTM, the model produces a sequence of actions. We can then use a cross-entropy loss function which compares these actions to the ground-truth trajectory required to solve the maze."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "id": "ec34151e",
35
+ "metadata": {},
36
+ "source": [
37
+ "### Task Details\n"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "id": "c257dbd3",
43
+ "metadata": {},
44
+ "source": [
45
+ "In addition to install some dependencies, we also clone the CTM repo (assuming this tutorial is being ran in Colab), so that we can access the base CTM model."
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 1,
51
+ "id": "c1ccfdcf",
52
+ "metadata": {},
53
+ "outputs": [
54
+ {
55
+ "name": "stdout",
56
+ "output_type": "stream",
57
+ "text": [
58
+ "Requirement already satisfied: gdown in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (5.2.0)\n",
59
+ "Requirement already satisfied: beautifulsoup4 in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from gdown) (4.13.4)\n",
60
+ "Requirement already satisfied: filelock in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from gdown) (3.17.0)\n",
61
+ "Requirement already satisfied: requests[socks] in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from gdown) (2.32.3)\n",
62
+ "Requirement already satisfied: tqdm in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from gdown) (4.67.1)\n",
63
+ "Requirement already satisfied: soupsieve>1.2 in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from beautifulsoup4->gdown) (2.7)\n",
64
+ "Requirement already satisfied: typing-extensions>=4.0.0 in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from beautifulsoup4->gdown) (4.13.2)\n",
65
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from requests[socks]->gdown) (3.4.1)\n",
66
+ "Requirement already satisfied: idna<4,>=2.5 in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from requests[socks]->gdown) (3.10)\n",
67
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from requests[socks]->gdown) (2.3.0)\n",
68
+ "Requirement already satisfied: certifi>=2017.4.17 in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from requests[socks]->gdown) (2025.4.26)\n",
69
+ "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages (from requests[socks]->gdown) (1.7.1)\n",
70
+ "^C\n",
71
+ "\u001b[31mERROR: Operation cancelled by user\u001b[0m\u001b[31m\n",
72
+ "\u001b[0mfatal: destination path 'continuous-thought-machines' already exists and is not an empty directory.\n"
73
+ ]
74
+ }
75
+ ],
76
+ "source": [
77
+ "!pip install gdown\n",
78
+ "!pip install mediapy\n",
79
+ "!git clone https://github.com/SakanaAI/continuous-thought-machines.git\n"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "markdown",
84
+ "id": "6397bf56",
85
+ "metadata": {},
86
+ "source": [
87
+ "Download the training and test data"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "634a61d8",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "!gdown \"https://drive.google.com/uc?id=1Z8FFnZ7pZcu7DfoSyfy-ghWa08lgYl1V\"\n",
98
+ "!unzip \"small-mazes.zip\""
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "markdown",
103
+ "id": "1ab57a96",
104
+ "metadata": {},
105
+ "source": [
106
+ "Imports"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 1,
112
+ "id": "24ffe416",
113
+ "metadata": {},
114
+ "outputs": [
115
+ {
116
+ "name": "stderr",
117
+ "output_type": "stream",
118
+ "text": [
119
+ "/home/ciaran_sakana_ai/.conda/envs/atm/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
120
+ " from .autonotebook import tqdm as notebook_tqdm\n"
121
+ ]
122
+ }
123
+ ],
124
+ "source": [
125
+ "import sys\n",
126
+ "sys.path.append(\"./continuous-thought-machines\")\n",
127
+ "\n",
128
+ "import torch\n",
129
+ "import torch.nn as nn\n",
130
+ "from tqdm.auto import tqdm\n",
131
+ "import os\n",
132
+ "\n",
133
+ "# From CTM repo\n",
134
+ "from models.ctm import ContinuousThoughtMachine as CTM\n",
135
+ "from data.custom_datasets import MazeImageFolder\n",
136
+ "from tasks.mazes.plotting import make_maze_gif\n",
137
+ "from tasks.image_classification.plotting import plot_neural_dynamics"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "id": "4407a4a8",
143
+ "metadata": {},
144
+ "source": [
145
+ "Prepare the data"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": 2,
151
+ "id": "cc28def0",
152
+ "metadata": {},
153
+ "outputs": [
154
+ {
155
+ "name": "stderr",
156
+ "output_type": "stream",
157
+ "text": [
158
+ "Loading mazes: 1%|▏ | 134/9000 [00:00<00:15, 583.49it/s]"
159
+ ]
160
+ },
161
+ {
162
+ "name": "stderr",
163
+ "output_type": "stream",
164
+ "text": [
165
+ "Loading mazes: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 9000/9000 [00:15<00:00, 572.59it/s]\n"
166
+ ]
167
+ },
168
+ {
169
+ "name": "stdout",
170
+ "output_type": "stream",
171
+ "text": [
172
+ "Solving all mazes...\n"
173
+ ]
174
+ },
175
+ {
176
+ "name": "stderr",
177
+ "output_type": "stream",
178
+ "text": [
179
+ "Loading mazes: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1000/1000 [00:01<00:00, 643.73it/s]\n"
180
+ ]
181
+ },
182
+ {
183
+ "name": "stdout",
184
+ "output_type": "stream",
185
+ "text": [
186
+ "Solving all mazes...\n"
187
+ ]
188
+ }
189
+ ],
190
+ "source": [
191
+ "data_root = './small-mazes'\n",
192
+ "train_data = MazeImageFolder(root=f'{data_root}/train/', which_set='train', maze_route_length=50)\n",
193
+ "test_data = MazeImageFolder(root=f'{data_root}/test/', which_set='test', maze_route_length=50)\n",
194
+ "\n",
195
+ "trainloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0, drop_last=True)\n",
196
+ "testloader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=True, num_workers=1, drop_last=False)"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "id": "0e164024",
202
+ "metadata": {},
203
+ "source": [
204
+ "Next, let's define the loss for the maze task.\n",
205
+ "\n"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": 3,
211
+ "id": "63e75f71",
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "def maze_loss(predictions, certainties, targets, cirriculum_lookahead=5, use_most_certain=True):\n",
216
+ " \"\"\"\n",
217
+ " Computes the maze loss with auto-extending cirriculum.\n",
218
+ "\n",
219
+ " Predictions are of shape: (B, route_length, class, internal_ticks),\n",
220
+ " where classes are in [0,1,2,3,4] for [Up, Down, Left, Right, Wait]\n",
221
+ " Certainties are of shape: (B, 2, internal_ticks), \n",
222
+ " where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy]\n",
223
+ " Targets are of shape: [B, route_length]\n",
224
+ "\n",
225
+ " cirriculum_lookahead: how far to look ahead in the auto-cirriculum\n",
226
+ "\n",
227
+ " use_most_certain will select either the most certain point or the final point. For baselines,\n",
228
+ " the final point proved the only usable option. \n",
229
+ " \n",
230
+ " \"\"\"\n",
231
+ " # Predictions reshaped to: [B*route_length, 5, internal_ticks]\n",
232
+ " predictions_reshaped = predictions.flatten(0,1)\n",
233
+ " # Targets reshaped to: [B*route_length, internal_ticks]\n",
234
+ " targets_reshaped = torch.repeat_interleave(targets.unsqueeze(-1), \n",
235
+ " predictions.size(-1), -1).flatten(0,1).long()\n",
236
+ " \n",
237
+ " # Losses are of shape [B, route_length, internal_ticks]\n",
238
+ " losses = nn.CrossEntropyLoss(reduction='none')(predictions_reshaped, targets_reshaped)\n",
239
+ " losses = losses.reshape(predictions[:,:,0].shape)\n",
240
+ " \n",
241
+ " # Below is the code for auto-cirriculum\n",
242
+ " # Find where correct, and make sure to always push +5 beyond that\n",
243
+ " iscorrects = (predictions.argmax(2) == targets.unsqueeze(-1)).cumsum(1)\n",
244
+ " correct_mask = (iscorrects == torch.arange(1, iscorrects.size(1)+1, device=iscorrects.device).reshape(1, -1, 1))\n",
245
+ " correct_mask[:,0,:] = 1\n",
246
+ " upto_where = correct_mask.cumsum(1).argmax(1).max(-1)[0]+cirriculum_lookahead\n",
247
+ " loss_mask = torch.zeros_like(losses)\n",
248
+ " for bi in range(predictions.size(0)):\n",
249
+ " loss_mask[bi, :upto_where[bi]] = 1\n",
250
+ "\n",
251
+ " # Reduce losses along route dimension\n",
252
+ " # Will now be of shape [B, internal_ticks]\n",
253
+ " losses = (losses * loss_mask).sum(1)/(loss_mask.sum(1))\n",
254
+ "\n",
255
+ " loss_index_1 = losses.argmin(dim=1)\n",
256
+ " loss_index_2 = certainties[:,1].argmax(-1)\n",
257
+ " if not use_most_certain:\n",
258
+ " loss_index_2[:] = -1\n",
259
+ " \n",
260
+ " batch_indexer = torch.arange(predictions.size(0), device=predictions.device)\n",
261
+ " loss_minimum_ce = losses[batch_indexer, loss_index_1]\n",
262
+ " loss_selected = losses[batch_indexer, loss_index_2]\n",
263
+ "\n",
264
+ " loss = ((loss_minimum_ce + loss_selected)/2).mean()\n",
265
+ " return loss, loss_index_2, upto_where.detach().cpu().numpy()"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "markdown",
270
+ "id": "89cb8dd7",
271
+ "metadata": {},
272
+ "source": [
273
+ "We define a helper function to update the training progress bar with key metrics:\n",
274
+ "\n",
275
+ "Displayed Metrics:\n",
276
+ "\n",
277
+ "- Train Loss & Test Loss: Standard loss values during training\n",
278
+ "- Train Acc (Step): The average accuracy for individual directional predictions across all steps in the maze trajectories\n",
279
+ "- Statistics on the internal ticks:\n",
280
+ " - Average internal tick where the model has highest certainty\n",
281
+ " - Standard deviation of these certainty peaks\n",
282
+ " - Range (min ↔ max) of internal ticks where peak certainty occurs\n",
283
+ " - The position upto which the loss is taken, on average\n"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": 20,
289
+ "id": "fb57caee",
290
+ "metadata": {},
291
+ "outputs": [],
292
+ "source": [
293
+ "def make_pbar_desc(train_loss, train_accuracy_finegrained, test_loss, optimizer, where_most_certain, upto_where):\n",
294
+ " \"\"\"A helper function to create a description for the tqdm progress bar\"\"\"\n",
295
+ " pbar_desc = f'Train Loss={train_loss if isinstance(train_loss, float) else train_loss.item():0.3f}. Train Acc(step)={train_accuracy_finegrained:0.3f}. Test Loss={test_loss if isinstance(test_loss, float) else test_loss.item():0.3f}. LR={optimizer.param_groups[-1][\"lr\"]:0.6f}.'\n",
296
+ " pbar_desc += f' Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d}). Upto: {sum(upto_where) / len(upto_where):0.2f}.'\n",
297
+ " return pbar_desc"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": null,
303
+ "id": "02de7c62",
304
+ "metadata": {},
305
+ "outputs": [],
306
+ "source": [
307
+ "def train(model, trainloader, testloader, device='cpu', training_iterations=10000, test_every=1000, lr=1e-4, log_dir='./logs'):\n",
308
+ "\n",
309
+ " os.makedirs(log_dir, exist_ok=True)\n",
310
+ " \n",
311
+ " model.train()\n",
312
+ " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
313
+ " iterator = iter(trainloader)\n",
314
+ " \n",
315
+ " train_losses = []\n",
316
+ " test_losses = []\n",
317
+ " \n",
318
+ " with tqdm(total=training_iterations) as pbar:\n",
319
+ " for stepi in range(training_iterations):\n",
320
+ "\n",
321
+ " try:\n",
322
+ " inputs, targets = next(iterator)\n",
323
+ " except StopIteration:\n",
324
+ " iterator = iter(trainloader)\n",
325
+ " inputs, targets = next(iterator)\n",
326
+ " \n",
327
+ " inputs, targets = inputs.to(device), targets.to(device)\n",
328
+ " \n",
329
+ " optimizer.zero_grad()\n",
330
+ " \n",
331
+ " predictions_raw, certainties, _ = model(inputs)\n",
332
+ "\n",
333
+ " # Reshape: (B, SeqLength, 5, Ticks)\n",
334
+ " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))\n",
335
+ " \n",
336
+ " # Compute loss\n",
337
+ " train_loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, use_most_certain=True)\n",
338
+ "\n",
339
+ " train_accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()\n",
340
+ "\n",
341
+ " train_loss.backward()\n",
342
+ " optimizer.step()\n",
343
+ " \n",
344
+ " train_losses.append(train_loss)\n",
345
+ "\n",
346
+ " if stepi % test_every == 0 or stepi == 0:\n",
347
+ " model.eval()\n",
348
+ " with torch.no_grad():\n",
349
+ " test_loss_per_batch = []\n",
350
+ " for inputs, targets in testloader:\n",
351
+ " inputs = inputs.to(device)\n",
352
+ " targets = targets.to(device)\n",
353
+ " \n",
354
+ " predictions_raw, certainties, _ = model(inputs)\n",
355
+ " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))\n",
356
+ " \n",
357
+ " test_loss, _, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)\n",
358
+ " test_loss_per_batch.append(test_loss.item())\n",
359
+ "\n",
360
+ " test_loss = sum(test_loss_per_batch) / len(test_loss_per_batch)\n",
361
+ " test_losses.append(test_loss)\n",
362
+ "\n",
363
+ " create_maze_gif_visualization(model, testloader, device, log_dir)\n",
364
+ " model.train()\n",
365
+ "\n",
366
+ " pbar_desc = make_pbar_desc(train_loss, train_accuracy_finegrained, test_loss, optimizer, where_most_certain, upto_where)\n",
367
+ " pbar.set_description(pbar_desc)\n",
368
+ " pbar.update(1)\n",
369
+ " \n",
370
+ " return train_losses, test_losses\n",
371
+ "\n",
372
+ "def create_maze_gif_visualization(model, testloader, device, log_dir):\n",
373
+ " \"\"\"\n",
374
+ " Create GIF visualization of maze solving with attention tracking\n",
375
+ " \"\"\"\n",
376
+ " \n",
377
+ " model.eval()\n",
378
+ " with torch.no_grad():\n",
379
+ " # Get a test batch\n",
380
+ " inputs_viz, targets_viz = next(iter(testloader))\n",
381
+ " inputs_viz = inputs_viz.to(device)\n",
382
+ " targets_viz = targets_viz.to(device)\n",
383
+ " \n",
384
+ " batch_index_to_viz = 0\n",
385
+ " \n",
386
+ " predictions_raw, certainties, _, pre_activations, post_activations, attention_tracking = model(inputs_viz, track=True)\n",
387
+ " \n",
388
+ " # Reshape predictions\n",
389
+ " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))\n",
390
+ " \n",
391
+ " # Reshape attention tracking for visualization\n",
392
+ " att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])\n",
393
+ " attention_tracking = attention_tracking.reshape(attention_tracking.shape[0], attention_tracking.shape[1], -1, att_shape[0], att_shape[1])\n",
394
+ "\n",
395
+ " plot_neural_dynamics(post_activations, 100, log_dir, axis_snap=True)\n",
396
+ " \n",
397
+ " # Create maze GIF with attention visualization\n",
398
+ " maze_input = (inputs_viz[batch_index_to_viz].detach().cpu().numpy() + 1) / 2\n",
399
+ " maze_predictions = predictions[batch_index_to_viz].detach().cpu().numpy()\n",
400
+ " maze_targets = targets_viz[batch_index_to_viz].detach().cpu().numpy()\n",
401
+ " maze_attention = attention_tracking[:, batch_index_to_viz] if attention_tracking.ndim > 2 else attention_tracking\n",
402
+ "\n",
403
+ " # Generate the maze GIF\n",
404
+ " make_maze_gif(\n",
405
+ " maze_input,\n",
406
+ " maze_predictions,\n",
407
+ " maze_targets,\n",
408
+ " maze_attention,\n",
409
+ " log_dir\n",
410
+ " )\n",
411
+ " \n",
412
+ " predictions_raw, certainties, _ = model(inputs_viz)\n",
413
+ " predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))\n"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": null,
419
+ "id": "2c180995",
420
+ "metadata": {},
421
+ "outputs": [
422
+ {
423
+ "name": "stdout",
424
+ "output_type": "stream",
425
+ "text": [
426
+ "Using neuron select type: random-pairing\n",
427
+ "Synch representation size action: 32\n",
428
+ "Synch representation size out: 32\n",
429
+ "Model parameters: 3,304,684\n"
430
+ ]
431
+ },
432
+ {
433
+ "name": "stderr",
434
+ "output_type": "stream",
435
+ "text": [
436
+ " 0%| | 0/100000 [00:00<?, ?it/s]"
437
+ ]
438
+ },
439
+ {
440
+ "name": "stderr",
441
+ "output_type": "stream",
442
+ "text": [
443
+ "Processing frames for maze plotting: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:05<00:00, 8.52it/s]\n",
444
+ "Train Loss=1.458. Train Acc(step)=0.197. Test Loss=1.613. LR=0.000100. Where_certain=31.31+-11.91 (13<->49). Upto: 6.44.: 1%| | 565/100000 [03:16<8:28:46, 3.26it/s] "
445
+ ]
446
+ }
447
+ ],
448
+ "source": [
449
+ "# Set device\n",
450
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
451
+ "\n",
452
+ "# Define the model\n",
453
+ "model = CTM(\n",
454
+ " iterations=50, # Number of thinking steps\n",
455
+ " d_model=512, # Model dimension\n",
456
+ " d_input=128, # Input dimension\n",
457
+ " heads=8, # Attention heads\n",
458
+ " n_synch_out=32, # Output synchronization neurons\n",
459
+ " n_synch_action=32, # Action synchronization neurons\n",
460
+ " synapse_depth=8, # Synapse network depth\n",
461
+ " memory_length=15, # Memory length\n",
462
+ " deep_nlms=True, # Use deep memory\n",
463
+ " memory_hidden_dims=16, # Memory hidden dimensions\n",
464
+ " backbone_type='resnet34-2', # Feature extractor\n",
465
+ " out_dims=50 * 5, # Output dimensions (route_length * 5 directions)\n",
466
+ " prediction_reshaper=[50, 5], # Reshape to [route_length, directions]\n",
467
+ " dropout=0.1,\n",
468
+ " do_layernorm_nlm=False,\n",
469
+ " positional_embedding_type='none'\n",
470
+ ").to(device)\n",
471
+ "\n",
472
+ "# Initialize model parameters with dummy forward pass\n",
473
+ "sample_batch = next(iter(trainloader))\n",
474
+ "dummy_input = sample_batch[0][:1].to(device)\n",
475
+ "with torch.no_grad():\n",
476
+ " _ = model(dummy_input)\n",
477
+ "\n",
478
+ "print(f'Model parameters: {sum(p.numel() for p in model.parameters()):,}')\n",
479
+ "\n",
480
+ "# Train the model\n",
481
+ "log_dir = './maze_training_logs'\n",
482
+ "train_losses, test_losses = train(\n",
483
+ " model=model,\n",
484
+ " trainloader=trainloader,\n",
485
+ " testloader=testloader,\n",
486
+ " device=device,\n",
487
+ " training_iterations=100000,\n",
488
+ " lr=1e-4,\n",
489
+ " log_dir=log_dir\n",
490
+ ")"
491
+ ]
492
+ }
493
+ ],
494
+ "metadata": {
495
+ "kernelspec": {
496
+ "display_name": "atm",
497
+ "language": "python",
498
+ "name": "python3"
499
+ },
500
+ "language_info": {
501
+ "codemirror_mode": {
502
+ "name": "ipython",
503
+ "version": 3
504
+ },
505
+ "file_extension": ".py",
506
+ "mimetype": "text/x-python",
507
+ "name": "python",
508
+ "nbconvert_exporter": "python",
509
+ "pygments_lexer": "ipython3",
510
+ "version": "3.13.2"
511
+ }
512
+ },
513
+ "nbformat": 4,
514
+ "nbformat_minor": 5
515
+ }
tasks/mazes/plotting.py CHANGED
@@ -88,7 +88,7 @@ def draw_path(x, route, valid_only=False, gt=False, cmap=None):
88
 
89
  return x
90
 
91
- def make_maze_gif(inputs, predictions, targets, attention_tracking, save_location):
92
  """
93
  Expect inputs, predictions, targets as numpy arrays
94
  """
@@ -130,7 +130,7 @@ def make_maze_gif(inputs, predictions, targets, attention_tracking, save_locatio
130
  cmap_viridis = plt.get_cmap('viridis')
131
  step_linspace = np.linspace(0, 1, predictions.shape[-1]) # For sampling colours
132
  with tqdm(total=predictions.shape[-1], initial=0, leave=True, position=1, dynamic_ncols=True) as pbar:
133
- pbar.set_description('Processing frames for maze plotting')
134
  for stepi in np.arange(0, predictions.shape[-1], 1):
135
  fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
136
  for ax in axes.values():
 
88
 
89
  return x
90
 
91
+ def make_maze_gif(inputs, predictions, targets, attention_tracking, save_location, verbose=True):
92
  """
93
  Expect inputs, predictions, targets as numpy arrays
94
  """
 
130
  cmap_viridis = plt.get_cmap('viridis')
131
  step_linspace = np.linspace(0, 1, predictions.shape[-1]) # For sampling colours
132
  with tqdm(total=predictions.shape[-1], initial=0, leave=True, position=1, dynamic_ncols=True) as pbar:
133
+ if verbose: pbar.set_description('Processing frames for maze plotting')
134
  for stepi in np.arange(0, predictions.shape[-1], 1):
135
  fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
136
  for ax in axes.values():