LukeDarlow commited on
Commit
68b32f4
·
0 Parent(s):

Welcome to the CTM. This is the first commit of the public repo. Enjoy!

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +18 -0
  2. README.md +134 -0
  3. data/custom_datasets.py +324 -0
  4. examples/01_mnist.ipynb +0 -0
  5. models/README.md +7 -0
  6. models/constants.py +10 -0
  7. models/ctm.py +552 -0
  8. models/ctm_qamnist.py +205 -0
  9. models/ctm_rl.py +192 -0
  10. models/ctm_sort.py +126 -0
  11. models/ff.py +75 -0
  12. models/lstm.py +244 -0
  13. models/lstm_qamnist.py +184 -0
  14. models/lstm_rl.py +96 -0
  15. models/modules.py +692 -0
  16. models/resnet.py +374 -0
  17. models/utils.py +122 -0
  18. requirements.txt +15 -0
  19. tasks/image_classification/README.md +29 -0
  20. tasks/image_classification/analysis/README.md +12 -0
  21. tasks/image_classification/analysis/run_imagenet_analysis.py +972 -0
  22. tasks/image_classification/imagenet_classes.py +1007 -0
  23. tasks/image_classification/plotting.py +494 -0
  24. tasks/image_classification/scripts/train_cifar10.sh +286 -0
  25. tasks/image_classification/scripts/train_imagenet.sh +38 -0
  26. tasks/image_classification/train.py +685 -0
  27. tasks/image_classification/train_distributed.py +799 -0
  28. tasks/mazes/README.md +10 -0
  29. tasks/mazes/analysis/README.md +10 -0
  30. tasks/mazes/analysis/run.py +407 -0
  31. tasks/mazes/plotting.py +198 -0
  32. tasks/mazes/scripts/train_ctm.sh +35 -0
  33. tasks/mazes/train.py +698 -0
  34. tasks/mazes/train_distributed.py +782 -0
  35. tasks/parity/README.md +16 -0
  36. tasks/parity/analysis/make_blog_gifs.py +263 -0
  37. tasks/parity/analysis/run.py +269 -0
  38. tasks/parity/plotting.py +896 -0
  39. tasks/parity/scripts/train_ctm_100_50.sh +46 -0
  40. tasks/parity/scripts/train_ctm_10_5.sh +46 -0
  41. tasks/parity/scripts/train_ctm_1_1.sh +46 -0
  42. tasks/parity/scripts/train_ctm_25_10.sh +46 -0
  43. tasks/parity/scripts/train_ctm_50_25.sh +46 -0
  44. tasks/parity/scripts/train_ctm_75_25.sh +46 -0
  45. tasks/parity/scripts/train_lstm_1.sh +39 -0
  46. tasks/parity/scripts/train_lstm_10.sh +39 -0
  47. tasks/parity/scripts/train_lstm_100.sh +39 -0
  48. tasks/parity/scripts/train_lstm_10_certain.sh +40 -0
  49. tasks/parity/scripts/train_lstm_25.sh +39 -0
  50. tasks/parity/scripts/train_lstm_25_certain.sh +40 -0
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ */__pycache__
2
+ logs
3
+ .DS_Store
4
+ *.png
5
+ *.pdf
6
+ *.gif
7
+ *.out
8
+ *.pyc
9
+ *.env
10
+ *.pt
11
+ *.mp4
12
+ .vscode*
13
+ *outputs*
14
+ data/*
15
+ !assets/*.gif
16
+ !data/custom_datasets.py
17
+ examples/*
18
+ !examples/01_mnist.ipynb
README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🕰️ The Continuous Thought Machine
2
+
3
+ 📚 [PAPER: Technical Report](https://pub.sakana.ai/ctm/paper) | 📝 [Blog](https://sakana.ai/ctm/) | 🕹️ [Interactive Website](https:pub.sakana.ai/ctm)
4
+
5
+ ![Activations](assets/activations.gif)
6
+
7
+ We present the Continuous Thought Machine (CTM), a model designed to unfold and then leverage neural activity as the underlying mechanism for observation and action. The CTM has two core innovations:
8
+
9
+ 1. Neuron-level temporal processing, where each neuron uses unique weight parameters to process a history of incoming signals, enabling fine-grained temporal dynamics.
10
+
11
+ 2. Neural synchronisation, employed as a direct latent representation for modulating data and producing outputs, thus directly encoding information in the timing of neural activity.
12
+
13
+ We demonstrate the CTM's strong performance and versatility across a range of challenging tasks, including ImageNet classification, solving 2D mazes, sorting, parity computation, question-answering, and RL tasks.
14
+
15
+ We provide all necessary code to reproduce our results and invite others to build upon and use CTMs in their own work.
16
+
17
+ ## Repo structure
18
+ ```
19
+ ├── tasks
20
+ │   ├── image_classification
21
+ │   │   ├── train.py # Training code for image classification (cifar, imagenet)
22
+ │   │   ├── imagenet_classes.py # Helper for imagenet class names
23
+ │   │   ├── plotting.py # Plotting utils specific to this task
24
+ │   │   └── analysis
25
+ │   │   ├──run_imagenet_analysis.py # ImageNet eval and visualisation code
26
+ │   │      └──outputs/ # Folder for outputs of analysis
27
+ │   ├── mazes
28
+ │   │   ├── train.py # Training code for solving 2D mazes (by way of a route; see paper)
29
+ │   │   └── plotting.py # Plotting utils specific to this task
30
+ │   │   └── analysis
31
+ │   │   ├──run.py # Maze analysis code
32
+ │   │      └──outputs/ # Folder for outputs of analysis
33
+ │   ├── sort
34
+ │   │   ├── train.py # Training code for sorting
35
+ │   │   └── utils.py # Sort specific utils (e.g., CTC decode)
36
+ │   ├── parity
37
+ │   │   ├── train.py # Training code for parity task
38
+ │   │   ├── utils.py # Parity-specific helper functions
39
+ │   │   ├── plotting.py # Plotting utils specific to this task
40
+ │   │   ├── scripts/
41
+ │   │   │   └── *.sh # Training scripts for different experimental setups
42
+ │   │   └── analysis/
43
+ │   │   └── run.py # Entry point for parity analysis
44
+ │   ├── qamnist
45
+ │   │   ├── train.py # Training code for QAMNIST task (quantized MNIST)
46
+ │   │   ├── utils.py # QAMNIST-specific helper functions
47
+ │   │   ├── plotting.py # Plotting utils specific to this task
48
+ │   │   ├── scripts/
49
+ │   │   │   └── *.sh # Training scripts for different experimental setups
50
+ │   │   └── analysis/
51
+ │   │   └── run.py # Entry point for QAMNIST analysis
52
+ │   └── rl
53
+ │      ├── train.py # Training code for RL environments
54
+ │      ├── utils.py # RL-specific helper functions
55
+ │      ├── plotting.py # Plotting utils specific to this task
56
+ │      ├── envs.py # Custom RL environment wrappers
57
+ │      ├── scripts/
58
+ │      │   ├── 4rooms/
59
+ │      │   │   └── *.sh # Training scripts for MiniGrid-FourRooms-v0 environment
60
+ │      │   ├── acrobot/
61
+ │      │   │   └── *.sh # Training scripts for Acrobot-v1 environment
62
+ │      │   └── cartpole/
63
+ │      │   └── *.sh # Training scripts for CartPole-v1 environment
64
+ │      └── analysis/
65
+ │      └── run.py # Entry point for RL analysis
66
+ ├── data # This is where data will be saved and downloaded to
67
+ │   └── custom_datasets.py # Custom datasets (e.g., Mazes), sort
68
+ ├── models
69
+ │   ├── ctm.py # Main model code, used for: image classification, solving mazes, sort
70
+ │   ├── ctm_*.py # Other model code, standalone adjustments for other tasks
71
+ │   ├── ff.py # feed-forward (simple) baseline code (e.g., for image classification)
72
+ │   ├── lstm.py # LSTM baseline code (e.g., for image classification)
73
+ │   ├── lstm_*.py # Other baseline code, standalone adjustments for other tasks
74
+ │   ├── modules.py # Helper modules, including Neuron-level models and the Synapse UNET
75
+ │   ├── utils.py # Helper functions (e.g., synch decay)
76
+ │   └── resnet.py # Wrapper for ResNet featuriser
77
+ ├── utils
78
+ │   ├── housekeeping.py # Helper functions for keeping things neat
79
+ │   ├── losses.py # Loss functions for various tasks (mostly with reshaping stuff)
80
+ │   └── schedulers.py # Helper wrappers for learning rate schedulers
81
+ └── checkpoints
82
+    └── imagenet, mazes, ... # Checkpoint directories (see google drive link for files)
83
+
84
+ ```
85
+
86
+ ## Setup
87
+ To set up the environment using conda:
88
+
89
+ ```
90
+ conda create --name=ctm python=3.12
91
+ conda activate ctm
92
+ pip install -r requirements.txt
93
+ ```
94
+
95
+ If there are issues with PyTorch versions, the following can be ran:
96
+ ```
97
+ pip uninstall torch
98
+ pip install torch --index-url https://download.pytorch.org/whl/cu121
99
+ ```
100
+
101
+ ## Model training
102
+ Each task has its own (set of) training code. See for instance [tasks/image_classification/train.py](tasks/image_classification/train.py). We have set it up like this to ensure ease-of-use as opposed to clinical efficiency. This code is for researchers and we hope to have it shared in a way that fosters collaboration and learning.
103
+
104
+ While we have provided reasonable defaults in the argparsers of each training setup, scripts to replicate the setups in the paper will typically be found in the accompanying script folders. If you simply want to dive in, run the following as a module (setup like this to make it easy to run many high-level training scripts from the top directory):
105
+
106
+ ```
107
+ python -m tasks.image_classification.train
108
+ ```
109
+ For debugging in VSCode, this configuration example might be helpful to you:
110
+ ```
111
+ {
112
+ "name": "Debug: train image classifier",
113
+ "type": "debugpy",
114
+ "request": "launch",
115
+ "module": "tasks.image_classification.train",
116
+ "console": "integratedTerminal",
117
+ "justMyCode": false
118
+ }
119
+ ```
120
+
121
+
122
+ ## Running analyses
123
+
124
+ We also provide analysis and plotting code to replicate many of the plots in our paper. See `tasks/.../analysis/*` for more details on that. We als provide some data (e.g., the mazes we generated for training) and checkpoints (see [here](#checkpoints-and-data))
125
+
126
+
127
+ ## Checkpoints and data
128
+ You can download the data and checkpoints from here: https://drive.google.com/drive/folders/1f4N0ndIDrRvac5fUnWof33KWhvz8iqo_?usp=drive_link
129
+
130
+ Checkpoints go in the `checkpoints` folder. For instance, when properly populated, the checkpoints folder will have the maze checkpoint in `checkpoints/mazes/...`
131
+
132
+
133
+
134
+
data/custom_datasets.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.datasets import ImageFolder
3
+ from torch.utils.data import Dataset
4
+ import random
5
+ import numpy as np
6
+ from tqdm.auto import tqdm
7
+ from PIL import Image
8
+ from datasets import load_dataset
9
+
10
+ class SortDataset(Dataset):
11
+ def __init__(self, N):
12
+ self.N = N
13
+ def __len__(self):
14
+ return 10000000
15
+ def __getitem__(self, idx):
16
+ data = torch.zeros(self.N).normal_()
17
+ ordering = torch.argsort(data)
18
+ inputs = data
19
+ return (inputs), (ordering)
20
+
21
+ class QAMNISTDataset(Dataset):
22
+ """A QAMNIST dataset that includes plus and minus operations on MNIST digits."""
23
+ def __init__(self, base_dataset, num_images, num_images_delta, num_repeats_per_input, num_operations, num_operations_delta):
24
+ self.base_dataset = base_dataset
25
+
26
+ self.num_images = num_images
27
+ self.num_images_delta = num_images_delta
28
+ self.num_images_range = self._calculate_num_images_range()
29
+
30
+ self.operators = ["+", "-"]
31
+ self.num_operations = num_operations
32
+ self.num_operations_delta = num_operations_delta
33
+ self.num_operations_range = self._calculate_num_operations_range()
34
+
35
+ self.num_repeats_per_input = num_repeats_per_input
36
+
37
+ self.current_num_digits = num_images
38
+ self.current_num_operations = num_operations
39
+
40
+ self.modulo_base = 10
41
+
42
+ self.output_range = [0, 9]
43
+
44
+ def _calculate_num_images_range(self):
45
+ min_val = self.num_images - self.num_images_delta
46
+ max_val = self.num_images + self.num_images_delta
47
+ assert min_val >= 1, f"Minimum number of images must be at least 1, got {min_val}"
48
+ return [min_val, max_val]
49
+
50
+ def _calculate_num_operations_range(self):
51
+ min_val = self.num_operations - self.num_operations_delta
52
+ max_val = self.num_operations + self.num_operations_delta
53
+ assert min_val >= 1, f"Minimum number of operations must be at least 1, got {min_val}"
54
+ return [min_val, max_val]
55
+
56
+ def set_num_digits(self, num_digits):
57
+ self.current_num_digits = num_digits
58
+
59
+ def set_num_operations(self, num_operations):
60
+ self.current_num_operations = num_operations
61
+
62
+ def _get_target_and_question(self, targets):
63
+ question = []
64
+ equations = []
65
+ num_digits = self.current_num_digits
66
+ num_operations = self.current_num_operations
67
+
68
+ # Select the initial digit
69
+ selection_idx = np.random.randint(num_digits)
70
+ first_digit = targets[selection_idx]
71
+ question.extend([selection_idx] * self.num_repeats_per_input)
72
+ # Set current_value to the initial digit (mod is applied in each operation)
73
+ current_value = first_digit % self.modulo_base
74
+
75
+ # For each operation, build an equation line
76
+ for _ in range(num_operations):
77
+ # Choose the operator ('+' or '-')
78
+ operator_idx = np.random.randint(len(self.operators))
79
+ operator = self.operators[operator_idx]
80
+ encoded_operator = -(operator_idx + 1) # -1 for '+', -2 for '-'
81
+ question.extend([encoded_operator] * self.num_repeats_per_input)
82
+
83
+ # Choose the next digit
84
+ selection_idx = np.random.randint(num_digits)
85
+ digit = targets[selection_idx]
86
+ question.extend([selection_idx] * self.num_repeats_per_input)
87
+
88
+ # Compute the new value with immediate modulo reduction
89
+ if operator == '+':
90
+ new_value = (current_value + digit) % self.modulo_base
91
+ else: # operator is '-'
92
+ new_value = (current_value - digit) % self.modulo_base
93
+
94
+ # Build the equation string for this step
95
+ equations.append(f"({current_value} {operator} {digit}) mod {self.modulo_base} = {new_value}")
96
+ # Update current value for the next operation
97
+ current_value = new_value
98
+
99
+ target = current_value
100
+ question_readable = "\n".join(equations)
101
+ return target, question, question_readable
102
+
103
+ def __len__(self):
104
+ return len(self.base_dataset)
105
+
106
+ def __getitem__(self, idx):
107
+ images, targets = [],[]
108
+ for _ in range(self.current_num_digits):
109
+ image, target = self.base_dataset[np.random.randint(self.__len__())]
110
+ images.append(image)
111
+ targets.append(target)
112
+
113
+ observations = torch.repeat_interleave(torch.stack(images, 0), repeats=self.num_repeats_per_input, dim=0)
114
+ target, question, question_readable = self._get_target_and_question(targets)
115
+ return observations, question, question_readable, target
116
+
117
+ class ImageNet(Dataset):
118
+ def __init__(self, which_split, transform):
119
+ """
120
+ Most simple form of the custom dataset structure.
121
+ Args:
122
+ base_dataset (Dataset): The base dataset to sample from.
123
+ N (int): The number of images to construct into an observable sequence.
124
+ R (int): number of repeats
125
+ operators (list): list of operators from which to sample
126
+ action to take on observations (str): can be 'global' to compute operator over full observations, or 'select_K', where K=integer.
127
+ """
128
+ dataset = load_dataset('imagenet-1k', split=which_split, trust_remote_code=True)
129
+
130
+ self.transform = transform
131
+ self.base_dataset = dataset
132
+
133
+ def __len__(self):
134
+ return len(self.base_dataset)
135
+
136
+ def __getitem__(self, idx):
137
+ data_item = self.base_dataset[idx]
138
+ image = self.transform(data_item['image'].convert('RGB'))
139
+ target = data_item['label']
140
+ return image, target
141
+
142
+ class MazeImageFolder(ImageFolder):
143
+ """
144
+ A custom dataset class that extends the ImageFolder class.
145
+
146
+ Args:
147
+ root (string): Root directory path.
148
+ transform (callable, optional): A function/transform that takes in
149
+ a sample and returns a transformed version.
150
+ E.g, ``transforms.RandomCrop`` for images.
151
+ target_transform (callable, optional): A function/transform that takes
152
+ in the target and transforms it.
153
+ loader (callable, optional): A function to load an image given its path.
154
+ is_valid_file (callable, optional): A function that takes path of an Image file
155
+ and check if the file is a valid file (used to check of corrupt files)
156
+
157
+ Attributes:
158
+ classes (list): List of the class names.
159
+ class_to_idx (dict): Dict with items (class_name, class_index).
160
+ imgs (list): List of (image path, class_index) tuples
161
+ """
162
+
163
+ def __init__(self, root, transform=None, target_transform=None,
164
+ loader=Image.open,
165
+ is_valid_file=None,
166
+ which_set='train',
167
+ augment_p=0.5,
168
+ maze_route_length=10,
169
+ trunc=False,
170
+ expand_range=True):
171
+ super(MazeImageFolder, self).__init__(root, transform, target_transform, loader, is_valid_file)
172
+ self.which_set = which_set
173
+ self.augment_p = augment_p
174
+ self.maze_route_length = maze_route_length
175
+ self.all_paths = {}
176
+ self.trunc = trunc
177
+ self.expand_range = expand_range
178
+
179
+ self._preload()
180
+ print('Solving all mazes...')
181
+ for index in range(len(self.preloaded_samples)):
182
+ path = self.get_solution(self.preloaded_samples[index])
183
+ self.all_paths[index] = path
184
+
185
+ def _preload(self):
186
+ preloaded_samples = []
187
+ with tqdm(total=self.__len__(), initial=0, leave=True, position=0, dynamic_ncols=True) as pbar:
188
+
189
+ for index in range(self.__len__()):
190
+ pbar.set_description('Loading mazes')
191
+ path, target = self.samples[index]
192
+ sample = self.loader(path)
193
+ sample = np.array(sample).astype(np.float32)/255
194
+ preloaded_samples.append(sample)
195
+ pbar.update(1)
196
+ if self.trunc and index == 999: break
197
+ self.preloaded_samples = preloaded_samples
198
+
199
+ def __len__(self):
200
+ if hasattr(self, 'preloaded_samples') and self.preloaded_samples is not None:
201
+ return len(self.preloaded_samples)
202
+ else:
203
+ return super().__len__()
204
+
205
+ def get_solution(self, x):
206
+ x = np.copy(x)
207
+ # Find start (red) and end (green) pixel coordinates
208
+ start_coords = np.argwhere((x == [1, 0, 0]).all(axis=2))
209
+ end_coords = np.argwhere((x == [0, 1, 0]).all(axis=2))
210
+
211
+ if len(start_coords) == 0 or len(end_coords) == 0:
212
+ print("Start or end point not found.")
213
+ return None
214
+
215
+ start_y, start_x = start_coords[0]
216
+ end_y, end_x = end_coords[0]
217
+
218
+ current_y, current_x = start_y, start_x
219
+ path = [4] * self.maze_route_length
220
+
221
+ pi = 0
222
+ while (current_y, current_x) != (end_y, end_x):
223
+ next_y, next_x = -1, -1 # Initialize to invalid coordinates
224
+ direction = -1 # Initialize to an invalid direction
225
+
226
+
227
+ # Check Up
228
+ if current_y > 0 and ((x[current_y - 1, current_x] == [0, 0, 1]).all() or (x[current_y - 1, current_x] == [0, 1, 0]).all()):
229
+ next_y, next_x = current_y - 1, current_x
230
+ direction = 0
231
+
232
+ # Check Down
233
+ elif current_y < x.shape[0] - 1 and ((x[current_y + 1, current_x] == [0, 0, 1]).all() or (x[current_y + 1, current_x] == [0, 1, 0]).all()):
234
+ next_y, next_x = current_y + 1, current_x
235
+ direction = 1
236
+
237
+ # Check Left
238
+ elif current_x > 0 and ((x[current_y, current_x - 1] == [0, 0, 1]).all() or (x[current_y, current_x - 1] == [0, 1, 0]).all()):
239
+ next_y, next_x = current_y, current_x - 1
240
+ direction = 2
241
+
242
+ # Check Right
243
+ elif current_x < x.shape[1] - 1 and ((x[current_y, current_x + 1] == [0, 0, 1]).all() or (x[current_y, current_x + 1] == [0, 1, 0]).all()):
244
+ next_y, next_x = current_y, current_x + 1
245
+ direction = 3
246
+
247
+
248
+ path[pi] = direction
249
+ pi += 1
250
+
251
+ x[current_y, current_x] = [255,255,255] # mark the current as white to avoid going in circles
252
+ current_y, current_x = next_y, next_x
253
+ if pi == len(path):
254
+ break
255
+
256
+ return np.array(path)
257
+
258
+ def __getitem__(self, index):
259
+ """
260
+ Args:
261
+ index (int): Index
262
+
263
+ Returns:
264
+ tuple: (sample, target) where target is class_index of the target class.
265
+ """
266
+
267
+ sample = np.copy(self.preloaded_samples[index])
268
+
269
+ path = np.copy(self.all_paths[index])
270
+
271
+ if self.which_set == 'train':
272
+ # Randomly rotate -90 or +90 degrees
273
+ if random.random() < self.augment_p:
274
+ which_rot = random.choice([-1, 1])
275
+ sample = np.rot90(sample, k=which_rot, axes=(0, 1))
276
+ for pi in range(len(path)):
277
+ if path[pi] == 0: path[pi] = 3 if which_rot == -1 else 2
278
+ elif path[pi] == 1: path[pi] = 2 if which_rot == -1 else 3
279
+ elif path[pi] == 2: path[pi] = 0 if which_rot == -1 else 1
280
+ elif path[pi] == 3: path[pi] = 1 if which_rot == -1 else 0
281
+
282
+
283
+ # Random horizontal flip
284
+ if random.random() < self.augment_p:
285
+ sample = np.fliplr(sample)
286
+ for pi in range(len(path)):
287
+ if path[pi] == 2: path[pi] = 3
288
+ elif path[pi] == 3: path[pi] = 2
289
+
290
+
291
+ # Random vertical flip
292
+ if random.random() < self.augment_p:
293
+ sample = np.flipud(sample)
294
+ for pi in range(len(path)):
295
+ if path[pi] == 0: path[pi] = 1
296
+ elif path[pi] == 1: path[pi] = 0
297
+
298
+ sample = torch.from_numpy(np.copy(sample)).permute(2,0,1)
299
+
300
+ blue_mask = (sample[0] == 0) & (sample[1] == 0) & (sample[2] == 1)
301
+
302
+ sample[:, blue_mask] = 1
303
+ target = path
304
+
305
+
306
+ if not self.expand_range:
307
+ return sample, target
308
+ return (sample*2)-1, (target)
309
+
310
+ class ParityDataset(Dataset):
311
+ def __init__(self, sequence_length=64, length=100000):
312
+ self.sequence_length = sequence_length
313
+ self.length = length
314
+
315
+ def __len__(self):
316
+ return self.length
317
+
318
+ def __getitem__(self, idx):
319
+ vector = 2 * torch.randint(0, 2, (self.sequence_length,)) - 1
320
+ vector = vector.float()
321
+ negatives = (vector == -1).to(torch.long)
322
+ cumsum = torch.cumsum(negatives, dim=0)
323
+ target = (cumsum % 2 != 0).to(torch.long)
324
+ return vector, target
examples/01_mnist.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Continuous Thought Machines
2
+ ## Models
3
+
4
+ This folder contains all model-related code.
5
+
6
+ Some notes for clarity:
7
+ 1. The resnet structure we used (see resnet.py) has a few minor changes that enable constraining the receptive field of the features yielded. We do this because we want the CTM (or baseline methods) to learn a process whereby they gather information. Neural networks that use SGD will find the [path of least resistence](https://era.ed.ac.uk/handle/1842/39606), even if that path doesn't result in actually intelligent behaviour. Constraining the receptive field helps to prevent this, a bit.
models/constants.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ VALID_NEURON_SELECT_TYPES = ['first-last', 'random', 'random-pairing']
2
+
3
+ VALID_BACKBONE_TYPES = [
4
+ f'resnet{depth}-{i}' for depth in [18, 34, 50, 101, 152] for i in range(1, 5)
5
+ ] + ['shallow-wide', 'parity_backbone']
6
+
7
+ VALID_POSITIONAL_EMBEDDING_TYPES = [
8
+ 'learnable-fourier', 'multi-learnable-fourier',
9
+ 'custom-rotational', 'custom-rotational-1d'
10
+ ]
models/ctm.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ import math
5
+
6
+ from models.modules import ParityBackbone, SynapseUNET, Squeeze, SuperLinear, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
7
+ from models.resnet import prepare_resnet_backbone
8
+ from models.utils import compute_normalized_entropy
9
+
10
+ from models.constants import (
11
+ VALID_NEURON_SELECT_TYPES,
12
+ VALID_BACKBONE_TYPES,
13
+ VALID_POSITIONAL_EMBEDDING_TYPES
14
+ )
15
+
16
+ class ContinuousThoughtMachine(nn.Module):
17
+ """
18
+ Continuous Thought Machine (CTM).
19
+
20
+ Technical report: TODO:LINK
21
+
22
+ Technical report (web version): TODO:LINK
23
+
24
+ Blog: TODO:LINK
25
+
26
+ Thought takes time and reasoning is a process.
27
+
28
+ The CTM consists of three main ideas:
29
+ 1. The use of internal recurrence, enabling a dimension over which a concept analogous to thought can occur.
30
+ 1. Neuron-level models, that compute post-activations by applying private (i.e., on a per-neuron basis) MLP
31
+ models to a history of incoming pre-activations.
32
+ 2. Synchronisation as representation, where the neural activity over time is tracked and used to compute how
33
+ pairs of neurons synchronise with one another over time. This measure of synchronisation is the representation
34
+ with which the CTM takes action and makes predictions.
35
+
36
+
37
+ Args:
38
+ iterations (int): Number of internal 'thought' ticks (T, in paper).
39
+ d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
40
+ NOTE: Note that this is NOT the representation used for action or prediction, but rather that which
41
+ is fully internal to the model and not directly connected to data.
42
+ d_input (int): Dimensionality of projected attention outputs or direct input features.
43
+ heads (int): Number of attention heads.
44
+ n_synch_out (int): Number of neurons used for output synchronisation (D_out, in paper).
45
+ n_synch_action (int): Number of neurons used for action/attention synchronisation (D_action, in paper).
46
+ synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
47
+ memory_length (int): History length for Neuron-Level Models (M, in paper).
48
+ deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
49
+ NOTE: we almost always use deep NLMs, but a linear NLM is faster.
50
+ memory_hidden_dims (int): Hidden dimension size for deep NLMs.
51
+ do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
52
+ NOTE: we never set this to true in the paper. If you set this to true you will get strange behaviour,
53
+ but you can potentially encourage more periodic behaviour in the dynamics. Untested; be careful.
54
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
55
+ positional_embedding_type (str): Type of positional embedding for backbone features.
56
+ out_dims (int): Output dimension size.
57
+ NOTE: projected from synchronisation!
58
+ prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
59
+ NOTE: this is used to compute certainty and is needed when applying softmax for probabilities
60
+ dropout (float): Dropout rate.
61
+ neuron_select_type (str): Neuron selection strategy ('first-last', 'random', 'random-pairing').
62
+ NOTE: some of this is legacy from our experimentation, but all three strategies are valid and useful.
63
+ We dilineate exactly which strategies we use per experiment in the paper.
64
+ - first-last: build a 'dense' sync matrix for output from the first D_out neurons and action from the
65
+ last D_action neurons. Flatten this matrix into the synchronisation representation.
66
+ This approach shares relationships for neurons and bottlenecks the gradients through them.
67
+ NOTE: the synchronisation size will be (D_out/action * (D_out/action + 1))/2
68
+ - random: randomly select D_out neurons for the 'i' side pairings, and also D_out for the 'j' side pairings,
69
+ also pairing those accross densely, resulting in a bottleneck roughly 2x as wide.
70
+ NOTE: the synchronisation size will be (D_out/action * (D_out/action + 1))/2
71
+ - random-pairing (DEFAULT!): randomly select D_out neurons and pair these with another D_out neurons.
72
+ This results in much less bottlenecking and is the most up-to-date variant.
73
+ NOTE: the synchronisation size will be D_out in this case; better control.
74
+ n_random_pairing_self (int): Number of neurons to select for self-to-self synch when random-pairing is used.
75
+ NOTE: when using random-pairing, i-to-i (self) synchronisation is rare, meaning that 'recovering a
76
+ snapshot representation' (see paper) is difficult. This alleviates that.
77
+ NOTE: works fine when set to 0.
78
+ """
79
+
80
+ def __init__(self,
81
+ iterations,
82
+ d_model,
83
+ d_input,
84
+ heads,
85
+ n_synch_out,
86
+ n_synch_action,
87
+ synapse_depth,
88
+ memory_length,
89
+ deep_nlms,
90
+ memory_hidden_dims,
91
+ do_layernorm_nlm,
92
+ backbone_type,
93
+ positional_embedding_type,
94
+ out_dims,
95
+ prediction_reshaper=[-1],
96
+ dropout=0,
97
+ dropout_nlm=None,
98
+ neuron_select_type='random-pairing',
99
+ n_random_pairing_self=0,
100
+ ):
101
+ super(ContinuousThoughtMachine, self).__init__()
102
+
103
+ # --- Core Parameters ---
104
+ self.iterations = iterations
105
+ self.d_model = d_model
106
+ self.d_input = d_input
107
+ self.memory_length = memory_length
108
+ self.prediction_reshaper = prediction_reshaper
109
+ self.n_synch_out = n_synch_out
110
+ self.n_synch_action = n_synch_action
111
+ self.backbone_type = backbone_type
112
+ self.out_dims = out_dims
113
+ self.positional_embedding_type = positional_embedding_type
114
+ self.neuron_select_type = neuron_select_type
115
+ self.memory_length = memory_length
116
+ dropout_nlm = dropout if dropout_nlm is None else dropout_nlm
117
+
118
+ # --- Assertions ---
119
+ self.verify_args()
120
+
121
+ # --- Input Processing ---
122
+ d_backbone = self.get_d_backbone()
123
+ self.set_initial_rgb()
124
+ self.set_backbone()
125
+ self.positional_embedding = self.get_positional_embedding(d_backbone)
126
+ self.kv_proj = nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input)) if heads else None
127
+ self.q_proj = nn.LazyLinear(self.d_input) if heads else None
128
+ self.attention = nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True) if heads else None
129
+
130
+ # --- Core CTM Modules ---
131
+ self.synapses = self.get_synapses(synapse_depth, d_model, dropout)
132
+ self.trace_processor = self.get_neuron_level_models(deep_nlms, do_layernorm_nlm, memory_length, memory_hidden_dims, d_model, dropout_nlm)
133
+
134
+ # --- Start States ---
135
+ self.register_parameter('start_activated_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model)))))
136
+ self.register_parameter('start_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length)))))
137
+
138
+ # --- Synchronisation ---
139
+ self.neuron_select_type_out, self.neuron_select_type_action = self.get_neuron_select_type()
140
+ self.synch_representation_size_action = self.calculate_synch_representation_size(self.n_synch_action)
141
+ self.synch_representation_size_out = self.calculate_synch_representation_size(self.n_synch_out)
142
+
143
+ for synch_type, size in (('action', self.synch_representation_size_action), ('out', self.synch_representation_size_out)):
144
+ print(f"Synch representation size {synch_type}: {size}")
145
+ if self.synch_representation_size_action: # if not zero
146
+ self.set_synchronisation_parameters('action', self.n_synch_action, n_random_pairing_self)
147
+ self.set_synchronisation_parameters('out', self.n_synch_out, n_random_pairing_self)
148
+
149
+ # --- Output Procesing ---
150
+ self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))
151
+
152
+ # --- Core CTM Methods ---
153
+
154
+ def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, r, synch_type):
155
+ """
156
+ Computes synchronisation to be used as a vector representation.
157
+
158
+ A neuron has what we call a 'trace', which is a history (time series) that changes with internal
159
+ recurrence. i.e., it gets longer with every internal tick. There are pre-activation traces
160
+ that are used in the NLMs and post-activation traces that, in theory, are used in this method.
161
+
162
+ We define sychronisation between neuron i and j as the dot product between their respective
163
+ time series. Since there can be many internal ticks, this process can be quite compute heavy as it
164
+ involves many dot products that repeat computation at each step.
165
+
166
+ Therefore, in practice, we update the synchronisation based on the current post-activations,
167
+ which we call the 'activated state' here. This is possible because the inputs to synchronisation
168
+ are only updated recurrently at each step, meaning that there is a linear recurrence we can
169
+ leverage.
170
+
171
+ See Appendix TODO of the Technical Report (TODO:LINK) for the maths that enables this method.
172
+ """
173
+
174
+ if synch_type == 'action': # Get action parameters
175
+ n_synch = self.n_synch_action
176
+ neuron_indices_left = self.action_neuron_indices_left
177
+ neuron_indices_right = self.action_neuron_indices_right
178
+ elif synch_type == 'out': # Get input parameters
179
+ n_synch = self.n_synch_out
180
+ neuron_indices_left = self.out_neuron_indices_left
181
+ neuron_indices_right = self.out_neuron_indices_right
182
+
183
+ if self.neuron_select_type in ('first-last', 'random'):
184
+ # For first-last and random, we compute the pairwise sync between all selected neurons
185
+ if self.neuron_select_type == 'first-last':
186
+ if synch_type == 'action': # Use last n_synch neurons for action
187
+ selected_left = selected_right = activated_state[:, -n_synch:]
188
+ elif synch_type == 'out': # Use first n_synch neurons for out
189
+ selected_left = selected_right = activated_state[:, :n_synch]
190
+ else: # Use the randomly selected neurons
191
+ selected_left = activated_state[:, neuron_indices_left]
192
+ selected_right = activated_state[:, neuron_indices_right]
193
+
194
+ # Compute outer product of selected neurons
195
+ outer = selected_left.unsqueeze(2) * selected_right.unsqueeze(1)
196
+ # Resulting matrix is symmetric, so we only need the upper triangle
197
+ i, j = torch.triu_indices(n_synch, n_synch)
198
+ pairwise_product = outer[:, i, j]
199
+
200
+ elif self.neuron_select_type == 'random-pairing':
201
+ # For random-pairing, we compute the sync between specific pairs of neurons
202
+ left = activated_state[:, neuron_indices_left]
203
+ right = activated_state[:, neuron_indices_right]
204
+ pairwise_product = left * right
205
+ else:
206
+ raise ValueError("Invalid neuron selection type")
207
+
208
+
209
+
210
+ # Compute synchronisation recurrently
211
+ if decay_alpha is None or decay_beta is None:
212
+ decay_alpha = pairwise_product
213
+ decay_beta = torch.ones_like(pairwise_product)
214
+ else:
215
+ decay_alpha = r * decay_alpha + pairwise_product
216
+ decay_beta = r * decay_beta + 1
217
+
218
+ synchronisation = decay_alpha / (torch.sqrt(decay_beta))
219
+ return synchronisation, decay_alpha, decay_beta
220
+
221
+ def compute_features(self, x):
222
+ """
223
+ Compute the key-value features from the input data using the backbone.
224
+ """
225
+ initial_rgb = self.initial_rgb(x)
226
+ self.kv_features = self.backbone(initial_rgb)
227
+ pos_emb = self.positional_embedding(self.kv_features)
228
+ combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2)
229
+ kv = self.kv_proj(combined_features)
230
+ return kv
231
+
232
+ def compute_certainty(self, current_prediction):
233
+ """
234
+ Compute the certainty of the current prediction.
235
+
236
+ We define certainty as being 1-normalised entropy.
237
+
238
+ For legacy reasons we stack that in a 2D vector as this can be used for optimisation later.
239
+ """
240
+ B = current_prediction.size(0)
241
+ reshaped_pred = current_prediction.reshape([B] + self.prediction_reshaper)
242
+ ne = compute_normalized_entropy(reshaped_pred)
243
+ current_certainty = torch.stack((ne, 1-ne), -1)
244
+ return current_certainty
245
+
246
+ # --- Setup Methods ---
247
+
248
+ def set_initial_rgb(self):
249
+ """
250
+ This is largely to accommodate training on grescale images and is legacy, but it
251
+ doesn't hurt the model in any way that we can tell.
252
+ """
253
+ if 'resnet' in self.backbone_type:
254
+ self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
255
+ else:
256
+ self.initial_rgb = nn.Identity()
257
+
258
+ def get_d_backbone(self):
259
+ """
260
+ Get the dimensionality of the backbone output, to be used for positional embedding setup.
261
+
262
+ This is a little bit complicated for resnets, but the logic should be easy enough to read below.
263
+ """
264
+ if self.backbone_type == 'shallow-wide':
265
+ return 2048
266
+ elif self.backbone_type == 'parity_backbone':
267
+ return self.d_input
268
+ elif 'resnet' in self.backbone_type:
269
+ if '18' in self.backbone_type or '34' in self.backbone_type:
270
+ if self.backbone_type.split('-')[1]=='1': return 64
271
+ elif self.backbone_type.split('-')[1]=='2': return 128
272
+ elif self.backbone_type.split('-')[1]=='3': return 256
273
+ elif self.backbone_type.split('-')[1]=='4': return 512
274
+ else:
275
+ raise NotImplementedError
276
+ else:
277
+ if self.backbone_type.split('-')[1]=='1': return 256
278
+ elif self.backbone_type.split('-')[1]=='2': return 512
279
+ elif self.backbone_type.split('-')[1]=='3': return 1024
280
+ elif self.backbone_type.split('-')[1]=='4': return 2048
281
+ else:
282
+ raise NotImplementedError
283
+ elif self.backbone_type == 'none':
284
+ return None
285
+ else:
286
+ raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
287
+
288
+ def set_backbone(self):
289
+ """
290
+ Set the backbone module based on the specified type.
291
+ """
292
+ if self.backbone_type == 'shallow-wide':
293
+ self.backbone = ShallowWide()
294
+ elif self.backbone_type == 'parity_backbone':
295
+ d_backbone = self.get_d_backbone()
296
+ self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone)
297
+ elif 'resnet' in self.backbone_type:
298
+ self.backbone = prepare_resnet_backbone(self.backbone_type)
299
+ elif self.backbone_type == 'none':
300
+ self.backbone = nn.Identity()
301
+ else:
302
+ raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
303
+
304
+ def get_positional_embedding(self, d_backbone):
305
+ """
306
+ Get the positional embedding module.
307
+
308
+ For Imagenet and mazes we used NO positional embedding, and largely don't think
309
+ that it is necessary as the CTM can build up its own internal world model when
310
+ observing.
311
+
312
+ LearnableFourierPositionalEncoding:
313
+ Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
314
+ Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
315
+ Provides positional information for 2D feature maps.
316
+
317
+ (MultiLearnableFourierPositionalEncoding uses multiple feature scales)
318
+
319
+ CustomRotationalEmbedding:
320
+ Simple sinusoidal embedding to encourage interpretability
321
+ """
322
+ if self.positional_embedding_type == 'learnable-fourier':
323
+ return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5)
324
+ elif self.positional_embedding_type == 'multi-learnable-fourier':
325
+ return MultiLearnableFourierPositionalEncoding(d_backbone)
326
+ elif self.positional_embedding_type == 'custom-rotational':
327
+ return CustomRotationalEmbedding(d_backbone)
328
+ elif self.positional_embedding_type == 'custom-rotational-1d':
329
+ return CustomRotationalEmbedding1D(d_backbone)
330
+ elif self.positional_embedding_type == 'none':
331
+ return lambda x: 0 # Default no-op
332
+ else:
333
+ raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}")
334
+
335
+ def get_neuron_level_models(self, deep_nlms, do_layernorm_nlm, memory_length, memory_hidden_dims, d_model, dropout):
336
+ """
337
+ Neuron level models are one of the core innovations of the CTM. They apply separate MLPs/linears to
338
+ each neuron.
339
+ NOTE: the name 'SuperLinear' is largely legacy, but its purpose is to apply separate linear layers
340
+ per neuron. It is sort of a 'grouped linear' function, where the group size is equal to 1.
341
+ One could make the group size bigger and use fewer parameters, but that is future work.
342
+
343
+ NOTE: We used GLU() nonlinearities because they worked well in practice.
344
+ """
345
+ if deep_nlms:
346
+ return nn.Sequential(
347
+ nn.Sequential(
348
+ SuperLinear(in_dims=memory_length, out_dims=2 * memory_hidden_dims, N=d_model,
349
+ do_norm=do_layernorm_nlm, dropout=dropout),
350
+ nn.GLU(),
351
+ SuperLinear(in_dims=memory_hidden_dims, out_dims=2, N=d_model,
352
+ do_norm=do_layernorm_nlm, dropout=dropout),
353
+ nn.GLU(),
354
+ Squeeze(-1)
355
+ )
356
+ )
357
+ else:
358
+ return nn.Sequential(
359
+ nn.Sequential(
360
+ SuperLinear(in_dims=memory_length, out_dims=2, N=d_model,
361
+ do_norm=do_layernorm_nlm, dropout=dropout),
362
+ nn.GLU(),
363
+ Squeeze(-1)
364
+ )
365
+ )
366
+
367
+ def get_synapses(self, synapse_depth, d_model, dropout):
368
+ """
369
+ The synapse model is the recurrent model in the CTM. It's purpose is to share information
370
+ across neurons. If using depth of 1, this is just a simple single layer with nonlinearity and layernomr.
371
+ For deeper synapse models we use a U-NET structure with many skip connections. In practice this performs
372
+ better as it enables multi-level information mixing.
373
+
374
+ The intuition with having a deep UNET model for synapses is that the action of synaptic connections is
375
+ not necessarily a linear one, and that approximate a synapose 'update' step in the brain is non trivial.
376
+ Hence, we set it up so that the CTM can learn some complex internal rule instead of trying to approximate
377
+ it ourselves.
378
+ """
379
+ if synapse_depth == 1:
380
+ return nn.Sequential(
381
+ nn.Dropout(dropout),
382
+ nn.LazyLinear(d_model * 2),
383
+ nn.GLU(),
384
+ nn.LayerNorm(d_model)
385
+ )
386
+ else:
387
+ return SynapseUNET(d_model, synapse_depth, 16, dropout) # hard-coded minimum width of 16; future work TODO.
388
+
389
+ def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
390
+ """
391
+ 1. Set the buffers for selecting neurons so that these indices are saved into the model state_dict.
392
+ 2. Set the parameters for learnable exponential decay when computing synchronisation between all
393
+ neurons.
394
+ """
395
+ assert synch_type in ('out', 'action'), f"Invalid synch_type: {synch_type}"
396
+ left, right = self.initialize_left_right_neurons(synch_type, self.d_model, n_synch, n_random_pairing_self)
397
+ synch_representation_size = self.synch_representation_size_action if synch_type == 'action' else self.synch_representation_size_out
398
+ self.register_buffer(f'{synch_type}_neuron_indices_left', left)
399
+ self.register_buffer(f'{synch_type}_neuron_indices_right', right)
400
+ self.register_parameter(f'decay_params_{synch_type}', nn.Parameter(torch.zeros(synch_representation_size), requires_grad=True))
401
+
402
+ def initialize_left_right_neurons(self, synch_type, d_model, n_synch, n_random_pairing_self=0):
403
+ """
404
+ Initialize the left and right neuron indices based on the neuron selection type.
405
+ This complexity is owing to legacy experiments, but we retain that these types of
406
+ neuron selections are interesting to experiment with.
407
+ """
408
+ if self.neuron_select_type=='first-last':
409
+ if synch_type == 'out':
410
+ neuron_indices_left = neuron_indices_right = torch.arange(0, n_synch)
411
+ elif synch_type == 'action':
412
+ neuron_indices_left = neuron_indices_right = torch.arange(d_model-n_synch, d_model)
413
+
414
+ elif self.neuron_select_type=='random':
415
+ neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
416
+ neuron_indices_right = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
417
+
418
+ elif self.neuron_select_type=='random-pairing':
419
+ assert n_synch > n_random_pairing_self, f"Need at least {n_random_pairing_self} pairs for {self.neuron_select_type}"
420
+ neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
421
+ neuron_indices_right = torch.concatenate((neuron_indices_left[:n_random_pairing_self], torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch-n_random_pairing_self))))
422
+
423
+ device = self.start_activated_state.device
424
+ return neuron_indices_left.to(device), neuron_indices_right.to(device)
425
+
426
+ def get_neuron_select_type(self):
427
+ """
428
+ Another helper method to accomodate our legacy neuron selection types.
429
+ TODO: additional experimentation and possible removal of 'first-last' and 'random'
430
+ """
431
+ print(f"Using neuron select type: {self.neuron_select_type}")
432
+ if self.neuron_select_type == 'first-last':
433
+ neuron_select_type_out, neuron_select_type_action = 'first', 'last'
434
+ elif self.neuron_select_type in ('random', 'random-pairing'):
435
+ neuron_select_type_out = neuron_select_type_action = self.neuron_select_type
436
+ else:
437
+ raise ValueError(f"Invalid neuron selection type: {self.neuron_select_type}")
438
+ return neuron_select_type_out, neuron_select_type_action
439
+
440
+ # --- Utilty Methods ---
441
+
442
+ def verify_args(self):
443
+ """
444
+ Verify the validity of the input arguments to ensure consistent behaviour.
445
+ Specifically when selecting neurons for sychronisation using 'first-last' or 'random',
446
+ one needs the right number of neurons
447
+ """
448
+ assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
449
+ f"Invalid neuron selection type: {self.neuron_select_type}"
450
+
451
+ assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \
452
+ f"Invalid backbone_type: {self.backbone_type}"
453
+
454
+ assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \
455
+ f"Invalid positional_embedding_type: {self.positional_embedding_type}"
456
+
457
+ if self.neuron_select_type == 'first-last':
458
+ assert self.d_model >= (self.n_synch_out + self.n_synch_action), \
459
+ "d_model must be >= n_synch_out + n_synch_action for neuron subsets"
460
+
461
+ if self.backbone_type=='none' and self.positional_embedding_type!='none':
462
+ raise AssertionError("There should be no positional embedding if there is no backbone.")
463
+
464
+ def calculate_synch_representation_size(self, n_synch):
465
+ """
466
+ Calculate the size of the synchronisation representation based on neuron selection type.
467
+ """
468
+ if self.neuron_select_type == 'random-pairing':
469
+ synch_representation_size = n_synch
470
+ elif self.neuron_select_type in ('first-last', 'random'):
471
+ synch_representation_size = (n_synch * (n_synch + 1)) // 2
472
+ else:
473
+ raise ValueError(f"Invalid neuron selection type: {self.neuron_select_type}")
474
+ return synch_representation_size
475
+
476
+
477
+
478
+
479
+ def forward(self, x, track=False):
480
+ B = x.size(0)
481
+ device = x.device
482
+
483
+ # --- Tracking Initialization ---
484
+ pre_activations_tracking = []
485
+ post_activations_tracking = []
486
+ synch_out_tracking = []
487
+ synch_action_tracking = []
488
+ attention_tracking = []
489
+
490
+ # --- Featurise Input Data ---
491
+ kv = self.compute_features(x)
492
+
493
+ # --- Initialise Recurrent State ---
494
+ state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
495
+ activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
496
+
497
+ # --- Prepare Storage for Outputs per Iteration ---
498
+ predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=torch.float32)
499
+ certainties = torch.empty(B, 2, self.iterations, device=device, dtype=torch.float32)
500
+
501
+ # --- Initialise Recurrent Synch Values ---
502
+ decay_alpha_action, decay_beta_action = None, None
503
+ r_action, r_out = torch.exp(-torch.clamp(self.decay_params_action, 0, 15)).unsqueeze(0).repeat(B, 1), torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
504
+ _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
505
+ # Compute learned weighting for synchronisation
506
+
507
+
508
+ # --- Recurrent Loop ---
509
+ for stepi in range(self.iterations):
510
+
511
+ # --- Calculate Synchronisation for Input Data Interaction ---
512
+ synchronisation_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
513
+
514
+ # --- Interact with Data via Attention ---
515
+ q = self.q_proj(synchronisation_action).unsqueeze(1)
516
+ attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
517
+ attn_out = attn_out.squeeze(1)
518
+ pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
519
+
520
+ # --- Apply Synapses ---
521
+ state = self.synapses(pre_synapse_input)
522
+ # The 'state_trace' is the history of incoming pre-activations
523
+ state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
524
+
525
+ # --- Apply Neuron-Level Models ---
526
+ activated_state = self.trace_processor(state_trace)
527
+ # One would also keep an 'activated_state_trace' as the history of outgoing post-activations
528
+ # BUT, this is unnecessary because the synchronisation calculation is fully linear and can be
529
+ # done using only the currect activated state (see compute_synchronisation method for explanation)
530
+
531
+ # --- Calculate Synchronisation for Output Predictions ---
532
+ synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
533
+
534
+ # --- Get Predictions and Certainties ---
535
+ current_prediction = self.output_projector(synchronisation_out)
536
+ current_certainty = self.compute_certainty(current_prediction)
537
+
538
+ predictions[..., stepi] = current_prediction
539
+ certainties[..., stepi] = current_certainty
540
+
541
+ # --- Tracking ---
542
+ if track:
543
+ pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
544
+ post_activations_tracking.append(activated_state.detach().cpu().numpy())
545
+ attention_tracking.append(attn_weights.detach().cpu().numpy())
546
+ synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
547
+ synch_action_tracking.append(synchronisation_action.detach().cpu().numpy())
548
+
549
+ # --- Return Values ---
550
+ if track:
551
+ return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
552
+ return predictions, certainties, synchronisation_out
models/ctm_qamnist.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from models.ctm import ContinuousThoughtMachine
4
+ from models.modules import MNISTBackbone, QAMNISTIndexEmbeddings, QAMNISTOperatorEmbeddings
5
+
6
+ class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine):
7
+ def __init__(self,
8
+ iterations,
9
+ d_model,
10
+ d_input,
11
+ heads,
12
+ n_synch_out,
13
+ n_synch_action,
14
+ synapse_depth,
15
+ memory_length,
16
+ deep_nlms,
17
+ memory_hidden_dims,
18
+ do_layernorm_nlm,
19
+ out_dims,
20
+ iterations_per_digit,
21
+ iterations_per_question_part,
22
+ iterations_for_answering,
23
+ prediction_reshaper=[-1],
24
+ dropout=0,
25
+ neuron_select_type='first-last',
26
+ n_random_pairing_self=256
27
+ ):
28
+ super().__init__(
29
+ iterations=iterations,
30
+ d_model=d_model,
31
+ d_input=d_input,
32
+ heads=heads,
33
+ n_synch_out=n_synch_out,
34
+ n_synch_action=n_synch_action,
35
+ synapse_depth=synapse_depth,
36
+ memory_length=memory_length,
37
+ deep_nlms=deep_nlms,
38
+ memory_hidden_dims=memory_hidden_dims,
39
+ do_layernorm_nlm=do_layernorm_nlm,
40
+ out_dims=out_dims,
41
+ prediction_reshaper=prediction_reshaper,
42
+ dropout=dropout,
43
+ neuron_select_type=neuron_select_type,
44
+ n_random_pairing_self=n_random_pairing_self,
45
+ backbone_type='none',
46
+ positional_embedding_type='none',
47
+ )
48
+
49
+ # --- Core Parameters ---
50
+ self.iterations_per_digit = iterations_per_digit
51
+ self.iterations_per_question_part = iterations_per_question_part
52
+ self.iterations_for_answering = iterations_for_answering
53
+
54
+ # --- Setup Methods ---
55
+
56
+ def set_initial_rgb(self):
57
+ """Set the initial RGB values for the backbone."""
58
+ return None
59
+
60
+ def get_d_backbone(self):
61
+ """Get the dimensionality of the backbone output."""
62
+ return self.d_input
63
+
64
+ def set_backbone(self):
65
+ """Set the backbone module based on the specified type."""
66
+ self.backbone_digit = MNISTBackbone(self.d_input)
67
+ self.index_backbone = QAMNISTIndexEmbeddings(50, self.d_input)
68
+ self.operator_backbone = QAMNISTOperatorEmbeddings(2, self.d_input)
69
+ pass
70
+
71
+ # --- Utilty Methods ---
72
+
73
+ def determine_step_type(self, total_iterations_for_digits, total_iterations_for_question, stepi: int):
74
+ """Determine whether the current step is for digits, questions, or answers."""
75
+ is_digit_step = stepi < total_iterations_for_digits
76
+ is_question_step = total_iterations_for_digits <= stepi < total_iterations_for_digits + total_iterations_for_question
77
+ is_answer_step = stepi >= total_iterations_for_digits + total_iterations_for_question
78
+ return is_digit_step, is_question_step, is_answer_step
79
+
80
+ def determine_index_operator_step_type(self, total_iterations_for_digits, stepi: int):
81
+ """Determine whether the current step is for index or operator."""
82
+ step_within_questions = stepi - total_iterations_for_digits
83
+ if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
84
+ is_index_step = True
85
+ is_operator_step = False
86
+ else:
87
+ is_index_step = False
88
+ is_operator_step = True
89
+ return is_index_step, is_operator_step
90
+
91
+ def get_kv_for_step(self, total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input=None, prev_kv=None):
92
+ """Get the key-value for the current step."""
93
+ is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
94
+
95
+ if is_digit_step:
96
+ current_input = x[:, stepi]
97
+ if prev_input is not None and torch.equal(current_input, prev_input):
98
+ return prev_kv, prev_input
99
+ kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
100
+
101
+ elif is_question_step:
102
+ offset = stepi - total_iterations_for_digits
103
+ current_input = z[:, offset]
104
+ if prev_input is not None and torch.equal(current_input, prev_input):
105
+ return prev_kv, prev_input
106
+ is_index_step, is_operator_step = self.determine_index_operator_step_type(total_iterations_for_digits, stepi)
107
+ if is_index_step:
108
+ kv = self.index_backbone(current_input)
109
+ elif is_operator_step:
110
+ kv = self.operator_backbone(current_input)
111
+ else:
112
+ raise ValueError("Invalid step type for question processing.")
113
+
114
+ elif is_answer_step:
115
+ current_input = None
116
+ kv = torch.zeros((x.size(0), self.d_input), device=x.device)
117
+
118
+ else:
119
+ raise ValueError("Invalid step type.")
120
+
121
+ return kv, current_input
122
+
123
+
124
+
125
+
126
+ def forward(self, x, z, track=False):
127
+ B = x.size(0)
128
+ device = x.device
129
+
130
+ # --- Tracking Initialization ---
131
+ pre_activations_tracking = []
132
+ post_activations_tracking = []
133
+ attention_tracking = []
134
+ embedding_tracking = []
135
+
136
+ total_iterations_for_digits = x.size(1)
137
+ total_iterations_for_question = z.size(1)
138
+ total_iterations = total_iterations_for_digits + total_iterations_for_question + self.iterations_for_answering
139
+
140
+ # --- Initialise Recurrent State ---
141
+ state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
142
+ activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
143
+
144
+ # --- Storage for outputs per iteration ---
145
+ predictions = torch.empty(B, self.out_dims, total_iterations, device=device, dtype=x.dtype)
146
+ certainties = torch.empty(B, 2, total_iterations, device=device, dtype=x.dtype)
147
+
148
+ # --- Initialise Recurrent Synch Values ---
149
+ decay_alpha_action, decay_beta_action = None, None
150
+ r_action, r_out = torch.exp(-torch.clamp(self.decay_params_action, 0, 15)).unsqueeze(0).repeat(B, 1), torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
151
+ _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
152
+
153
+ prev_input = None
154
+ prev_kv = None
155
+
156
+ # --- Recurrent Loop ---
157
+ for stepi in range(total_iterations):
158
+ is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
159
+
160
+ kv, prev_input = self.get_kv_for_step(total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input, prev_kv)
161
+ prev_kv = kv
162
+
163
+ synchronization_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
164
+
165
+ # --- Interact with Data via Attention ---
166
+ attn_weights = None
167
+ if is_digit_step:
168
+ q = self.q_proj(synchronization_action).unsqueeze(1)
169
+ attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
170
+ attn_out = attn_out.squeeze(1)
171
+ pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
172
+ else:
173
+ kv = kv.squeeze(1)
174
+ pre_synapse_input = torch.concatenate((kv, activated_state), dim=-1)
175
+
176
+ # --- Apply Synapses ---
177
+ state = self.synapses(pre_synapse_input)
178
+ state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
179
+
180
+ # --- Apply NLMs ---
181
+ activated_state = self.trace_processor(state_trace)
182
+
183
+ # --- Calculate Synchronisation for Output Predictions ---
184
+ synchronization_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
185
+
186
+ # --- Get Predictions and Certainties ---
187
+ current_prediction = self.output_projector(synchronization_out)
188
+ current_certainty = self.compute_certainty(current_prediction)
189
+
190
+ predictions[..., stepi] = current_prediction
191
+ certainties[..., stepi] = current_certainty
192
+
193
+ # --- Tracking ---
194
+ if track:
195
+ pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
196
+ post_activations_tracking.append(activated_state.detach().cpu().numpy())
197
+ if attn_weights is not None:
198
+ attention_tracking.append(attn_weights.detach().cpu().numpy())
199
+ if is_question_step:
200
+ embedding_tracking.append(kv.detach().cpu().numpy())
201
+
202
+ # --- Return Values ---
203
+ if track:
204
+ return predictions, certainties, synchronization_out, np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
205
+ return predictions, certainties, synchronization_out
models/ctm_rl.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import math
5
+ from models.ctm import ContinuousThoughtMachine
6
+ from models.modules import MiniGridBackbone, ClassicControlBackbone, SynapseUNET
7
+ from models.utils import compute_decay
8
+ from models.constants import VALID_NEURON_SELECT_TYPES
9
+
10
+ class ContinuousThoughtMachineRL(ContinuousThoughtMachine):
11
+ def __init__(self,
12
+ iterations,
13
+ d_model,
14
+ d_input,
15
+ n_synch_out,
16
+ synapse_depth,
17
+ memory_length,
18
+ deep_nlms,
19
+ memory_hidden_dims,
20
+ do_layernorm_nlm,
21
+ backbone_type,
22
+ prediction_reshaper=[-1],
23
+ dropout=0,
24
+ neuron_select_type='first-last',
25
+ ):
26
+ super().__init__(
27
+ iterations=iterations,
28
+ d_model=d_model,
29
+ d_input=d_input,
30
+ heads=0, # Set heads to 0 will return None
31
+ n_synch_out=n_synch_out,
32
+ n_synch_action=0,
33
+ synapse_depth=synapse_depth,
34
+ memory_length=memory_length,
35
+ deep_nlms=deep_nlms,
36
+ memory_hidden_dims=memory_hidden_dims,
37
+ do_layernorm_nlm=do_layernorm_nlm,
38
+ out_dims=0,
39
+ prediction_reshaper=prediction_reshaper,
40
+ dropout=dropout,
41
+ neuron_select_type=neuron_select_type,
42
+ backbone_type=backbone_type,
43
+ n_random_pairing_self=0,
44
+ positional_embedding_type='none',
45
+ )
46
+
47
+ # --- Use a minimal CTM w/out input (action) synch ---
48
+ self.neuron_select_type_action = None
49
+ self.synch_representation_size_action = None
50
+
51
+ # --- Start dynamics with a learned activated state trace ---
52
+ self.register_parameter('start_activated_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True))
53
+ self.start_activated_state = None
54
+
55
+ self.register_buffer('diagonal_mask_out', torch.triu(torch.ones(self.n_synch_out, self.n_synch_out, dtype=torch.bool)))
56
+
57
+ self.attention = None # Should already be None because super(... heads=0... )
58
+ self.q_proj = None # Should already be None because super(... heads=0... )
59
+ self.kv_proj = None # Should already be None because super(... heads=0... )
60
+ self.output_projector = None
61
+
62
+ # --- Core CTM Methods ---
63
+
64
+ def compute_synchronisation(self, activated_state_trace):
65
+ """Compute the synchronisation between neurons."""
66
+ assert self.neuron_select_type == "first-last", "only fisrst-last neuron selection is supported here"
67
+ # For RL tasks we track a sliding window of activations from which we compute synchronisation
68
+ S = activated_state_trace.permute(0, 2, 1)
69
+ diagonal_mask = self.diagonal_mask_out.to(S.device)
70
+ decay = compute_decay(S.size(1), self.decay_params_out, clamp_lims=(0, 4))
71
+ synchronisation = ((decay.unsqueeze(0) *(S[:,:,-self.n_synch_out:].unsqueeze(-1) * S[:,:,-self.n_synch_out:].unsqueeze(-2))[:,:,diagonal_mask]).sum(1))/torch.sqrt(decay.unsqueeze(0).sum(1,))
72
+ return synchronisation
73
+
74
+ # --- Setup Methods ---
75
+
76
+ def set_initial_rgb(self):
77
+ """Set the initial RGB values for the backbone."""
78
+ return None
79
+
80
+ def get_d_backbone(self):
81
+ """Get the dimensionality of the backbone output."""
82
+ return self.d_input
83
+
84
+ def set_backbone(self):
85
+ """Set the backbone module based on the specified type."""
86
+ if self.backbone_type == 'navigation-backbone':
87
+ self.backbone = MiniGridBackbone(self.d_input)
88
+ elif self.backbone_type == 'classic-control-backbone':
89
+ self.backbone = ClassicControlBackbone(self.d_input)
90
+ else:
91
+ raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
92
+ pass
93
+
94
+ def get_positional_embedding(self, d_backbone):
95
+ """Get the positional embedding module."""
96
+ return None
97
+
98
+
99
+ def get_synapses(self, synapse_depth, d_model, dropout):
100
+ """
101
+ Get the synapse module.
102
+
103
+ We found in our early experimentation that a single Linear, GLU and LayerNorm block performed worse than two blocks.
104
+ For that reason we set the default synapse depth to two blocks.
105
+
106
+ TODO: This is legacy and needs further experimentation to iron out.
107
+ """
108
+ if synapse_depth == 1:
109
+ return nn.Sequential(
110
+ nn.Dropout(dropout),
111
+ nn.LazyLinear(d_model*2),
112
+ nn.GLU(),
113
+ nn.LayerNorm(d_model),
114
+ nn.LazyLinear(d_model*2),
115
+ nn.GLU(),
116
+ nn.LayerNorm(d_model)
117
+ )
118
+ else:
119
+ return SynapseUNET(d_model, synapse_depth, 16, dropout)
120
+
121
+ def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
122
+ """Set the parameters for the synchronisation of neurons."""
123
+ if synch_type == 'action':
124
+ pass
125
+ elif synch_type == 'out':
126
+ left, right = self.initialize_left_right_neurons("out", self.d_model, n_synch, n_random_pairing_self)
127
+ self.register_buffer(f'out_neuron_indices_left', left)
128
+ self.register_buffer(f'out_neuron_indices_right', right)
129
+ self.register_parameter(f'decay_params_out', nn.Parameter(torch.zeros(self.synch_representation_size_out), requires_grad=True))
130
+ pass
131
+ else:
132
+ raise ValueError(f"Invalid synch_type: {synch_type}")
133
+
134
+ # --- Utilty Methods ---
135
+
136
+ def verify_args(self):
137
+ """Verify the validity of the input arguments."""
138
+ assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
139
+ f"Invalid neuron selection type: {self.neuron_select_type}"
140
+ assert self.neuron_select_type != 'random-pairing', \
141
+ f"Random pairing is not supported for RL."
142
+ assert self.backbone_type in ('navigation-backbone', 'classic-control-backbone'), \
143
+ f"Invalid backbone_type: {self.backbone_type}"
144
+ assert self.d_model >= (self.n_synch_out), \
145
+ "d_model must be >= n_synch_out for neuron subsets"
146
+ pass
147
+
148
+
149
+
150
+
151
+ def forward(self, x, hidden_states, track=False):
152
+
153
+ # --- Tracking Initialization ---
154
+ pre_activations_tracking = []
155
+ post_activations_tracking = []
156
+
157
+ # --- Featurise Input Data ---
158
+ features = self.backbone(x)
159
+
160
+ # --- Get Recurrent State ---
161
+ state_trace, activated_state_trace = hidden_states
162
+
163
+ # --- Recurrent Loop ---
164
+ for stepi in range(self.iterations):
165
+
166
+ pre_synapse_input = torch.concatenate((features.reshape(x.size(0), -1), activated_state_trace[:,:,-1]), -1)
167
+
168
+ # --- Apply Synapses ---
169
+ state = self.synapses(pre_synapse_input)
170
+ state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
171
+
172
+ # --- Apply NLMs ---
173
+ activated_state = self.trace_processor(state_trace)
174
+ activated_state_trace = torch.concatenate((activated_state_trace[:,:,1:], activated_state.unsqueeze(-1)), -1)
175
+
176
+ # --- Tracking ---
177
+ if track:
178
+ pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
179
+ post_activations_tracking.append(activated_state.detach().cpu().numpy())
180
+
181
+ hidden_states = (
182
+ state_trace,
183
+ activated_state_trace,
184
+ )
185
+
186
+ # --- Calculate Output Synchronisation ---
187
+ synchronisation_out = self.compute_synchronisation(activated_state_trace)
188
+
189
+ # --- Return Values ---
190
+ if track:
191
+ return synchronisation_out, hidden_states, np.array(pre_activations_tracking), np.array(post_activations_tracking)
192
+ return synchronisation_out, hidden_states
models/ctm_sort.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from models.ctm import ContinuousThoughtMachine
4
+
5
+ class ContinuousThoughtMachineSORT(ContinuousThoughtMachine):
6
+ """
7
+ Slight adaption of the CTM to work with the sort task.
8
+ """
9
+
10
+ def __init__(self,
11
+ iterations,
12
+ d_model,
13
+ d_input,
14
+ heads,
15
+ n_synch_out,
16
+ n_synch_action,
17
+ synapse_depth,
18
+ memory_length,
19
+ deep_nlms,
20
+ memory_hidden_dims,
21
+ do_layernorm_nlm,
22
+ backbone_type,
23
+ positional_embedding_type,
24
+ out_dims,
25
+ prediction_reshaper=[-1],
26
+ dropout=0,
27
+ dropout_nlm=None,
28
+ neuron_select_type='random-pairing',
29
+ n_random_pairing_self=0,
30
+ ):
31
+ super().__init__(
32
+ iterations=iterations,
33
+ d_model=d_model,
34
+ d_input=d_input,
35
+ heads=0,
36
+ n_synch_out=n_synch_out,
37
+ n_synch_action=0,
38
+ synapse_depth=synapse_depth,
39
+ memory_length=memory_length,
40
+ deep_nlms=deep_nlms,
41
+ memory_hidden_dims=memory_hidden_dims,
42
+ do_layernorm_nlm=do_layernorm_nlm,
43
+ backbone_type='none',
44
+ positional_embedding_type='none',
45
+ out_dims=out_dims,
46
+ prediction_reshaper=prediction_reshaper,
47
+ dropout=dropout,
48
+ dropout_nlm=dropout_nlm,
49
+ neuron_select_type=neuron_select_type,
50
+ n_random_pairing_self=n_random_pairing_self,
51
+ )
52
+
53
+ # --- Use a minimal CTM w/out input (action) synch ---
54
+ self.neuron_select_type_action = None
55
+ self.synch_representation_size_action = None
56
+
57
+ self.attention = None # Should already be None because super(... heads=0... )
58
+ self.q_proj = None # Should already be None because super(... heads=0... )
59
+ self.kv_proj = None # Should already be None because super(... heads=0... )
60
+
61
+
62
+
63
+
64
+ def forward(self, x, track=False):
65
+ B = x.size(0)
66
+ device = x.device
67
+
68
+ # --- Tracking Initialization ---
69
+ pre_activations_tracking = []
70
+ post_activations_tracking = []
71
+ synch_out_tracking = []
72
+ attention_tracking = []
73
+
74
+ # --- For SORT: no need to featurise data ---
75
+
76
+
77
+ # --- Initialise Recurrent State ---
78
+ state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
79
+ activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
80
+
81
+ # --- Prepare Storage for Outputs per Iteration ---
82
+ predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
83
+ certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
84
+
85
+ # --- Initialise Recurrent Synch Values ---
86
+ r_out = torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
87
+ _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
88
+ # Compute learned weighting for synchronisation
89
+
90
+
91
+ # --- Recurrent Loop ---
92
+ for stepi in range(self.iterations):
93
+
94
+ pre_synapse_input = torch.concatenate((x, activated_state), dim=-1)
95
+
96
+ # --- Apply Synapses ---
97
+ state = self.synapses(pre_synapse_input)
98
+ # The 'state_trace' is the history of incoming pre-activations
99
+ state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
100
+
101
+ # --- Apply Neuron-Level Models ---
102
+ activated_state = self.trace_processor(state_trace)
103
+ # One would also keep an 'activated_state_trace' as the history of outgoing post-activations
104
+ # BUT, this is unnecessary because the synchronisation calculation is fully linear and can be
105
+ # done using only the currect activated state (see compute_synchronisation method for explanation)
106
+
107
+ # --- Calculate Synchronisation for Output Predictions ---
108
+ synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
109
+
110
+ # --- Get Predictions and Certainties ---
111
+ current_prediction = self.output_projector(synchronisation_out)
112
+ current_certainty = self.compute_certainty(current_prediction)
113
+
114
+ predictions[..., stepi] = current_prediction
115
+ certainties[..., stepi] = current_certainty
116
+
117
+ # --- Tracking ---
118
+ if track:
119
+ pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
120
+ post_activations_tracking.append(activated_state.detach().cpu().numpy())
121
+ synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
122
+
123
+ # --- Return Values ---
124
+ if track:
125
+ return predictions, certainties, np.array(synch_out_tracking), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
126
+ return predictions, certainties, synchronisation_out
models/ff.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ # Local imports (Assuming these contain necessary custom modules)
4
+ from models.modules import *
5
+ from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
6
+
7
+
8
+ class FFBaseline(nn.Module):
9
+ """
10
+ LSTM Baseline.
11
+
12
+ Wrapper that lets us use the same backbone as the CTM and LSTM baselines, with a
13
+
14
+
15
+ Args:
16
+ d_model (int): workaround that projects final layer to this space so that parameter-matching is plausible.
17
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
18
+ out_dims (int): Dimensionality of the final output projection.
19
+ dropout (float): dropout in last layer
20
+ """
21
+
22
+ def __init__(self,
23
+ d_model,
24
+ backbone_type,
25
+ out_dims,
26
+ dropout=0,
27
+ ):
28
+ super(FFBaseline, self).__init__()
29
+
30
+ # --- Core Parameters ---
31
+ self.d_model = d_model
32
+ self.backbone_type = backbone_type
33
+ self.out_dims = out_dims
34
+
35
+ # --- Input Assertions ---
36
+ assert backbone_type in ['resnet18-1', 'resnet18-2', 'resnet18-3', 'resnet18-4',
37
+ 'resnet34-1', 'resnet34-2', 'resnet34-3', 'resnet34-4',
38
+ 'resnet50-1', 'resnet50-2', 'resnet50-3', 'resnet50-4',
39
+ 'resnet101-1', 'resnet101-2', 'resnet101-3', 'resnet101-4',
40
+ 'resnet152-1', 'resnet152-2', 'resnet152-3', 'resnet152-4',
41
+ 'none', 'shallow-wide', 'parity_backbone'], f"Invalid backbone_type: {backbone_type}"
42
+
43
+ # --- Backbone / Feature Extraction ---
44
+ self.initial_rgb = Identity() # Placeholder, potentially replaced if using ResNet
45
+
46
+
47
+ self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
48
+ resnet_family = resnet18 # Default
49
+ if '34' in self.backbone_type: resnet_family = resnet34
50
+ if '50' in self.backbone_type: resnet_family = resnet50
51
+ if '101' in self.backbone_type: resnet_family = resnet101
52
+ if '152' in self.backbone_type: resnet_family = resnet152
53
+
54
+ # Determine which ResNet blocks to keep
55
+ block_num_str = self.backbone_type.split('-')[-1]
56
+ hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]
57
+
58
+ self.backbone = resnet_family(
59
+ 3, # initial_rgb handles input channels now
60
+ hyper_blocks_to_keep,
61
+ stride=2,
62
+ pretrained=False,
63
+ progress=True,
64
+ device="cpu", # Initialise on CPU, move later via .to(device)
65
+ do_initial_max_pool=True,
66
+ )
67
+
68
+
69
+ # At this point we will have a 4D tensor of features: [B, C, H, W]
70
+ # The following lets us scale up the resnet with d_model until it matches the CTM
71
+ self.output_projector = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), Squeeze(-1), Squeeze(-1), nn.LazyLinear(d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, out_dims))
72
+
73
+
74
+ def forward(self, x):
75
+ return self.output_projector((self.backbone(self.initial_rgb(x))))
models/lstm.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ import math
5
+
6
+ from models.modules import ParityBackbone, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
7
+ from models.resnet import prepare_resnet_backbone
8
+ from models.utils import compute_normalized_entropy
9
+
10
+ from models.constants import (
11
+ VALID_BACKBONE_TYPES,
12
+ VALID_POSITIONAL_EMBEDDING_TYPES
13
+ )
14
+
15
+ class LSTMBaseline(nn.Module):
16
+ """
17
+ LSTM Baseline
18
+
19
+ Args:
20
+ iterations (int): Number of internal 'thought' steps (T, in paper).
21
+ d_model (int): Core dimensionality of the latent space.
22
+ d_input (int): Dimensionality of projected attention outputs or direct input features.
23
+ heads (int): Number of attention heads.
24
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
25
+ positional_embedding_type (str): Type of positional embedding for backbone features.
26
+ out_dims (int): Dimensionality of the final output projection.
27
+ prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
28
+ dropout (float): Dropout rate.
29
+ """
30
+
31
+ def __init__(self,
32
+ iterations,
33
+ d_model,
34
+ d_input,
35
+ heads,
36
+ backbone_type,
37
+ num_layers,
38
+ positional_embedding_type,
39
+ out_dims,
40
+ prediction_reshaper=[-1],
41
+ dropout=0,
42
+ ):
43
+ super(LSTMBaseline, self).__init__()
44
+
45
+ # --- Core Parameters ---
46
+ self.iterations = iterations
47
+ self.d_model = d_model
48
+ self.d_input = d_input
49
+ self.prediction_reshaper = prediction_reshaper
50
+ self.backbone_type = backbone_type
51
+ self.positional_embedding_type = positional_embedding_type
52
+ self.out_dims = out_dims
53
+
54
+ # --- Assertions ---
55
+ self.verify_args()
56
+
57
+ # --- Input Processing ---
58
+ d_backbone = self.get_d_backbone()
59
+
60
+ self.set_initial_rgb()
61
+ self.set_backbone()
62
+ self.positional_embedding = self.get_positional_embedding(d_backbone)
63
+ self.kv_proj = self.get_kv_proj()
64
+ self.lstm = nn.LSTM(d_input, d_model, num_layers, batch_first=True, dropout=dropout)
65
+ self.q_proj = self.get_q_proj()
66
+ self.attention = self.get_attention(heads, dropout)
67
+ self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
68
+
69
+ # --- Start States ---
70
+ self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
71
+ self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
72
+
73
+
74
+
75
+ # --- Core LSTM Methods ---
76
+
77
+ def compute_features(self, x):
78
+ """Applies backbone and positional embedding to input."""
79
+ x = self.initial_rgb(x)
80
+ self.kv_features = self.backbone(x)
81
+ pos_emb = self.positional_embedding(self.kv_features)
82
+ combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2)
83
+ kv = self.kv_proj(combined_features)
84
+ return kv
85
+
86
+ def compute_certainty(self, current_prediction):
87
+ """Compute the certainty of the current prediction."""
88
+ B = current_prediction.size(0)
89
+ reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
90
+ ne = compute_normalized_entropy(reshaped_pred)
91
+ current_certainty = torch.stack((ne, 1-ne), -1)
92
+ return current_certainty
93
+
94
+ # --- Setup Methods ---
95
+
96
+ def set_initial_rgb(self):
97
+ """Set the initial RGB processing module based on the backbone type."""
98
+ if 'resnet' in self.backbone_type:
99
+ self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
100
+ else:
101
+ self.initial_rgb = nn.Identity()
102
+
103
+ def get_d_backbone(self):
104
+ """
105
+ Get the dimensionality of the backbone output, to be used for positional embedding setup.
106
+
107
+ This is a little bit complicated for resnets, but the logic should be easy enough to read below.
108
+ """
109
+ if self.backbone_type == 'shallow-wide':
110
+ return 2048
111
+ elif self.backbone_type == 'parity_backbone':
112
+ return self.d_input
113
+ elif 'resnet' in self.backbone_type:
114
+ if '18' in self.backbone_type or '34' in self.backbone_type:
115
+ if self.backbone_type.split('-')[1]=='1': return 64
116
+ elif self.backbone_type.split('-')[1]=='2': return 128
117
+ elif self.backbone_type.split('-')[1]=='3': return 256
118
+ elif self.backbone_type.split('-')[1]=='4': return 512
119
+ else:
120
+ raise NotImplementedError
121
+ else:
122
+ if self.backbone_type.split('-')[1]=='1': return 256
123
+ elif self.backbone_type.split('-')[1]=='2': return 512
124
+ elif self.backbone_type.split('-')[1]=='3': return 1024
125
+ elif self.backbone_type.split('-')[1]=='4': return 2048
126
+ else:
127
+ raise NotImplementedError
128
+ elif self.backbone_type == 'none':
129
+ return None
130
+ else:
131
+ raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
132
+
133
+ def set_backbone(self):
134
+ """Set the backbone module based on the specified type."""
135
+ if self.backbone_type == 'shallow-wide':
136
+ self.backbone = ShallowWide()
137
+ elif self.backbone_type == 'parity_backbone':
138
+ d_backbone = self.get_d_backbone()
139
+ self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone)
140
+ elif 'resnet' in self.backbone_type:
141
+ self.backbone = prepare_resnet_backbone(self.backbone_type)
142
+ elif self.backbone_type == 'none':
143
+ self.backbone = nn.Identity()
144
+ else:
145
+ raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
146
+
147
+ def get_positional_embedding(self, d_backbone):
148
+ """Get the positional embedding module."""
149
+ if self.positional_embedding_type == 'learnable-fourier':
150
+ return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5)
151
+ elif self.positional_embedding_type == 'multi-learnable-fourier':
152
+ return MultiLearnableFourierPositionalEncoding(d_backbone)
153
+ elif self.positional_embedding_type == 'custom-rotational':
154
+ return CustomRotationalEmbedding(d_backbone)
155
+ elif self.positional_embedding_type == 'custom-rotational-1d':
156
+ return CustomRotationalEmbedding1D(d_backbone)
157
+ elif self.positional_embedding_type == 'none':
158
+ return lambda x: 0 # Default no-op
159
+ else:
160
+ raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}")
161
+
162
+ def get_attention(self, heads, dropout):
163
+ """Get the attention module."""
164
+ return nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True)
165
+
166
+ def get_kv_proj(self):
167
+ """Get the key-value projection module."""
168
+ return nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input))
169
+
170
+ def get_q_proj(self):
171
+ """Get the query projection module."""
172
+ return nn.LazyLinear(self.d_input)
173
+
174
+
175
+ def verify_args(self):
176
+ """Verify the validity of the input arguments."""
177
+
178
+ assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \
179
+ f"Invalid backbone_type: {self.backbone_type}"
180
+
181
+ assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \
182
+ f"Invalid positional_embedding_type: {self.positional_embedding_type}"
183
+
184
+ if self.backbone_type=='none' and self.positional_embedding_type!='none':
185
+ raise AssertionError("There should be no positional embedding if there is no backbone.")
186
+
187
+ pass
188
+
189
+
190
+
191
+
192
+ def forward(self, x, track=False):
193
+ """
194
+ Forward pass - Reverted to structure closer to user's working version.
195
+ Executes T=iterations steps.
196
+ """
197
+ B = x.size(0)
198
+ device = x.device
199
+
200
+ # --- Tracking Initialization ---
201
+ activations_tracking = []
202
+ attention_tracking = []
203
+
204
+ # --- Featurise Input Data ---
205
+ kv = self.compute_features(x)
206
+
207
+ # --- Initialise Recurrent State ---
208
+ hn = torch.repeat_interleave(self.start_hidden_state.unsqueeze(1), x.size(0), 1)
209
+ cn = torch.repeat_interleave(self.start_cell_state.unsqueeze(1), x.size(0), 1)
210
+ state_trace = [hn[-1]]
211
+
212
+ # --- Prepare Storage for Outputs per Iteration ---
213
+ predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
214
+ certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
215
+
216
+ # --- Recurrent Loop ---
217
+ for stepi in range(self.iterations):
218
+
219
+ # --- Interact with Data via Attention ---
220
+ q = self.q_proj(hn[-1].unsqueeze(1))
221
+ attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
222
+ lstm_input = attn_out
223
+
224
+ # --- Apply LSTM ---
225
+ hidden_state, (hn,cn) = self.lstm(lstm_input, (hn, cn))
226
+ hidden_state = hidden_state.squeeze(1)
227
+ state_trace.append(hidden_state)
228
+
229
+ # --- Get Predictions and Certainties ---
230
+ current_prediction = self.output_projector(hidden_state)
231
+ current_certainty = self.compute_certainty(current_prediction)
232
+
233
+ predictions[..., stepi] = current_prediction
234
+ certainties[..., stepi] = current_certainty
235
+
236
+ # --- Tracking ---
237
+ if track:
238
+ activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
239
+ attention_tracking.append(attn_weights.detach().cpu().numpy())
240
+
241
+ # --- Return Values ---
242
+ if track:
243
+ return predictions, certainties, None, np.zeros_like(activations_tracking), np.array(activations_tracking), np.array(attention_tracking)
244
+ return predictions, certainties, None
models/lstm_qamnist.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F # Used for GLU if not in modules
4
+ import numpy as np
5
+ import math
6
+
7
+ # Local imports (Assuming these contain necessary custom modules)
8
+ from models.modules import *
9
+ from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
10
+
11
+ class LSTMBaseline(nn.Module):
12
+ """
13
+ LSTM Baseline
14
+
15
+ Args:
16
+ iterations (int): Number of internal 'thought' steps (T, in paper).
17
+ d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
18
+ d_input (int): Dimensionality of projected attention outputs or direct input features.
19
+ heads (int): Number of attention heads.
20
+ n_synch_out (int): Number of neurons used for output synchronisation (No, in paper).
21
+ n_synch_action (int): Number of neurons used for action/attention synchronisation (Ni, in paper).
22
+ synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
23
+ memory_length (int): History length for Neuron-Level Models (M, in paper).
24
+ deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
25
+ memory_hidden_dims (int): Hidden dimension size for deep NLMs.
26
+ do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
27
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
28
+ positional_embedding_type (str): Type of positional embedding for backbone features.
29
+ out_dims (int): Dimensionality of the final output projection.
30
+ prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
31
+ dropout (float): Dropout rate.
32
+ """
33
+
34
+ def __init__(self,
35
+ iterations,
36
+ d_model,
37
+ d_input,
38
+ heads,
39
+ out_dims,
40
+ iterations_per_digit,
41
+ iterations_per_question_part,
42
+ iterations_for_answering,
43
+ prediction_reshaper=[-1],
44
+ dropout=0,
45
+ ):
46
+ super(LSTMBaseline, self).__init__()
47
+
48
+ # --- Core Parameters ---
49
+ self.iterations = iterations
50
+ self.d_model = d_model
51
+ self.prediction_reshaper = prediction_reshaper
52
+ self.out_dims = out_dims
53
+ self.d_input = d_input
54
+ self.backbone_type = 'qamnist_backbone'
55
+ self.iterations_per_digit = iterations_per_digit
56
+ self.iterations_per_question_part = iterations_per_question_part
57
+ self.total_iterations_for_answering = iterations_for_answering
58
+
59
+ # --- Backbone / Feature Extraction ---
60
+ self.backbone_digit = MNISTBackbone(d_input)
61
+ self.index_backbone = QAMNISTIndexEmbeddings(50, d_input)
62
+ self.operator_backbone = QAMNISTOperatorEmbeddings(2, d_input)
63
+
64
+ # --- Core CTM Modules ---
65
+ self.lstm_cell = nn.LSTMCell(d_input, d_model)
66
+ self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
67
+ self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
68
+
69
+ # Attention
70
+ self.q_proj = nn.LazyLinear(d_input)
71
+ self.kv_proj = nn.Sequential(nn.LazyLinear(d_input), nn.LayerNorm(d_input))
72
+ self.attention = nn.MultiheadAttention(d_input, heads, dropout, batch_first=True)
73
+
74
+ # Output Projection
75
+ self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
76
+
77
+ def compute_certainty(self, current_prediction):
78
+ """Compute the certainty of the current prediction."""
79
+ B = current_prediction.size(0)
80
+ reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
81
+ ne = compute_normalized_entropy(reshaped_pred)
82
+ current_certainty = torch.stack((ne, 1-ne), -1)
83
+ return current_certainty
84
+
85
+ def get_kv_for_step(self, stepi, x, z, thought_steps, prev_input=None, prev_kv=None):
86
+ is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
87
+
88
+ if is_digit_step:
89
+ current_input = x[:, stepi]
90
+ if prev_input is not None and torch.equal(current_input, prev_input):
91
+ return prev_kv, prev_input
92
+ kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
93
+
94
+ elif is_question_step:
95
+ offset = stepi - thought_steps.total_iterations_for_digits
96
+ current_input = z[:, offset].squeeze(0)
97
+ if prev_input is not None and torch.equal(current_input, prev_input):
98
+ return prev_kv, prev_input
99
+ is_index_step, is_operator_step = thought_steps.determine_answer_step_type(stepi)
100
+ if is_index_step:
101
+ kv = self.kv_proj(self.index_backbone(current_input))
102
+ elif is_operator_step:
103
+ kv = self.kv_proj(self.operator_backbone(current_input))
104
+ else:
105
+ raise ValueError("Invalid step type for question processing.")
106
+
107
+ elif is_answer_step:
108
+ current_input = None
109
+ kv = torch.zeros((x.size(0), self.d_input), device=x.device)
110
+
111
+ else:
112
+ raise ValueError("Invalid step type.")
113
+
114
+ return kv, current_input
115
+
116
+ def forward(self, x, z, track=False):
117
+ """
118
+ Forward pass - Reverted to structure closer to user's working version.
119
+ Executes T=iterations steps.
120
+ """
121
+ B = x.size(0) # Batch size
122
+
123
+ # --- Tracking Initialization ---
124
+ activations_tracking = []
125
+ attention_tracking = [] # Note: reshaping this correctly requires knowing num_heads
126
+ embedding_tracking = []
127
+
128
+ thought_steps = ThoughtSteps(self.iterations_per_digit, self.iterations_per_question_part, self.total_iterations_for_answering, x.size(1), z.size(1))
129
+
130
+ # --- Step 2: Initialise Recurrent State ---
131
+ hidden_state = torch.repeat_interleave(self.start_hidden_state.unsqueeze(0), x.size(0), 0)
132
+ cell_state = torch.repeat_interleave(self.start_cell_state.unsqueeze(0), x.size(0), 0)
133
+
134
+ state_trace = [hidden_state]
135
+
136
+ device = hidden_state.device
137
+
138
+ # Storage for outputs per iteration
139
+ predictions = torch.empty(B, self.out_dims, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
140
+ certainties = torch.empty(B, 2, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
141
+
142
+ prev_input = None
143
+ prev_kv = None
144
+
145
+ # --- Recurrent Loop (T=iterations steps) ---
146
+ for stepi in range(thought_steps.total_iterations):
147
+
148
+ is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
149
+ kv, prev_input = self.get_kv_for_step(stepi, x, z, thought_steps, prev_input, prev_kv)
150
+ prev_kv = kv
151
+
152
+ # --- Interact with Data via Attention ---
153
+ attn_weights = None
154
+ if is_digit_step:
155
+ q = self.q_proj(hidden_state).unsqueeze(1)
156
+ attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
157
+ lstm_input = attn_out.squeeze(1)
158
+ else:
159
+ lstm_input = kv
160
+
161
+
162
+
163
+ hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
164
+ state_trace.append(hidden_state)
165
+
166
+ # --- Get Predictions and Certainties ---
167
+ current_prediction = self.output_projector(hidden_state)
168
+ current_certainty = self.compute_certainty(current_prediction)
169
+
170
+ predictions[..., stepi] = current_prediction
171
+ certainties[..., stepi] = current_certainty
172
+
173
+ # --- Tracking ---
174
+ if track:
175
+ activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
176
+ if attn_weights is not None:
177
+ attention_tracking.append(attn_weights.detach().cpu().numpy())
178
+ if is_question_step:
179
+ embedding_tracking.append(kv.detach().cpu().numpy())
180
+
181
+ # --- Return Values ---
182
+ if track:
183
+ return predictions, certainties, None, np.array(activations_tracking), np.array(activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
184
+ return predictions, certainties, None
models/lstm_rl.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F # Used for GLU if not in modules
4
+ import numpy as np
5
+ import math
6
+
7
+ # Local imports (Assuming these contain necessary custom modules)
8
+ from models.modules import *
9
+ from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
10
+
11
+
12
+ class LSTMBaseline(nn.Module):
13
+ """
14
+
15
+ LSTM Baseline
16
+
17
+ Args:
18
+ iterations (int): Number of internal 'thought' steps (T, in paper).
19
+ d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
20
+ d_input (int): Dimensionality of projected attention outputs or direct input features.
21
+ backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
22
+ """
23
+
24
+ def __init__(self,
25
+ iterations,
26
+ d_model,
27
+ d_input,
28
+ backbone_type,
29
+ ):
30
+ super(LSTMBaseline, self).__init__()
31
+
32
+ # --- Core Parameters ---
33
+ self.iterations = iterations
34
+ self.d_model = d_model
35
+ self.backbone_type = backbone_type
36
+
37
+ # --- Input Assertions ---
38
+ assert backbone_type in ('navigation-backbone', 'classic-control-backbone'), f"Invalid backbone_type: {backbone_type}"
39
+
40
+ # --- Backbone / Feature Extraction ---
41
+ if self.backbone_type == 'navigation-backbone':
42
+ grid_size = 7
43
+ self.backbone = MiniGridBackbone(d_input=d_input, grid_size=grid_size)
44
+ lstm_cell_input_dim = grid_size * grid_size * d_input
45
+
46
+ elif self.backbone_type == 'classic-control-backbone':
47
+ self.backbone = ClassicControlBackbone(d_input=d_input)
48
+ lstm_cell_input_dim = d_input
49
+
50
+ else:
51
+ raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
52
+
53
+ # --- Core LSTM Modules ---
54
+ self.lstm_cell = nn.LSTMCell(lstm_cell_input_dim, d_model)
55
+ self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
56
+ self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
57
+
58
+ def compute_features(self, x):
59
+ """Applies backbone and positional embedding to input."""
60
+ return self.backbone(x)
61
+
62
+
63
+ def forward(self, x, hidden_states, track=False):
64
+ """
65
+ Forward pass - Reverted to structure closer to user's working version.
66
+ Executes T=iterations steps.
67
+ """
68
+
69
+ # --- Tracking Initialization ---
70
+ activations_tracking = []
71
+
72
+ # --- Featurise Input Data ---
73
+ features = self.compute_features(x)
74
+
75
+ hidden_state = hidden_states[0]
76
+ cell_state = hidden_states[1]
77
+
78
+ # --- Recurrent Loop ---
79
+ for stepi in range(self.iterations):
80
+
81
+ lstm_input = features.reshape(x.size(0), -1)
82
+ hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
83
+
84
+ # --- Tracking ---
85
+ if track:
86
+ activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
87
+
88
+ hidden_states = (
89
+ hidden_state,
90
+ cell_state
91
+ )
92
+
93
+ # --- Return Values ---
94
+ if track:
95
+ return hidden_state, hidden_states, np.array(activations_tracking), np.array(activations_tracking)
96
+ return hidden_state, hidden_states
models/modules.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F # Used for GLU
4
+ import math
5
+ import numpy as np
6
+
7
+ # Assuming 'add_coord_dim' is defined in models.utils
8
+ from models.utils import add_coord_dim
9
+
10
+ # --- Basic Utility Modules ---
11
+
12
+ class Identity(nn.Module):
13
+ """
14
+ Identity Module.
15
+
16
+ Returns the input tensor unchanged. Useful as a placeholder or a no-op layer
17
+ in nn.Sequential containers or conditional network parts.
18
+ """
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def forward(self, x):
23
+ return x
24
+
25
+
26
+ class Squeeze(nn.Module):
27
+ """
28
+ Squeeze Module.
29
+
30
+ Removes a specified dimension of size 1 from the input tensor.
31
+ Useful for incorporating tensor dimension squeezing within nn.Sequential.
32
+
33
+ Args:
34
+ dim (int): The dimension to squeeze.
35
+ """
36
+ def __init__(self, dim):
37
+ super().__init__()
38
+ self.dim = dim
39
+
40
+ def forward(self, x):
41
+ return x.squeeze(self.dim)
42
+
43
+ # --- Core CTM Component Modules ---
44
+
45
+ class SynapseUNET(nn.Module):
46
+ """
47
+ UNET-style architecture for the Synapse Model (f_theta1 in the paper).
48
+
49
+ This module implements the connections between neurons in the CTM's latent
50
+ space. It processes the combined input (previous post-activation state z^t
51
+ and attention output o^t) to produce the pre-activations (a^t) for the
52
+ next internal tick (Eq. 1 in the paper).
53
+
54
+ While a simpler Linear or MLP layer can be used, the paper notes
55
+ that this U-Net structure empirically performed better, suggesting benefit
56
+ from more flexible synaptic connections[cite: 79, 80]. This implementation
57
+ uses `depth` points in linspace and creates `depth-1` down/up blocks.
58
+
59
+ Args:
60
+ in_dims (int): Number of input dimensions (d_model + d_input).
61
+ out_dims (int): Number of output dimensions (d_model).
62
+ depth (int): Determines structure size; creates `depth-1` down/up blocks.
63
+ minimum_width (int): Smallest channel width at the U-Net bottleneck.
64
+ dropout (float): Dropout rate applied within down/up projections.
65
+ """
66
+ def __init__(self,
67
+ out_dims,
68
+ depth,
69
+ minimum_width=16,
70
+ dropout=0.0):
71
+ super().__init__()
72
+ self.width_out = out_dims
73
+ self.n_deep = depth # Store depth just for reference if needed
74
+
75
+ # Define UNET structure based on depth
76
+ # Creates `depth` width values, leading to `depth-1` blocks
77
+ widths = np.linspace(out_dims, minimum_width, depth)
78
+
79
+ # Initial projection layer
80
+ self.first_projection = nn.Sequential(
81
+ nn.LazyLinear(int(widths[0])), # Project to the first width
82
+ nn.LayerNorm(int(widths[0])),
83
+ nn.SiLU()
84
+ )
85
+
86
+ # Downward path (encoding layers)
87
+ self.down_projections = nn.ModuleList()
88
+ self.up_projections = nn.ModuleList()
89
+ self.skip_lns = nn.ModuleList()
90
+ num_blocks = len(widths) - 1 # Number of down/up blocks created
91
+
92
+ for i in range(num_blocks):
93
+ # Down block: widths[i] -> widths[i+1]
94
+ self.down_projections.append(nn.Sequential(
95
+ nn.Dropout(dropout),
96
+ nn.Linear(int(widths[i]), int(widths[i+1])),
97
+ nn.LayerNorm(int(widths[i+1])),
98
+ nn.SiLU()
99
+ ))
100
+ # Up block: widths[i+1] -> widths[i]
101
+ # Note: Up blocks are added in order matching down blocks conceptually,
102
+ # but applied in reverse order in the forward pass.
103
+ self.up_projections.append(nn.Sequential(
104
+ nn.Dropout(dropout),
105
+ nn.Linear(int(widths[i+1]), int(widths[i])),
106
+ nn.LayerNorm(int(widths[i])),
107
+ nn.SiLU()
108
+ ))
109
+ # Skip connection LayerNorm operates on width[i]
110
+ self.skip_lns.append(nn.LayerNorm(int(widths[i])))
111
+
112
+ def forward(self, x):
113
+ # Initial projection
114
+ out_first = self.first_projection(x)
115
+
116
+ # Downward path, storing outputs for skip connections
117
+ outs_down = [out_first]
118
+ for layer in self.down_projections:
119
+ outs_down.append(layer(outs_down[-1]))
120
+ # outs_down contains [level_0, level_1, ..., level_depth-1=bottleneck] outputs
121
+
122
+ # Upward path, starting from the bottleneck output
123
+ outs_up = outs_down[-1] # Bottleneck activation
124
+ num_blocks = len(self.up_projections) # Should be depth - 1
125
+
126
+ for i in range(num_blocks):
127
+ # Apply up projection in reverse order relative to down blocks
128
+ # up_projection[num_blocks - 1 - i] processes deeper features first
129
+ up_layer_idx = num_blocks - 1 - i
130
+ out_up = self.up_projections[up_layer_idx](outs_up)
131
+
132
+ # Get corresponding skip connection from downward path
133
+ # skip_connection index = num_blocks - 1 - i (same as up_layer_idx)
134
+ # This matches the output width of the up_projection[up_layer_idx]
135
+ skip_idx = up_layer_idx
136
+ skip_connection = outs_down[skip_idx]
137
+
138
+ # Add skip connection and apply LayerNorm corresponding to this level
139
+ # skip_lns index also corresponds to the level = skip_idx
140
+ outs_up = self.skip_lns[skip_idx](out_up + skip_connection)
141
+
142
+ # The final output after all up-projections
143
+ return outs_up
144
+
145
+
146
+ class SuperLinear(nn.Module):
147
+ """
148
+ SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.
149
+
150
+ This layer is the core component enabling Neuron-Level Models (NLMs),
151
+ referred to as g_theta_d in the paper (Eq. 3). It applies N independent
152
+ linear transformations (or small MLPs when used sequentially) to corresponding
153
+ slices of the input tensor along a specified dimension (typically the neuron
154
+ or feature dimension).
155
+
156
+ How it works for NLMs:
157
+ - The input `x` is expected to be the pre-activation history for each neuron,
158
+ shaped (batch_size, n_neurons=N, history_length=in_dims).
159
+ - This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons.
160
+ `w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims).
161
+ - `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix
162
+ multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`):
163
+ - For each neuron `n` (from 0 to N-1):
164
+ - It takes the neuron's history `x[:, n, :]` (shape B, in_dims).
165
+ - Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims).
166
+ - Resulting in `out[:, n, :]` (shape B, out_dims).
167
+ - The unique bias `self.b1[:, n, :]` is added.
168
+ - The result is squeezed on the last dim (if out_dims=1) and scaled by `T`.
169
+
170
+ This allows each neuron `d` to process its temporal history `A_d^t` using
171
+ its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`,
172
+ enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85].
173
+ It's typically used within the `trace_processor` module of the main CTM class.
174
+
175
+ Args:
176
+ in_dims (int): Input dimension (typically `memory_length`).
177
+ out_dims (int): Output dimension per neuron.
178
+ N (int): Number of independent linear models (typically `d_model`).
179
+ T (float): Initial value for learnable temperature/scaling factor applied to output.
180
+ do_norm (bool): Apply Layer Normalization to the input history before linear transform.
181
+ dropout (float): Dropout rate applied to the input.
182
+ """
183
+ def __init__(self,
184
+ in_dims,
185
+ out_dims,
186
+ N,
187
+ T=1.0,
188
+ do_norm=False,
189
+ dropout=0):
190
+ super().__init__()
191
+ # N is the number of neurons (d_model), in_dims is the history length (memory_length)
192
+ self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
193
+ self.in_dims = in_dims # Corresponds to memory_length
194
+ # LayerNorm applied across the history dimension for each neuron independently
195
+ self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity()
196
+ self.do_norm = do_norm
197
+
198
+ # Initialize weights and biases
199
+ # w1 shape: (memory_length, out_dims, d_model)
200
+ self.register_parameter('w1', nn.Parameter(
201
+ torch.empty((in_dims, out_dims, N)).uniform_(
202
+ -1/math.sqrt(in_dims + out_dims),
203
+ 1/math.sqrt(in_dims + out_dims)
204
+ ), requires_grad=True)
205
+ )
206
+ # b1 shape: (1, d_model, out_dims)
207
+ self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
208
+ # Learnable temperature/scaler T
209
+ self.register_parameter('T', nn.Parameter(torch.Tensor([T])))
210
+
211
+ def forward(self, x):
212
+ """
213
+ Args:
214
+ x (torch.Tensor): Input tensor, expected shape (B, N, in_dims)
215
+ where B=batch, N=d_model, in_dims=memory_length.
216
+ Returns:
217
+ torch.Tensor: Output tensor, shape (B, N) after squeeze(-1).
218
+ """
219
+ # Input shape: (B, D, M) where D=d_model=N neurons in CTM, M=history/memory length
220
+ out = self.dropout(x)
221
+ # LayerNorm across the memory_length dimension (dim=-1)
222
+ out = self.layernorm(out) # Shape remains (B, N, M)
223
+
224
+ # Apply N independent linear models using einsum
225
+ # einsum('BDM,MHD->BDH', ...)
226
+ # x: (B=batch size, D=N neurons, one NLM per each of these, M=history/memory length)
227
+ # w1: (M, H=hidden dims if using MLP, otherwise output, D=N neurons, parallel)
228
+ # b1: (1, D=N neurons, H)
229
+ # einsum result: (B, D, H)
230
+ # Applying bias requires matching shapes, b1 is broadcasted.
231
+ out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1
232
+
233
+ # Squeeze the output dimension (assumed to be 1 usually) and scale by T
234
+ # This matches the original code's structure exactly.
235
+ out = out.squeeze(-1) / self.T
236
+ return out
237
+
238
+
239
+ # --- Backbone Modules ---
240
+
241
+ class ParityBackbone(nn.Module):
242
+ def __init__(self, n_embeddings, d_embedding):
243
+ super(ParityBackbone, self).__init__()
244
+ self.embedding = nn.Embedding(n_embeddings, d_embedding)
245
+
246
+ def forward(self, x):
247
+ """
248
+ Maps -1 (negative parity) to 0 and 1 (positive) to 1
249
+ """
250
+ x = (x == 1).long()
251
+ return self.embedding(x.long()).transpose(1, 2) # Transpose for compatibility with other backbones
252
+
253
+ class QAMNISTOperatorEmbeddings(nn.Module):
254
+ def __init__(self, num_operator_types, d_projection):
255
+ super(QAMNISTOperatorEmbeddings, self).__init__()
256
+ self.embedding = nn.Embedding(num_operator_types, d_projection)
257
+
258
+ def forward(self, x):
259
+ # -1 for plus and -2 for minus
260
+ return self.embedding(-x - 1)
261
+
262
+ class QAMNISTIndexEmbeddings(torch.nn.Module):
263
+ def __init__(self, max_seq_length, embedding_dim):
264
+ super().__init__()
265
+ self.max_seq_length = max_seq_length
266
+ self.embedding_dim = embedding_dim
267
+
268
+ embedding = torch.zeros(max_seq_length, embedding_dim)
269
+ position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
270
+ div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
271
+
272
+ embedding[:, 0::2] = torch.sin(position * div_term)
273
+ embedding[:, 1::2] = torch.cos(position * div_term)
274
+
275
+ self.register_buffer('embedding', embedding)
276
+
277
+ def forward(self, x):
278
+ return self.embedding[x]
279
+
280
+ class ThoughtSteps:
281
+ """
282
+ Helper class for managing "thought steps" in the ctm_qamnist pipeline.
283
+
284
+ Args:
285
+ iterations_per_digit (int): Number of iterations for each digit.
286
+ iterations_per_question_part (int): Number of iterations for each question part.
287
+ total_iterations_for_answering (int): Total number of iterations for answering.
288
+ total_iterations_for_digits (int): Total number of iterations for digits.
289
+ total_iterations_for_question (int): Total number of iterations for question.
290
+ """
291
+ def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question):
292
+ self.iterations_per_digit = iterations_per_digit
293
+ self.iterations_per_question_part = iterations_per_question_part
294
+ self.total_iterations_for_digits = total_iterations_for_digits
295
+ self.total_iterations_for_question = total_iterations_for_question
296
+ self.total_iterations_for_answering = total_iterations_for_answering
297
+ self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering
298
+
299
+ def determine_step_type(self, stepi: int):
300
+ is_digit_step = stepi < self.total_iterations_for_digits
301
+ is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question
302
+ is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question
303
+ return is_digit_step, is_question_step, is_answer_step
304
+
305
+ def determine_answer_step_type(self, stepi: int):
306
+ step_within_questions = stepi - self.total_iterations_for_digits
307
+ if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
308
+ is_index_step = True
309
+ is_operator_step = False
310
+ else:
311
+ is_index_step = False
312
+ is_operator_step = True
313
+ return is_index_step, is_operator_step
314
+
315
+ class MNISTBackbone(nn.Module):
316
+ """
317
+ Simple backbone for MNIST feature extraction.
318
+ """
319
+ def __init__(self, d_input):
320
+ super(MNISTBackbone, self).__init__()
321
+ self.layers = nn.Sequential(
322
+ nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
323
+ nn.BatchNorm2d(d_input),
324
+ nn.ReLU(),
325
+ nn.MaxPool2d(2, 2),
326
+ nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
327
+ nn.BatchNorm2d(d_input),
328
+ nn.ReLU(),
329
+ nn.MaxPool2d(2, 2),
330
+ )
331
+
332
+ def forward(self, x):
333
+ return self.layers(x)
334
+
335
+
336
+ class MiniGridBackbone(nn.Module):
337
+ def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8):
338
+ super().__init__()
339
+ self.object_embedding = nn.Embedding(num_objects, embedding_dim)
340
+ self.color_embedding = nn.Embedding(num_colors, embedding_dim)
341
+ self.state_embedding = nn.Embedding(num_states, embedding_dim)
342
+
343
+ self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim)
344
+
345
+ self.project_to_d_projection = nn.Sequential(
346
+ nn.Linear(embedding_dim * 4, d_input * 2),
347
+ nn.GLU(),
348
+ nn.LayerNorm(d_input),
349
+ nn.Linear(d_input, d_input * 2),
350
+ nn.GLU(),
351
+ nn.LayerNorm(d_input)
352
+ )
353
+
354
+ def forward(self, x):
355
+ x = x.long()
356
+ B, H, W, C = x.size()
357
+
358
+ object_idx = x[:,:,:, 0]
359
+ color_idx = x[:,:,:, 1]
360
+ state_idx = x[:,:,:, 2]
361
+
362
+ obj_embed = self.object_embedding(object_idx)
363
+ color_embed = self.color_embedding(color_idx)
364
+ state_embed = self.state_embedding(state_idx)
365
+
366
+ pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1)
367
+ pos_embed = self.position_embedding(pos_idx)
368
+
369
+ out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1))
370
+ return out
371
+
372
+ class ClassicControlBackbone(nn.Module):
373
+ def __init__(self, d_input):
374
+ super().__init__()
375
+ self.input_projector = nn.Sequential(
376
+ nn.Flatten(),
377
+ nn.LazyLinear(d_input * 2),
378
+ nn.GLU(),
379
+ nn.LayerNorm(d_input),
380
+ nn.LazyLinear(d_input * 2),
381
+ nn.GLU(),
382
+ nn.LayerNorm(d_input)
383
+ )
384
+
385
+ def forward(self, x):
386
+ return self.input_projector(x)
387
+
388
+
389
+ class ShallowWide(nn.Module):
390
+ """
391
+ Simple, wide, shallow convolutional backbone for image feature extraction.
392
+
393
+ Alternative to ResNet, uses grouped convolutions and GLU activations.
394
+ Fixed structure, useful for specific experiments.
395
+ """
396
+ def __init__(self):
397
+ super(ShallowWide, self).__init__()
398
+ # LazyConv2d infers input channels
399
+ self.layers = nn.Sequential(
400
+ nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), # Output channels = 4096
401
+ nn.GLU(dim=1), # Halves channels to 2048
402
+ nn.BatchNorm2d(2048),
403
+ # Grouped convolution maintains width but processes groups independently
404
+ nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32),
405
+ nn.GLU(dim=1), # Halves channels to 2048
406
+ nn.BatchNorm2d(2048)
407
+ )
408
+ def forward(self, x):
409
+ return self.layers(x)
410
+
411
+
412
+ class PretrainedResNetWrapper(nn.Module):
413
+ """
414
+ Wrapper to use standard pre-trained ResNet models from torchvision.
415
+
416
+ Loads a specified ResNet architecture pre-trained on ImageNet, removes the
417
+ final classification layer (fc), average pooling, and optionally later layers
418
+ (e.g., layer4), allowing it to be used as a feature extractor backbone.
419
+
420
+ Args:
421
+ resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50').
422
+ fine_tune (bool): If False, freezes the weights of the pre-trained backbone.
423
+ """
424
+ def __init__(self, resnet_type, fine_tune=True):
425
+ super(PretrainedResNetWrapper, self).__init__()
426
+ self.resnet_type = resnet_type
427
+ self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True)
428
+
429
+ if not fine_tune:
430
+ for param in self.backbone.parameters():
431
+ param.requires_grad = False
432
+
433
+ # Remove final layers to use as feature extractor
434
+ self.backbone.avgpool = Identity()
435
+ self.backbone.fc = Identity()
436
+ # Keep layer4 by default, user can modify instance if needed
437
+ # self.backbone.layer4 = Identity()
438
+
439
+ def forward(self, x):
440
+ # Get features from the modified ResNet
441
+ out = self.backbone(x)
442
+
443
+ # Reshape output to (B, C, H, W) - This is heuristic based on original comment.
444
+ # User might need to adjust this based on which layers are kept/removed.
445
+ # Infer C based on ResNet type (example values)
446
+ nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 # Approx for layer3/4 output channel numbers
447
+ # Infer H, W assuming output is flattened C * H * W
448
+ num_features = out.shape[-1]
449
+ # This calculation assumes nc is correct and feature map is square
450
+ wh_squared = num_features / nc
451
+ if wh_squared < 0 or not float(wh_squared).is_integer():
452
+ print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}")
453
+ # Return potentially flattened features if reshape fails
454
+ return out
455
+ wh = int(np.sqrt(wh_squared))
456
+
457
+ return out.reshape(x.size(0), nc, wh, wh)
458
+
459
+ # --- Positional Encoding Modules ---
460
+
461
+ class LearnableFourierPositionalEncoding(nn.Module):
462
+ """
463
+ Learnable Fourier Feature Positional Encoding.
464
+
465
+ Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
466
+ Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
467
+ Provides positional information for 2D feature maps.
468
+
469
+ Args:
470
+ d_model (int): The output dimension of the positional encoding (D).
471
+ G (int): Positional groups (default 1).
472
+ M (int): Dimensionality of input coordinates (default 2 for H, W).
473
+ F_dim (int): Dimension of the Fourier features.
474
+ H_dim (int): Hidden dimension of the MLP.
475
+ gamma (float): Initialization scale for the Fourier projection weights (Wr).
476
+ """
477
+ def __init__(self, d_model,
478
+ G=1, M=2,
479
+ F_dim=256,
480
+ H_dim=128,
481
+ gamma=1/2.5,
482
+ ):
483
+ super().__init__()
484
+ self.G = G
485
+ self.M = M
486
+ self.F_dim = F_dim
487
+ self.H_dim = H_dim
488
+ self.D = d_model
489
+ self.gamma = gamma
490
+
491
+ self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
492
+ self.mlp = nn.Sequential(
493
+ nn.Linear(self.F_dim, self.H_dim, bias=True),
494
+ nn.GLU(), # Halves H_dim
495
+ nn.Linear(self.H_dim // 2, self.D // self.G),
496
+ nn.LayerNorm(self.D // self.G)
497
+ )
498
+
499
+ self.init_weights()
500
+
501
+ def init_weights(self):
502
+ nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
503
+
504
+ def forward(self, x):
505
+ """
506
+ Computes positional encodings for the input feature map x.
507
+
508
+ Args:
509
+ x (torch.Tensor): Input feature map, shape (B, C, H, W).
510
+
511
+ Returns:
512
+ torch.Tensor: Positional encoding tensor, shape (B, D, H, W).
513
+ """
514
+ B, C, H, W = x.shape
515
+ # Creates coordinates based on (H, W) and repeats for batch B.
516
+ # Takes x[:,0] assuming channel dim isn't needed for coords.
517
+ x_coord = add_coord_dim(x[:,0]) # Expects (B, H, W) -> (B, H, W, 2)
518
+
519
+ # Compute Fourier features
520
+ projected = self.Wr(x_coord) # (B, H, W, F_dim // 2)
521
+ cosines = torch.cos(projected)
522
+ sines = torch.sin(projected)
523
+ F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) # (B, H, W, F_dim)
524
+
525
+ # Project features through MLP
526
+ Y = self.mlp(F) # (B, H, W, D // G)
527
+
528
+ # Reshape to (B, D, H, W)
529
+ PEx = Y.permute(0, 3, 1, 2) # Assuming G=1
530
+ return PEx
531
+
532
+
533
+ class MultiLearnableFourierPositionalEncoding(nn.Module):
534
+ """
535
+ Combines multiple LearnableFourierPositionalEncoding modules with different
536
+ initialization scales (gamma) via a learnable weighted sum.
537
+
538
+ Allows the model to learn an optimal combination of positional frequencies.
539
+
540
+ Args:
541
+ d_model (int): Output dimension of the encoding.
542
+ G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding.
543
+ gamma_range (list[float]): Min and max gamma values for the linspace.
544
+ N (int): Number of parallel embedding modules to create.
545
+ """
546
+ def __init__(self, d_model,
547
+ G=1, M=2,
548
+ F_dim=256,
549
+ H_dim=128,
550
+ gamma_range=[1.0, 0.1], # Default range
551
+ N=10,
552
+ ):
553
+ super().__init__()
554
+ self.embedders = nn.ModuleList()
555
+ for gamma in np.linspace(gamma_range[0], gamma_range[1], N):
556
+ self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma))
557
+
558
+ # Renamed parameter from 'combination' to 'combination_weights' for clarity only in comments
559
+ # Actual registered name remains 'combination' as in original code
560
+ self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True))
561
+ self.N = N
562
+
563
+
564
+ def forward(self, x):
565
+ """
566
+ Computes combined positional encoding.
567
+
568
+ Args:
569
+ x (torch.Tensor): Input feature map, shape (B, C, H, W).
570
+
571
+ Returns:
572
+ torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W).
573
+ """
574
+ # Compute embeddings from all modules and stack: (N, B, D, H, W)
575
+ pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0)
576
+
577
+ # Compute combination weights using softmax
578
+ # Use registered parameter name 'combination'
579
+ # Reshape weights for broadcasting: (N,) -> (N, 1, 1, 1, 1)
580
+ weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1)
581
+
582
+ # Compute weighted sum over the N dimension
583
+ combined_emb = (pos_embs * weights).sum(0) # (B, D, H, W)
584
+ return combined_emb
585
+
586
+
587
+ class CustomRotationalEmbedding(nn.Module):
588
+ """
589
+ Custom Rotational Positional Embedding.
590
+
591
+ Generates 2D positional embeddings based on rotating a fixed start vector.
592
+ The rotation angle for each grid position is determined primarily by its
593
+ horizontal position (width dimension). The resulting rotated vectors are
594
+ concatenated and projected.
595
+
596
+ Note: The current implementation derives angles only from the width dimension (`x.size(-1)`).
597
+
598
+ Args:
599
+ d_model (int): Dimensionality of the output embeddings.
600
+ """
601
+ def __init__(self, d_model):
602
+ super(CustomRotationalEmbedding, self).__init__()
603
+ # Learnable 2D start vector
604
+ self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True))
605
+ # Projects the 4D concatenated rotated vectors to d_model
606
+ # Input size 4 comes from concatenating two 2D rotated vectors
607
+ self.projection = nn.Sequential(nn.Linear(4, d_model))
608
+
609
+ def forward(self, x):
610
+ """
611
+ Computes rotational positional embeddings based on input width.
612
+
613
+ Args:
614
+ x (torch.Tensor): Input tensor (used for shape and device),
615
+ shape (batch_size, channels, height, width).
616
+ Returns:
617
+ Output tensor containing positional embeddings,
618
+ shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all.
619
+ """
620
+ B, C, H, W = x.shape
621
+ device = x.device
622
+
623
+ # --- Generate rotations based only on Width ---
624
+ # Angles derived from width dimension
625
+ theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) # Angle per column
626
+ cos_theta = torch.cos(theta_rad)
627
+ sin_theta = torch.sin(theta_rad)
628
+
629
+ # Create rotation matrices: Shape (W, 2, 2)
630
+ # Use unsqueeze(1) to allow stacking along dim 1
631
+ rotation_matrices = torch.stack([
632
+ torch.stack([cos_theta, -sin_theta], dim=-1), # Shape (W, 2)
633
+ torch.stack([sin_theta, cos_theta], dim=-1) # Shape (W, 2)
634
+ ], dim=1) # Stacks along dim 1 -> Shape (W, 2, 2)
635
+
636
+ # Rotate the start vector by column angle: Shape (W, 2)
637
+ rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector)
638
+
639
+ # --- Create Grid Key ---
640
+ # Original code uses repeats based on rotated_vectors.shape[0] (which is W) for both dimensions.
641
+ # This creates a (W, W, 4) key tensor.
642
+ key = torch.cat((
643
+ torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), # (W, 1, 2) -> (W, W, 2)
644
+ torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0) # (1, W, 2) -> (W, W, 2)
645
+ ), dim=-1) # Shape (W, W, 4)
646
+
647
+ # Project the 4D key vector to d_model: Shape (W, W, d_model)
648
+ pe_grid = self.projection(key)
649
+
650
+ # Reshape to (1, d_model, W, W) and then select/resize to target H, W?
651
+ # Original code permutes to (d_model, W, W) and unsqueezes to (1, d_model, W, W)
652
+ pe = pe_grid.permute(2, 0, 1).unsqueeze(0)
653
+
654
+ # If H != W, this needs adjustment. Assuming H=W or cropping/padding happens later.
655
+ # Let's return the (1, d_model, W, W) tensor as generated by the original logic.
656
+ # If H != W, downstream code must handle the mismatch or this PE needs modification.
657
+ if H != W:
658
+ # Simple interpolation/cropping could be added, but sticking to original logic:
659
+ # Option 1: Interpolate
660
+ # pe = F.interpolate(pe, size=(H, W), mode='bilinear', align_corners=False)
661
+ # Option 2: Crop/Pad (e.g., crop if W > W_target, pad if W < W_target)
662
+ # Sticking to original: return shape (1, d_model, W, W)
663
+ pass
664
+
665
+ return pe
666
+
667
+ class CustomRotationalEmbedding1D(nn.Module):
668
+ def __init__(self, d_model):
669
+ super(CustomRotationalEmbedding1D, self).__init__()
670
+ self.projection = nn.Linear(2, d_model)
671
+
672
+ def forward(self, x):
673
+ start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float)
674
+ theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device))
675
+ cos_theta = torch.cos(theta_rad)
676
+ sin_theta = torch.sin(theta_rad)
677
+ cos_theta = cos_theta.unsqueeze(1) # Shape: (height, 1)
678
+ sin_theta = sin_theta.unsqueeze(1) # Shape: (height, 1)
679
+
680
+ # Create rotation matrices
681
+ rotation_matrices = torch.stack([
682
+ torch.cat([cos_theta, -sin_theta], dim=1),
683
+ torch.cat([sin_theta, cos_theta], dim=1)
684
+ ], dim=1) # Shape: (height, 2, 2)
685
+
686
+ # Rotate the start vector
687
+ rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector)
688
+
689
+ pe = self.projection(rotated_vectors)
690
+ pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0)
691
+ return pe.transpose(1, 2) # Transpose for compatibility with other backbones
692
+
models/resnet.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ from models.modules import Identity
5
+
6
+ __all__ = [
7
+ "ResNet",
8
+ "resnet18",
9
+ "resnet34",
10
+ "resnet50",
11
+ "resnet101",
12
+ "resnet152",
13
+ ]
14
+
15
+
16
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
17
+ """3x3 convolution with padding"""
18
+ return nn.Conv2d(
19
+ in_planes,
20
+ out_planes,
21
+ kernel_size=3,
22
+ stride=stride,
23
+ padding=dilation,
24
+ groups=groups,
25
+ bias=False,
26
+ dilation=dilation,
27
+ )
28
+
29
+
30
+ def conv1x1(in_planes, out_planes, stride=1):
31
+ """1x1 convolution"""
32
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
33
+
34
+
35
+ class BasicBlock(nn.Module):
36
+ expansion = 1
37
+
38
+ def __init__(
39
+ self,
40
+ inplanes,
41
+ planes,
42
+ stride=1,
43
+ downsample=None,
44
+ groups=1,
45
+ base_width=64,
46
+ dilation=1,
47
+ norm_layer=None,
48
+ ):
49
+ super(BasicBlock, self).__init__()
50
+ if norm_layer is None:
51
+ norm_layer = nn.BatchNorm2d
52
+ if groups != 1 or base_width != 64:
53
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
54
+ if dilation > 1:
55
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
56
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
57
+ self.conv1 = conv3x3(inplanes, planes, stride)
58
+ self.bn1 = norm_layer(planes)
59
+ self.relu = nn.ReLU(inplace=True)
60
+ self.conv2 = conv3x3(planes, planes)
61
+ self.bn2 = norm_layer(planes)
62
+ self.downsample = downsample
63
+ self.stride = stride
64
+
65
+ def forward(self, x):
66
+ identity = x
67
+
68
+ out = self.conv1(x)
69
+ out = self.bn1(out)
70
+ out = self.relu(out)
71
+
72
+ out = self.conv2(out)
73
+ out = self.bn2(out)
74
+
75
+ if self.downsample is not None:
76
+ identity = self.downsample(x)
77
+
78
+ out += identity
79
+
80
+ out = self.relu(out)
81
+ return out
82
+
83
+
84
+ class Bottleneck(nn.Module):
85
+ expansion = 4
86
+
87
+ def __init__(
88
+ self,
89
+ inplanes,
90
+ planes,
91
+ stride=1,
92
+ downsample=None,
93
+ groups=1,
94
+ base_width=64,
95
+ dilation=1,
96
+ norm_layer=None,
97
+ ):
98
+ super(Bottleneck, self).__init__()
99
+ if norm_layer is None:
100
+ norm_layer = nn.BatchNorm2d
101
+ width = int(planes * (base_width / 64.0)) * groups
102
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
103
+ self.conv1 = conv1x1(inplanes, width)
104
+ self.bn1 = norm_layer(width)
105
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
106
+ self.bn2 = norm_layer(width)
107
+ self.conv3 = conv1x1(width, planes * self.expansion)
108
+ self.bn3 = norm_layer(planes * self.expansion)
109
+ self.relu = nn.ReLU(inplace=True)
110
+ self.downsample = downsample
111
+ self.stride = stride
112
+
113
+ def forward(self, x):
114
+ identity = x
115
+
116
+ out = self.conv1(x)
117
+ out = self.bn1(out)
118
+ out = self.relu(out)
119
+
120
+ out = self.conv2(out)
121
+ out = self.bn2(out)
122
+ out = self.relu(out)
123
+
124
+ out = self.conv3(out)
125
+ out = self.bn3(out)
126
+
127
+ if self.downsample is not None:
128
+ identity = self.downsample(x)
129
+
130
+ out += identity
131
+
132
+
133
+ # activation = None
134
+ # activation = out.detach().cpu().numpy()
135
+ out = self.relu(out)
136
+ # return out, activation
137
+
138
+ return out
139
+
140
+
141
+ class ResNet(nn.Module):
142
+ def __init__(
143
+ self,
144
+ in_channels,
145
+ feature_scales,
146
+ stride,
147
+ block,
148
+ layers,
149
+ num_classes=10,
150
+ zero_init_residual=False,
151
+ groups=1,
152
+ width_per_group=64,
153
+ replace_stride_with_dilation=None,
154
+ norm_layer=None,
155
+ do_initial_max_pool=True,
156
+ ):
157
+ super(ResNet, self).__init__()
158
+ if norm_layer is None:
159
+ norm_layer = nn.BatchNorm2d
160
+ self._norm_layer = norm_layer
161
+
162
+ self.inplanes = 64
163
+ self.dilation = 1
164
+ if replace_stride_with_dilation is None:
165
+ # each element in the tuple indicates if we should replace
166
+ # the 2x2 stride with a dilated convolution instead
167
+ replace_stride_with_dilation = [False, False, False]
168
+ if len(replace_stride_with_dilation) != 3:
169
+ raise ValueError(
170
+ "replace_stride_with_dilation should be None "
171
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
172
+ )
173
+ self.groups = groups
174
+ self.base_width = width_per_group
175
+
176
+ # NOTE: Important!
177
+ # This has changed from a kernel size of 7 (padding=3) to a kernel of 3 (padding=1)
178
+ # The reason for this was to limit the receptive field to constrain models to
179
+ # "Looking around" to gather information.
180
+
181
+ self.conv1 = nn.Conv2d(
182
+ in_channels, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
183
+ ) if in_channels in [1, 3] else nn.LazyConv2d(
184
+ self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
185
+ )
186
+ # END
187
+
188
+ self.bn1 = norm_layer(self.inplanes)
189
+ self.relu = nn.ReLU(inplace=True)
190
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if do_initial_max_pool else Identity()
191
+ self.layer1 = self._make_layer(block, 64, layers[0])
192
+ self.feature_scales = feature_scales
193
+ if 2 in feature_scales:
194
+ self.layer2 = self._make_layer(
195
+ block, 128, layers[1], stride=stride, dilate=replace_stride_with_dilation[0]
196
+ )
197
+ if 3 in feature_scales:
198
+ self.layer3 = self._make_layer(
199
+ block, 256, layers[2], stride=stride, dilate=replace_stride_with_dilation[1]
200
+ )
201
+ if 4 in feature_scales:
202
+ self.layer4 = self._make_layer(
203
+ block, 512, layers[3], stride=stride, dilate=replace_stride_with_dilation[2]
204
+ )
205
+
206
+ # NOTE: Commented this out as it is not used anymore for this work, kept it for reference
207
+ # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
208
+ # self.fc = nn.Linear(512 * block.expansion, num_classes)
209
+
210
+ # for m in self.modules():
211
+ # if isinstance(m, nn.Conv2d):
212
+ # nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
213
+ # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
214
+ # nn.init.constant_(m.weight, 1)
215
+ # nn.init.constant_(m.bias, 0)
216
+
217
+ # Zero-initialize the last BN in each residual branch,
218
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
219
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
220
+ if zero_init_residual:
221
+ for m in self.modules():
222
+ if isinstance(m, Bottleneck):
223
+ nn.init.constant_(m.bn3.weight, 0)
224
+ elif isinstance(m, BasicBlock):
225
+ nn.init.constant_(m.bn2.weight, 0)
226
+
227
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
228
+ norm_layer = self._norm_layer
229
+ downsample = None
230
+ previous_dilation = self.dilation
231
+ if dilate:
232
+ self.dilation *= stride
233
+ stride = 1
234
+ if stride != 1 or self.inplanes != planes * block.expansion:
235
+ downsample = nn.Sequential(
236
+ conv1x1(self.inplanes, planes * block.expansion, stride),
237
+ norm_layer(planes * block.expansion),
238
+ )
239
+
240
+ layers = []
241
+ layers.append(
242
+ block(
243
+ self.inplanes,
244
+ planes,
245
+ stride,
246
+ downsample,
247
+ self.groups,
248
+ self.base_width,
249
+ previous_dilation,
250
+ norm_layer,
251
+ )
252
+ )
253
+ self.inplanes = planes * block.expansion
254
+ for _ in range(1, blocks):
255
+ layers.append(
256
+ block(
257
+ self.inplanes,
258
+ planes,
259
+ groups=self.groups,
260
+ base_width=self.base_width,
261
+ dilation=self.dilation,
262
+ norm_layer=norm_layer,
263
+ )
264
+ )
265
+
266
+ return nn.Sequential(*layers)
267
+
268
+ def forward(self, x):
269
+ activations = []
270
+ x = self.conv1(x)
271
+ x = self.bn1(x)
272
+ x = self.relu(x)
273
+ x = self.maxpool(x)
274
+ # if return_activations: activations.append(torch.clone(x))
275
+ x = self.layer1(x)
276
+
277
+ if 2 in self.feature_scales:
278
+ x = self.layer2(x)
279
+ if 3 in self.feature_scales:
280
+ x = self.layer3(x)
281
+ if 4 in self.feature_scales:
282
+ x = self.layer4(x)
283
+ return x
284
+
285
+
286
+ def _resnet(in_channels, feature_scales, stride, arch, block, layers, pretrained, progress, device, do_initial_max_pool, **kwargs):
287
+ model = ResNet(in_channels, feature_scales, stride, block, layers, do_initial_max_pool=do_initial_max_pool, **kwargs)
288
+ if pretrained:
289
+ assert in_channels==3
290
+ script_dir = os.path.dirname(__file__)
291
+ state_dict = torch.load(
292
+ script_dir + '/state_dicts/' + arch + ".pt", map_location=device
293
+ )
294
+ model.load_state_dict(state_dict, strict=False)
295
+ return model
296
+
297
+
298
+ def resnet18(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
299
+ """Constructs a ResNet-18 model.
300
+ Args:
301
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
302
+ progress (bool): If True, displays a progress bar of the download to stderr
303
+ """
304
+ return _resnet(in_channels,
305
+ feature_scales, stride, "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, do_initial_max_pool, **kwargs
306
+ )
307
+
308
+
309
+ def resnet34(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
310
+ """Constructs a ResNet-34 model.
311
+ Args:
312
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
313
+ progress (bool): If True, displays a progress bar of the download to stderr
314
+ """
315
+ return _resnet(in_channels,
316
+ feature_scales, stride, "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
317
+ )
318
+
319
+
320
+ def resnet50(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
321
+ """Constructs a ResNet-50 model.
322
+ Args:
323
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
324
+ progress (bool): If True, displays a progress bar of the download to stderr
325
+ """
326
+ return _resnet(in_channels,
327
+ feature_scales, stride, "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
328
+ )
329
+
330
+
331
+ def resnet101(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
332
+ """Constructs a ResNet-50 model.
333
+ Args:
334
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
335
+ progress (bool): If True, displays a progress bar of the download to stderr
336
+ """
337
+ return _resnet(in_channels,
338
+ feature_scales, stride, "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
339
+ )
340
+
341
+
342
+ def resnet152(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
343
+ """Constructs a ResNet-50 model.
344
+ Args:
345
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
346
+ progress (bool): If True, displays a progress bar of the download to stderr
347
+ """
348
+ return _resnet(in_channels,
349
+ feature_scales, stride, "resnet152", Bottleneck, [3, 4, 36, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
350
+ )
351
+
352
+ def prepare_resnet_backbone(backbone_type):
353
+
354
+ resnet_family = resnet18 # Default
355
+ if '34' in backbone_type: resnet_family = resnet34
356
+ if '50' in backbone_type: resnet_family = resnet50
357
+ if '101' in backbone_type: resnet_family = resnet101
358
+ if '152' in backbone_type: resnet_family = resnet152
359
+
360
+ # Determine which ResNet blocks to keep
361
+ block_num_str = backbone_type.split('-')[-1]
362
+ hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]
363
+
364
+ backbone = resnet_family(
365
+ 3,
366
+ hyper_blocks_to_keep,
367
+ stride=2,
368
+ pretrained=False,
369
+ progress=True,
370
+ device="cpu",
371
+ do_initial_max_pool=True,
372
+ )
373
+
374
+ return backbone
models/utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import re
4
+ import os
5
+
6
+ def compute_decay(T, params, clamp_lims=(0, 15)):
7
+ """
8
+ This function computes exponential decays for learnable synchronisation
9
+ interactions between pairs of neurons.
10
+ """
11
+ assert len(clamp_lims), 'Clamp lims should be length 2'
12
+ assert type(clamp_lims) == tuple, 'Clamp lims should be tuple'
13
+
14
+ indices = torch.arange(T-1, -1, -1, device=params.device).reshape(T, 1).expand(T, params.shape[0])
15
+ out = torch.exp(-indices * torch.clamp(params, clamp_lims[0], clamp_lims[1]).unsqueeze(0))
16
+ return out
17
+
18
+ def add_coord_dim(x, scaled=True):
19
+ """
20
+ Adds a final dimension to the tensor representing 2D coordinates.
21
+
22
+ Args:
23
+ tensor: A PyTorch tensor of shape (B, D, H, W).
24
+
25
+ Returns:
26
+ A PyTorch tensor of shape (B, D, H, W, 2) with the last dimension
27
+ representing the 2D coordinates within the HW dimensions.
28
+ """
29
+ B, H, W = x.shape
30
+ # Create coordinate grids
31
+ x_coords = torch.arange(W, device=x.device, dtype=x.dtype).repeat(H, 1) # Shape (H, W)
32
+ y_coords = torch.arange(H, device=x.device, dtype=x.dtype).unsqueeze(-1).repeat(1, W) # Shape (H, W)
33
+ if scaled:
34
+ x_coords /= (W-1)
35
+ y_coords /= (H-1)
36
+ # Stack coordinates and expand dimensions
37
+ coords = torch.stack((x_coords, y_coords), dim=-1) # Shape (H, W, 2)
38
+ coords = coords.unsqueeze(0) # Shape (1, 1, H, W, 2)
39
+ coords = coords.repeat(B, 1, 1, 1) # Shape (B, D, H, W, 2)
40
+ return coords
41
+
42
+ def compute_normalized_entropy(logits, reduction='mean'):
43
+ """
44
+ Calculates the normalized entropy of a PyTorch tensor of logits along the
45
+ final dimension.
46
+
47
+ Args:
48
+ logits: A PyTorch tensor of logits.
49
+
50
+ Returns:
51
+ A PyTorch tensor containing the normalized entropy values.
52
+ """
53
+
54
+ # Apply softmax to get probabilities
55
+ preds = F.softmax(logits, dim=-1)
56
+
57
+ # Calculate the log probabilities
58
+ log_preds = torch.log_softmax(logits, dim=-1)
59
+
60
+ # Calculate the entropy
61
+ entropy = -torch.sum(preds * log_preds, dim=-1)
62
+
63
+ # Calculate the maximum possible entropy
64
+ num_classes = preds.shape[-1]
65
+ max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32))
66
+
67
+ # Normalize the entropy
68
+ normalized_entropy = entropy / max_entropy
69
+ if len(logits.shape)>2 and reduction == 'mean':
70
+ normalized_entropy = normalized_entropy.flatten(1).mean(-1)
71
+
72
+ return normalized_entropy
73
+
74
+ def reshape_predictions(predictions, prediction_reshaper):
75
+ B, T = predictions.size(0), predictions.size(-1)
76
+ new_shape = [B] + prediction_reshaper + [T]
77
+ rehaped_predictions = predictions.reshape(new_shape)
78
+ return rehaped_predictions
79
+
80
+ def get_all_log_dirs(root_dir):
81
+ folders = []
82
+ for dirpath, dirnames, filenames in os.walk(root_dir):
83
+ if any(f.endswith(".pt") for f in filenames):
84
+ folders.append(dirpath)
85
+ return folders
86
+
87
+ def get_latest_checkpoint(log_dir):
88
+ files = [f for f in os.listdir(log_dir) if re.match(r'checkpoint_\d+\.pt', f)]
89
+ return os.path.join(log_dir, max(files, key=lambda f: int(re.search(r'\d+', f).group()))) if files else None
90
+
91
+ def get_latest_checkpoint_file(filepath, limit=300000):
92
+ checkpoint_files = get_checkpoint_files(filepath)
93
+ checkpoint_files = [
94
+ f for f in checkpoint_files if int(re.search(r'checkpoint_(\d+)\.pt', f).group(1)) <= limit
95
+ ]
96
+ if not checkpoint_files:
97
+ return None
98
+ return checkpoint_files[-1]
99
+
100
+ def get_checkpoint_files(filepath):
101
+ regex = r'checkpoint_(\d+)\.pt'
102
+ files = [f for f in os.listdir(filepath) if re.match(regex, f)]
103
+ files = sorted(files, key=lambda f: int(re.search(regex, f).group(1)))
104
+ return [os.path.join(filepath, f) for f in files]
105
+
106
+ def load_checkpoint(checkpoint_path, device):
107
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
108
+ return checkpoint
109
+
110
+ def get_model_args_from_checkpoint(checkpoint):
111
+ if "args" in checkpoint:
112
+ return(checkpoint["args"])
113
+ else:
114
+ raise ValueError("Checkpoint does not contain saved args.")
115
+
116
+ def get_accuracy_and_loss_from_checkpoint(checkpoint, device="cpu"):
117
+ training_iteration = checkpoint.get('training_iteration', 0)
118
+ train_losses = checkpoint.get('train_losses', [])
119
+ test_losses = checkpoint.get('test_losses', [])
120
+ train_accuracies = checkpoint.get('train_accuracies_most_certain', [])
121
+ test_accuracies = checkpoint.get('test_accuracies_most_certain', [])
122
+ return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ torchvision
4
+ matplotlib
5
+ seaborn
6
+ tdqm
7
+ opencv-python
8
+ imageio
9
+ scikit-learn
10
+ umap-learn
11
+ python-dotenv
12
+ gymnasium
13
+ minigrid
14
+ datasets
15
+ autoclip
tasks/image_classification/README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Image classification
2
+
3
+ This folder contains code for training and analysing imagenet and cifar related experiments.
4
+
5
+ ## Accessing and loading imagenet
6
+
7
+ We use the [ILSRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset in our paper.
8
+
9
+ To get this to work for you, you will need to do the following:
10
+ 1. Login to huggingface (make an account) to agree to TCs of this dataset,
11
+ 2. Make a new access token.
12
+ 3. Install huggingface_hub on the target machine with ```pip install huggingface_hub```
13
+ 4. Run ```huggingface-cli login``` and use your token. This will authenticate you on the backend and allow the code to run.
14
+ 5. Simply run an imagenet experiment. It will auto download and do all that magic.
15
+
16
+
17
+ ## Training
18
+ There are two training files: `train.py` and `train_distributed.py`. The training code uses mixed precision. For the settings in the paper, the following command was used for distributed training:
19
+
20
+ ```
21
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp
22
+ ```
23
+
24
+ You can run the same setup on a single GPU with:
25
+ ```
26
+ python -m tasks.image_classification.train tasks.image_classification.train --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp --device 0
27
+ ```
28
+
29
+
tasks/image_classification/analysis/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analysis
2
+
3
+ This folder contains analysis code for image classifcation experiments. To build GIFs for imagenet run (from the base directory):
4
+
5
+ ```
6
+ python -m tasks.image_classification.analysis.build_imagenet_viz
7
+ ```
8
+
9
+ To build the plots in the paper run:
10
+ ```
11
+ python -m tasks.image_classification.analysis.imagenet_evaluate_and_plot
12
+ ```
tasks/image_classification/analysis/run_imagenet_analysis.py ADDED
@@ -0,0 +1,972 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- Core Libraries ---
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ import argparse
6
+ from tqdm.auto import tqdm
7
+ import torch.nn.functional as F # Used for interpolate
8
+
9
+ # --- Plotting & Visualization ---
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib as mpl
12
+ mpl.use('Agg')
13
+ import seaborn as sns
14
+ sns.set_style('darkgrid')
15
+ from matplotlib import patheffects
16
+ import seaborn as sns
17
+ import imageio
18
+ import cv2
19
+ from scipy.special import softmax
20
+ from tasks.image_classification.plotting import save_frames_to_mp4
21
+
22
+ # --- Data Handling & Model ---
23
+ from torchvision import transforms
24
+ from torchvision import datasets # Only used for CIFAR100 in debug mode
25
+ from scipy import ndimage # Used in find_island_centers
26
+ from data.custom_datasets import ImageNet
27
+ from models.ctm import ContinuousThoughtMachine
28
+ from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
29
+ from tasks.image_classification.plotting import plot_neural_dynamics
30
+
31
+ # --- Global Settings ---
32
+ np.seterr(divide='ignore')
33
+ mpl.use('Agg')
34
+ sns.set_style('darkgrid')
35
+
36
+ # --- Helper Functions ---
37
+
38
+ def find_island_centers(array_2d, threshold):
39
+ """
40
+ Finds the center of mass of each island (connected component > threshold)
41
+ in a 2D array, weighted by the array's values.
42
+ Returns list of (y, x) centers and list of areas.
43
+ """
44
+ binary_image = array_2d > threshold
45
+ labeled_image, num_labels = ndimage.label(binary_image)
46
+ centers = []
47
+ areas = []
48
+ # Calculate center of mass for each labeled island (label 0 is background)
49
+ for i in range(1, num_labels + 1):
50
+ island_mask = (labeled_image == i)
51
+ total_mass = np.sum(array_2d[island_mask])
52
+ if total_mass > 0:
53
+ # Get coordinates for this island
54
+ y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
55
+ # Calculate weighted average for center
56
+ x_center = np.average(x_coords[island_mask], weights=array_2d[island_mask])
57
+ y_center = np.average(y_coords[island_mask], weights=array_2d[island_mask])
58
+ centers.append((round(y_center, 4), round(x_center, 4)))
59
+ areas.append(np.sum(island_mask)) # Area is the count of pixels in the island
60
+ return centers, areas
61
+
62
+ def parse_args():
63
+ """Parses command-line arguments."""
64
+ # Note: Original had two ArgumentParser instances, using the second one.
65
+ parser = argparse.ArgumentParser(description="Visualize Continuous Thought Machine Attention")
66
+ parser.add_argument('--actions', type=str, nargs='+', default=['videos'], choices=['plots', 'videos', 'demo'], help="Actions to take. Plots=results plots; videos=gifs/mp4s to watch attention; demo: last frame of internal ticks")
67
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help="GPU device index or -1 for CPU")
68
+
69
+ parser.add_argument('--checkpoint', type=str, default='checkpoints/imagenet/ctm_clean.pt', help="Path to ATM checkpoint")
70
+ parser.add_argument('--output_dir', type=str, default='tasks/image_classification/analysis/outputs/imagenet_viz', help="Directory for visualization outputs")
71
+ parser.add_argument('--debug', action=argparse.BooleanOptionalAction, default=True, help='Debug mode: use CIFAR100 instead of ImageNet for debugging.')
72
+ parser.add_argument('--plot_every', type=int, default=10, help="How often to plot.")
73
+
74
+ parser.add_argument('--inference_iterations', type=int, default=50, help="Iterations to use during inference.")
75
+ parser.add_argument('--data_indices', type=int, nargs='+', default=[], help="Use specific indices in validation data for demos, otherwise random.")
76
+ parser.add_argument('--N_to_viz', type=int, default=5, help="When not supplying data_indices.")
77
+
78
+ return parser.parse_args()
79
+
80
+
81
+ # --- Main Execution Block ---
82
+ if __name__=='__main__':
83
+
84
+ # --- Setup ---
85
+ args = parse_args()
86
+ if args.device[0] != -1 and torch.cuda.is_available():
87
+ device = f'cuda:{args.device[0]}'
88
+ else:
89
+ device = 'cpu'
90
+ print(f"Using device: {device}")
91
+
92
+ # --- Load Checkpoint & Model ---
93
+ print(f"Loading checkpoint: {args.checkpoint}")
94
+ checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) # removed weights_only=False
95
+ model_args = checkpoint['args']
96
+
97
+ # Handle legacy arguments from checkpoint if necessary
98
+ if not hasattr(model_args, 'backbone_type') and hasattr(model_args, 'resnet_type'):
99
+ model_args.backbone_type = f'{model_args.resnet_type}-{getattr(model_args, "resnet_feature_scales", [4])[-1]}'
100
+ if not hasattr(model_args, 'neuron_select_type'):
101
+ model_args.neuron_select_type = 'first-last'
102
+
103
+
104
+ # Instantiate Model based on checkpoint args
105
+ print("Instantiating CTM model...")
106
+ model = ContinuousThoughtMachine(
107
+ iterations=model_args.iterations,
108
+ d_model=model_args.d_model,
109
+ d_input=model_args.d_input,
110
+ heads=model_args.heads,
111
+ n_synch_out=model_args.n_synch_out,
112
+ n_synch_action=model_args.n_synch_action,
113
+ synapse_depth=model_args.synapse_depth,
114
+ memory_length=model_args.memory_length,
115
+ deep_nlms=model_args.deep_memory,
116
+ memory_hidden_dims=model_args.memory_hidden_dims,
117
+ do_layernorm_nlm=model_args.do_normalisation,
118
+ backbone_type=model_args.backbone_type,
119
+ positional_embedding_type=model_args.positional_embedding_type,
120
+ out_dims=model_args.out_dims,
121
+ prediction_reshaper=[-1], # Kept fixed value from original code
122
+ dropout=0, # No dropout for eval
123
+ neuron_select_type=model_args.neuron_select_type,
124
+ n_random_pairing_self=model_args.n_random_pairing_self,
125
+ ).to(device)
126
+
127
+ # Load weights into model
128
+ load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
129
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
130
+ model.eval() # Set model to evaluation mode
131
+
132
+ # --- Prepare Dataset ---
133
+ if args.debug:
134
+ print("Debug mode: Using CIFAR100")
135
+ # CIFAR100 specific normalization constants
136
+ dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
137
+ dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
138
+ img_size = 256 # Resize CIFAR images for consistency
139
+ transform = transforms.Compose([
140
+ transforms.Resize(img_size),
141
+ transforms.ToTensor(),
142
+ transforms.Normalize(mean=dataset_mean, std=dataset_std), # Normalize
143
+ ])
144
+ validation_dataset = datasets.CIFAR100('data/', train=False, transform=transform, download=True)
145
+ validation_dataset_centercrop = datasets.CIFAR100('data/', train=True, transform=transform, download=True)
146
+ else:
147
+ print("Using ImageNet")
148
+ # ImageNet specific normalization constants
149
+ dataset_mean = [0.485, 0.456, 0.406]
150
+ dataset_std = [0.229, 0.224, 0.225]
151
+ img_size = 256 # Resize ImageNet images
152
+ # Note: Original comment mentioned no CenterCrop, this transform reflects that.
153
+ transform = transforms.Compose([
154
+ transforms.Resize(img_size),
155
+ transforms.ToTensor(),
156
+ transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
157
+ ])
158
+ validation_dataset = ImageNet(which_split='validation', transform=transform)
159
+ validation_dataset_centercrop = ImageNet(which_split='train', transform=transforms.Compose([
160
+ transforms.Resize(img_size),
161
+ transforms.RandomCrop(img_size),
162
+ transforms.ToTensor(),
163
+ transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
164
+ ]))
165
+ class_labels = list(IMAGENET2012_CLASSES.values()) # Load actual class names
166
+
167
+ os.makedirs(f'{args.output_dir}', exist_ok=True)
168
+
169
+ interp_mode = 'nearest'
170
+ cmap_calib = sns.color_palette('viridis', as_cmap=True)
171
+ loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False)
172
+ loader_crop = torch.utils.data.DataLoader(validation_dataset_centercrop, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
173
+
174
+ model.eval()
175
+
176
+ figscale = 0.85
177
+ topk = 5
178
+ mean_certainties_correct, mean_certainties_incorrect = [],[]
179
+ tracked_certainties = []
180
+ tracked_targets = []
181
+ tracked_predictions = []
182
+
183
+ if model.iterations != args.inference_iterations:
184
+ print('WARNING: you are setting inference iterations to a value not used during training!')
185
+
186
+ model.iterations = args.inference_iterations
187
+
188
+ if 'plots' in args.actions:
189
+
190
+ with torch.inference_mode(): # Disable gradient calculations
191
+ with tqdm(total=len(loader), initial=0, leave=False, position=0, dynamic_ncols=True) as pbar:
192
+ imgi = 0
193
+ for bi, (inputs, targets) in enumerate(loader):
194
+ inputs = inputs.to(device)
195
+ targets = targets.to(device)
196
+ if bi==0:
197
+ dynamics_inputs, _ = next(iter(loader_crop)) # Use this because of batching
198
+ _, _, _, _, post_activations_viz, _ = model(inputs, track=True)
199
+ plot_neural_dynamics(post_activations_viz, 15*10, args.output_dir, axis_snap=True, N_per_row=15)
200
+ predictions, certainties, synchronisation = model(inputs)
201
+
202
+ tracked_predictions.append(predictions.detach().cpu().numpy())
203
+ tracked_targets.append(targets.detach().cpu().numpy())
204
+ tracked_certainties.append(certainties.detach().cpu().numpy())
205
+
206
+
207
+
208
+
209
+ pbar.set_description(f'Processing base image of size {inputs.shape}')
210
+ pbar.update(1)
211
+ if ((bi % args.plot_every == 0) or bi == len(loader)-1) and bi!=0: #
212
+
213
+ concatenated_certainties = np.concatenate(tracked_certainties, axis=0)
214
+ concatenated_targets = np.concatenate(tracked_targets, axis=0)
215
+ concatenated_predictions = np.concatenate(tracked_predictions, axis=0)
216
+ concatenated_predictions_argsorted = np.argsort(concatenated_predictions, 1)[:,::-1]
217
+
218
+
219
+
220
+ for topk in [1, 5]:
221
+ concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
222
+
223
+ accs_instant, accs_avg, accs_certain = [], [], []
224
+ accs_avg_logits, accs_weighted_logits = [],[]
225
+ with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
226
+ pbarinner.set_description('Acc types')
227
+ for stepi in np.arange(concatenated_predictions.shape[-1]):
228
+ pred_avg = softmax(concatenated_predictions, 1)[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
229
+ pred_instant = concatenated_predictions_argsorted_topk[:,:,stepi]
230
+ pred_certain = concatenated_predictions_argsorted_topk[np.arange(concatenated_predictions.shape[0]),:, concatenated_certainties[:,1,:stepi+1].argmax(1)]
231
+ pred_avg_logits = concatenated_predictions[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
232
+ pred_weighted_logits = (concatenated_predictions[:,:,:stepi+1] * concatenated_certainties[:,1:,:stepi+1]).sum(-1).argsort(1)[:, -topk:]
233
+ pbarinner.update(1)
234
+ accs_instant.append(np.any(pred_instant==concatenated_targets[...,np.newaxis], -1).mean())
235
+ accs_avg.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
236
+ accs_avg_logits.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
237
+ accs_weighted_logits.append(np.any(pred_weighted_logits==concatenated_targets[...,np.newaxis], -1).mean())
238
+ accs_certain.append(np.any(pred_avg_logits==concatenated_targets[...,np.newaxis], -1).mean())
239
+ fig = plt.figure(figsize=(10*figscale, 4*figscale))
240
+ ax = fig.add_subplot(111)
241
+ cp = sns.color_palette("bright")
242
+ ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_instant), linestyle='-', color=cp[0], label='Instant')
243
+ # ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg), linestyle='--', color=cp[1], label='Based on average probability up to this step')
244
+ ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_certain), linestyle=':', color=cp[2], label='Most certain')
245
+ ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg_logits), linestyle='-.', color=cp[3], label='Average logits')
246
+ ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_weighted_logits), linestyle='--', color=cp[4], label='Logits weighted by certainty')
247
+ ax.set_xlim([0, concatenated_predictions.shape[-1]+1])
248
+ ax.set_ylim([75, 92])
249
+ ax.set_xlabel('Internal ticks')
250
+ ax.set_ylabel(f'Top-k={topk} accuracy')
251
+ ax.legend(loc='lower right')
252
+ fig.tight_layout(pad=0.1)
253
+ fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.png', dpi=200)
254
+ fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.pdf', dpi=200)
255
+ plt.close(fig)
256
+ print(f'k={topk}. Accuracy most certain at last internal tick={100*np.array(accs_certain)[-1]:0.4f}') # Using certainty based approach
257
+
258
+
259
+ indices_over_80 = []
260
+ classes_80 = {}
261
+ corrects_80 = {}
262
+
263
+ topk = 5
264
+ concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
265
+ for certainty_threshold in [0.5, 0.8, 0.9]:
266
+ # certainty_threshold = 0.6
267
+ percentage_corrects = []
268
+ percentage_incorrects = []
269
+ with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
270
+ pbarinner.set_description(f'Certainty threshold={certainty_threshold}')
271
+ for stepi in np.arange(concatenated_predictions.shape[-1]):
272
+ certainty_here = concatenated_certainties[:,1,stepi]
273
+ certainty_mask = certainty_here>=certainty_threshold
274
+ predictions_here = concatenated_predictions_argsorted_topk[:,:,stepi]
275
+ is_correct_here = np.any(predictions_here==concatenated_targets[...,np.newaxis], axis=-1)
276
+ percentage_corrects.append(is_correct_here[certainty_mask].sum()/predictions_here.shape[0])
277
+ percentage_incorrects.append((~is_correct_here)[certainty_mask].sum()/predictions_here.shape[0])
278
+
279
+ if certainty_threshold==0.8:
280
+ indices_certain = np.where(certainty_mask)[0]
281
+ for index in indices_certain:
282
+ if index not in indices_over_80:
283
+ indices_over_80.append(index)
284
+ if concatenated_targets[index] not in classes_80:
285
+ classes_80[concatenated_targets[index]] = [stepi]
286
+ corrects_80[concatenated_targets[index]] = [is_correct_here[index]]
287
+ else:
288
+ classes_80[concatenated_targets[index]] = classes_80[concatenated_targets[index]]+[stepi]
289
+ corrects_80[concatenated_targets[index]] = corrects_80[concatenated_targets[index]]+[is_correct_here[index]]
290
+
291
+
292
+ pbarinner.update(1)
293
+ fig = plt.figure(figsize=(6.5*figscale, 4*figscale))
294
+ ax = fig.add_subplot(111)
295
+ ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
296
+ percentage_corrects,
297
+ color='forestgreen',
298
+ hatch='OO',
299
+ width=0.9,
300
+ label='Positive',
301
+ alpha=0.9,
302
+ linewidth=1.0*figscale)
303
+
304
+ ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
305
+ percentage_incorrects,
306
+ bottom=percentage_corrects,
307
+ color='crimson',
308
+ hatch='xx',
309
+ width=0.9,
310
+ label='Negative',
311
+ alpha=0.9,
312
+ linewidth=1.0*figscale)
313
+ ax.set_xlim(-1, concatenated_predictions.shape[-1]+1)
314
+ ax.set_xlabel('Internal tick')
315
+ ax.set_ylabel('% of data')
316
+ ax.legend(loc='lower right')
317
+
318
+
319
+ fig.tight_layout(pad=0.1)
320
+ fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.png', dpi=200)
321
+ fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.pdf', dpi=200)
322
+ plt.close(fig)
323
+
324
+
325
+ class_list = list(classes_80.keys())
326
+ mean_steps = [np.mean(classes_80[cls]) for cls in class_list]
327
+ std_steps = [np.std(classes_80[cls]) for cls in class_list]
328
+
329
+
330
+ # Following code plots the class distribution over internal ticks
331
+ indices_to_show = np.arange(1000)
332
+
333
+ colours = cmap_diverse = plt.get_cmap('rainbow')(np.linspace(0, 1, 1000))
334
+ # np.random.shuffle(colours)
335
+ bottom = np.zeros(concatenated_predictions.shape[-1])
336
+
337
+ fig = plt.figure(figsize=(7*figscale, 4*figscale))
338
+ ax = fig.add_subplot(111)
339
+ for iii, idx in enumerate(indices_to_show):
340
+ if idx in classes_80:
341
+ steps = classes_80[idx]
342
+ colour = colours[iii]
343
+ vs, cts = np.unique(steps, return_counts=True)
344
+
345
+ bar = np.zeros(concatenated_predictions.shape[-1])
346
+ bar[vs] = cts
347
+ ax.bar(np.arange(concatenated_predictions.shape[-1])+1, bar, bottom=bottom, color=colour, width=1, edgecolor='none')
348
+ bottom += bar
349
+ ax.set_xlabel('Internal ticks')
350
+ ax.set_ylabel('Counts over 0.8 certainty')
351
+ fig.tight_layout(pad=0.1)
352
+ fig.savefig(f'{args.output_dir}/class_counts.png', dpi=200)
353
+ fig.savefig(f'{args.output_dir}/class_counts.pdf', dpi=200)
354
+ plt.close(fig)
355
+
356
+
357
+
358
+
359
+
360
+ # The following code plots calibration
361
+ probability_space = np.linspace(0, 1, 10)
362
+ fig = plt.figure(figsize=(6*figscale, 4*figscale))
363
+ ax = fig.add_subplot(111)
364
+
365
+
366
+ color_linspace = np.linspace(0, 1, concatenated_predictions.shape[-1])
367
+ with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
368
+ pbarinner.set_description(f'Calibration')
369
+ for stepi in np.arange(concatenated_predictions.shape[-1]):
370
+ color = cmap_calib(color_linspace[stepi])
371
+ pred = concatenated_predictions[:,:,stepi].argmax(1)
372
+ is_correct = pred == concatenated_targets # BxT
373
+ probabilities = softmax(concatenated_predictions[:,:,:stepi+1], axis=1)[np.arange(concatenated_predictions.shape[0]),pred].mean(-1)#softmax(concatenated_predictions[:,:,stepi], axis=1).max(1)
374
+ probability_space = np.linspace(0, 1, 10)
375
+ accuracies_per_bin = []
376
+ bin_centers = []
377
+ for pi in range(len(probability_space)-1):
378
+ bin_low = probability_space[pi]
379
+ bin_high = probability_space[pi+1]
380
+ mask = ((probabilities >=bin_low) & (probabilities < bin_high)) if pi !=len(probability_space)-2 else ((probabilities >=bin_low) & (probabilities <= bin_high))
381
+ accuracies_per_bin.append(is_correct[mask].mean())
382
+ bin_centers.append(probabilities[mask].mean())
383
+
384
+
385
+ if stepi==concatenated_predictions.shape[-1]-1:
386
+ ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color='#4050f7', alpha=1, label='After all ticks')
387
+ else: ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color=color, alpha=0.65)
388
+ pbarinner.update(1)
389
+ ax.plot(probability_space, np.linspace(0, 1, len(probability_space)), 'k--')
390
+
391
+ ax.legend(loc='upper left')
392
+ ax.set_xlim([-0.01, 1.01])
393
+ ax.set_ylim([-0.01, 1.01])
394
+
395
+ sm = plt.cm.ScalarMappable(cmap=cmap_calib, norm=plt.Normalize(vmin=0, vmax=concatenated_predictions.shape[-1] - 1))
396
+ sm.set_array([]) # Empty array for colormap
397
+ cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.02)
398
+ cbar.set_label('Internal ticks')
399
+
400
+ ax.set_xlabel('Mean predicted probabilities')
401
+ ax.set_ylabel('Ratio of positives')
402
+ fig.tight_layout(pad=0.1)
403
+ fig.savefig(f'{args.output_dir}/imagenet_calibration.png', dpi=200)
404
+ fig.savefig(f'{args.output_dir}/imagenet_calibration.pdf', dpi=200)
405
+ plt.close(fig)
406
+ if 'videos' in args.actions:
407
+ if not args.data_indices: # If list is empty
408
+ n_samples = len(validation_dataset)
409
+ num_to_sample = min(args.N_to_viz, n_samples)
410
+ replace = n_samples < num_to_sample
411
+ data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
412
+ print(f"Selected random indices: {data_indices}")
413
+ else:
414
+ data_indices = args.data_indices
415
+ print(f"Using specified indices: {data_indices}")
416
+
417
+
418
+ for di in data_indices:
419
+ print(f'\nBuilding viz for dataset index {di}.')
420
+
421
+ # --- Get Data & Run Inference ---
422
+ # inputs_norm is already normalized by the transform
423
+ inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
424
+
425
+ # Add batch dimension and send to device
426
+ inputs = inputs.to(device).unsqueeze(0)
427
+
428
+ # Run model inference
429
+ predictions, certainties, synchronisation, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
430
+ # predictions: (B, Classes, Steps), attention_tracking: (Steps*B*Heads, SeqLen)
431
+ n_steps = predictions.size(-1)
432
+
433
+ # --- Reshape Attention ---
434
+ # Infer feature map size from model internals (assuming B=1)
435
+ h_feat, w_feat = model.kv_features.shape[-2:]
436
+
437
+ n_heads = attention_tracking.shape[2]
438
+ # Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
439
+ attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
440
+
441
+ # --- Setup for Plotting ---
442
+ step_linspace = np.linspace(0, 1, n_steps) # For step colors
443
+ # Define color maps
444
+ cmap_spectral = sns.color_palette("Spectral", as_cmap=True)
445
+ cmap_attention = sns.color_palette('viridis', as_cmap=True)
446
+
447
+ # Create output directory for this index
448
+ index_output_dir = os.path.join(args.output_dir, str(di))
449
+ os.makedirs(index_output_dir, exist_ok=True)
450
+
451
+ frames = [] # Store frames for GIF
452
+ head_routes = {h: [] for h in range(n_heads)} # Store (y,x) path points per head
453
+ head_routes[-1] = []
454
+ route_colours_step = [] # Store colors for each step's path segments
455
+
456
+ # --- Loop Through Each Step ---
457
+ for step_i in range(n_steps):
458
+
459
+ # --- Prepare Image for Display ---
460
+ # Denormalize the input tensor for visualization
461
+ data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
462
+ mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
463
+ std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
464
+ data_img_denorm = data_img_tensor * std_tensor + mean_tensor
465
+ # Permute to (H, W, C) and convert to numpy, clip to [0, 1]
466
+ data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
467
+ data_img_np = np.clip(data_img_np, 0, 1)
468
+ img_h, img_w = data_img_np.shape[:2]
469
+
470
+ # --- Process Attention & Certainty ---
471
+ # Average attention over last few steps (from original code)
472
+ start_step = max(0, step_i - 5)
473
+ attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
474
+ # Get certainties up to current step
475
+ certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
476
+
477
+ # --- Calculate Attention Paths (using bilinear interp) ---
478
+ # Interpolate attention to image size using bilinear for center finding
479
+ attention_interp_bilinear = F.interpolate(
480
+ torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
481
+ size=(img_h, img_w),
482
+ mode=interp_mode,
483
+ # align_corners=False
484
+ ).squeeze(0) # Remove batch dim -> (Heads, H, W)
485
+
486
+ # Normalize each head's map to [0, 1]
487
+ # Deal with mean
488
+ attn_mean = attention_interp_bilinear.mean(0)
489
+ attn_mean_min = attn_mean.min()
490
+ attn_mean_max = attn_mean.max()
491
+ attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
492
+ centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
493
+
494
+ if centers: # If islands found
495
+ largest_island_idx = np.argmax(areas)
496
+ current_center = centers[largest_island_idx] # (y, x)
497
+ head_routes[-1].append(current_center)
498
+ elif head_routes[-1]: # If no center now, repeat last known center if history exists
499
+ head_routes[-1].append(head_routes[-1][-1])
500
+
501
+
502
+ attn_min = attention_interp_bilinear.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
503
+ attn_max = attention_interp_bilinear.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
504
+ attention_interp_bilinear = (attention_interp_bilinear - attn_min) / (attn_max - attn_min + 1e-6)
505
+
506
+ # Store step color
507
+ current_colour = list(cmap_spectral(step_linspace[step_i]))
508
+ route_colours_step.append(current_colour)
509
+
510
+ # Find island center for each head
511
+ for head_i in range(n_heads):
512
+ attn_head_np = attention_interp_bilinear[head_i].detach().cpu().numpy()
513
+ # Keep threshold=0.7 based on original call
514
+ centers, areas = find_island_centers(attn_head_np, threshold=0.7)
515
+
516
+ if centers: # If islands found
517
+ largest_island_idx = np.argmax(areas)
518
+ current_center = centers[largest_island_idx] # (y, x)
519
+ head_routes[head_i].append(current_center)
520
+ elif head_routes[head_i]: # If no center now, repeat last known center if history exists
521
+ head_routes[head_i].append(head_routes[head_i][-1])
522
+
523
+
524
+
525
+ # --- Plotting Setup ---
526
+ mosaic = [['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
527
+ ['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
528
+ ['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
529
+ ['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
530
+ ['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay', 'head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],
531
+ ['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay','head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],
532
+ ['head_8', 'head_8_overlay', 'head_9', 'head_9_overlay','head_10', 'head_10_overlay', 'head_11', 'head_11_overlay'],
533
+ ['head_12', 'head_12_overlay', 'head_13', 'head_13_overlay','head_14', 'head_14_overlay', 'head_15', 'head_15_overlay'],
534
+ ['probabilities', 'probabilities','probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty'],
535
+ ]
536
+
537
+ img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
538
+ aspect_ratio = (8 * figscale, 9 * figscale * img_aspect) # W, H
539
+ fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
540
+
541
+ for ax in axes.values():
542
+ ax.axis('off')
543
+
544
+ # --- Plot Certainty ---
545
+ ax_cert = axes['certainty']
546
+ ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
547
+ # Add background color based on prediction correctness at each step
548
+ for ii in range(len(certainties_now)):
549
+ is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
550
+ facecolor = 'limegreen' if is_correct else 'orchid'
551
+ ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
552
+ # Mark the last point
553
+ ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
554
+ ax_cert.axis('off')
555
+ ax_cert.set_ylim([0.05, 1.05])
556
+ ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
557
+
558
+ # --- Plot Probabilities ---
559
+ ax_prob = axes['probabilities']
560
+ # Get probabilities for the current step
561
+ ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
562
+ k = 15 # Top k predictions
563
+ topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
564
+ topk_indices = topk_indices.numpy()
565
+ topk_probs = topk_probs.numpy()
566
+
567
+ top_classes = np.array(class_labels)[topk_indices]
568
+ true_class_idx = ground_truth_target # Ground truth index
569
+
570
+ # Determine bar colors (green if correct, blue otherwise - consistent with original)
571
+ colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
572
+
573
+ # Plot horizontal bars (inverted range for top-down display)
574
+ ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
575
+ ax_prob.set_xlim([0, 1])
576
+ ax_prob.axis('off')
577
+
578
+ # Add text labels for top classes
579
+ for i, name_idx in enumerate(topk_indices):
580
+ name = class_labels[name_idx] # Get name from index
581
+ is_correct = name_idx == true_class_idx
582
+ fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
583
+ text_str = f'{name[:40]}' # Truncate long names
584
+ # Position text on the left side of the horizontal bars
585
+ ax_prob.text(
586
+ 0.01, # Small offset from left edge
587
+ k - 1 - i, # Y-position corresponding to the bar
588
+ text_str,
589
+ #transform=ax_prob.transAxes, # Use data coordinates for Y
590
+ verticalalignment='center',
591
+ horizontalalignment='left',
592
+ fontsize=8,
593
+ color=fg_color,
594
+ alpha=0.9, # Slightly more visible than 0.5
595
+ path_effects=[
596
+ patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
597
+ patheffects.Normal()
598
+ ])
599
+
600
+
601
+ # --- Plot Attention Heads & Overlays (using nearest interp) ---
602
+ # Re-interpolate attention using nearest neighbor for visual plotting
603
+ attention_interp_plot = F.interpolate(
604
+ torch.from_numpy(attention_now).unsqueeze(0).float(),
605
+ size=(img_h, img_w),
606
+ mode=interp_mode, # 'nearest'
607
+ ).squeeze(0)
608
+
609
+ attn_mean = attention_interp_plot.mean(0)
610
+ attn_mean_min = attn_mean.min()
611
+ attn_mean_max = attn_mean.max()
612
+ attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
613
+
614
+
615
+ # Normalize each head's map to [0, 1]
616
+ attn_min_plot = attention_interp_plot.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
617
+ attn_max_plot = attention_interp_plot.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
618
+ attention_interp_plot = (attention_interp_plot - attn_min_plot) / (attn_max_plot - attn_min_plot + 1e-6)
619
+ attention_interp_plot_np = attention_interp_plot.detach().cpu().numpy()
620
+
621
+
622
+
623
+
624
+
625
+
626
+ for head_i in list(range(n_heads)) + [-1]:
627
+ axname = f'head_{head_i}' if head_i != -1 else 'head_mean'
628
+ if axname not in axes: continue # Skip if mosaic doesn't have this head
629
+
630
+ ax = axes[axname]
631
+ ax_overlay = axes[f'{axname}_overlay']
632
+
633
+ # Plot attention heatmap
634
+ this_attn = attention_interp_plot_np[head_i] if head_i != -1 else attn_mean
635
+ img_to_plot = cmap_attention(this_attn)
636
+ ax.imshow(img_to_plot)
637
+ ax.axis('off')
638
+
639
+ # Plot overlay: image + paths
640
+ these_route_steps = head_routes[head_i]
641
+ arrow_scale = 1.5 if head_i != -1 else 3
642
+
643
+ if these_route_steps: # Only plot if path exists
644
+ # Separate y and x coordinates
645
+ y_coords, x_coords = zip(*these_route_steps)
646
+ y_coords = np.array(y_coords)
647
+ x_coords = np.array(x_coords)
648
+
649
+ # Flip y-coordinates for correct plotting (imshow origin is top-left)
650
+ # NOTE: Original flip seemed complex, simplifying to standard flip
651
+ y_coords_flipped = img_h - 1 - y_coords
652
+
653
+ # Show original image flipped vertically to match coordinate system
654
+ ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
655
+
656
+ # Draw arrows for path segments
657
+ # Arrow size scaling from original
658
+ for i in range(len(these_route_steps) - 1):
659
+ dx = x_coords[i+1] - x_coords[i]
660
+ dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
661
+
662
+ # Draw white background arrow (thicker)
663
+ ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
664
+ linewidth=1.6 * arrow_scale * 1.3,
665
+ head_width=1.9 * arrow_scale * 1.3,
666
+ head_length=1.4 * arrow_scale * 1.45,
667
+ fc='white', ec='white', length_includes_head=True, alpha=1)
668
+ # Draw colored foreground arrow
669
+ ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
670
+ linewidth=1.6 * arrow_scale,
671
+ head_width=1.9 * arrow_scale,
672
+ head_length=1.4 * arrow_scale,
673
+ fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
674
+ length_includes_head=True)
675
+
676
+ else: # If no path yet, just show the image
677
+ ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
678
+
679
+
680
+ # Set limits and turn off axes for overlay
681
+ ax_overlay.set_xlim([0, img_w - 1])
682
+ ax_overlay.set_ylim([0, img_h - 1])
683
+ ax_overlay.axis('off')
684
+
685
+
686
+ # --- Finalize and Save Frame ---
687
+ fig.tight_layout(pad=0.1) # Adjust spacing
688
+
689
+ # Render the plot to a numpy array
690
+ canvas = fig.canvas
691
+ canvas.draw()
692
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
693
+ image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
694
+
695
+ frames.append(image_numpy) # Add to list for GIF
696
+
697
+
698
+
699
+ plt.close(fig) # Close figure to free memory
700
+
701
+ # --- Save GIF ---
702
+ gif_path = os.path.join(index_output_dir, f'{str(di)}_viz.gif')
703
+ print(f"Saving GIF to {gif_path}...")
704
+ imageio.mimsave(gif_path, frames, fps=15, loop=0) # loop=0 means infinite loop
705
+ save_frames_to_mp4([fm[:,:,::-1] for fm in frames], os.path.join(index_output_dir, f'{str(di)}_viz.mp4'), fps=15, gop_size=1, preset='veryslow')
706
+ if 'demo' in args.actions:
707
+
708
+
709
+
710
+ # --- Select Data Indices ---
711
+ if not args.data_indices: # If list is empty
712
+ n_samples = len(validation_dataset)
713
+ num_to_sample = min(args.N_to_viz, n_samples)
714
+ replace = n_samples < num_to_sample
715
+ data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
716
+ print(f"Selected random indices: {data_indices}")
717
+ else:
718
+ data_indices = args.data_indices
719
+ print(f"Using specified indices: {data_indices}")
720
+
721
+
722
+ for di in data_indices:
723
+
724
+ index_output_dir = os.path.join(args.output_dir, str(di))
725
+ os.makedirs(index_output_dir, exist_ok=True)
726
+
727
+ print(f'\nBuilding viz for dataset index {di}.')
728
+
729
+ inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
730
+
731
+ # Add batch dimension and send to device
732
+ inputs = inputs.to(device).unsqueeze(0)
733
+ predictions, certainties, synchronisations_over_time, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
734
+
735
+ # --- Reshape Attention ---
736
+ # Infer feature map size from model internals (assuming B=1)
737
+ h_feat, w_feat = model.kv_features.shape[-2:]
738
+ n_steps = predictions.size(-1)
739
+ n_heads = attention_tracking.shape[2]
740
+ # Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
741
+ attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
742
+
743
+ # --- Setup for Plotting ---
744
+ step_linspace = np.linspace(0, 1, n_steps) # For step colors
745
+ # Define color maps
746
+ cmap_steps = sns.color_palette("Spectral", as_cmap=True)
747
+ cmap_attention = sns.color_palette('viridis', as_cmap=True)
748
+
749
+ # Create output directory for this index
750
+
751
+
752
+ frames = [] # Store frames for GIF
753
+ head_routes = [] # Store (y,x) path points per head
754
+ route_colours_step = [] # Store colors for each step's path segments
755
+
756
+ # --- Loop Through Each Step ---
757
+ for step_i in range(n_steps):
758
+
759
+ # Store step color
760
+ current_colour = list(cmap_steps(step_linspace[step_i]))
761
+ route_colours_step.append(current_colour)
762
+
763
+ # --- Prepare Image for Display ---
764
+ # Denormalize the input tensor for visualization
765
+ data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
766
+ mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
767
+ std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
768
+ data_img_denorm = data_img_tensor * std_tensor + mean_tensor
769
+ # Permute to (H, W, C) and convert to numpy, clip to [0, 1]
770
+ data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
771
+ data_img_np = np.clip(data_img_np, 0, 1)
772
+ img_h, img_w = data_img_np.shape[:2]
773
+
774
+ # --- Process Attention & Certainty ---
775
+ # Average attention over last few steps (from original code)
776
+ start_step = max(0, step_i - 5)
777
+ attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
778
+ # Get certainties up to current step
779
+ certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
780
+
781
+ # --- Calculate Attention Paths (using bilinear interp) ---
782
+ # Interpolate attention to image size using bilinear for center finding
783
+ attention_interp_bilinear = F.interpolate(
784
+ torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
785
+ size=(img_h, img_w),
786
+ mode=interp_mode,
787
+ ).squeeze(0) # Remove batch dim -> (Heads, H, W)
788
+
789
+ attn_mean = attention_interp_bilinear.mean(0)
790
+ attn_mean_min = attn_mean.min()
791
+ attn_mean_max = attn_mean.max()
792
+ attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
793
+ centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
794
+
795
+ if centers: # If islands found
796
+ largest_island_idx = np.argmax(areas)
797
+ current_center = centers[largest_island_idx] # (y, x)
798
+ head_routes.append(current_center)
799
+ elif head_routes: # If no center now, repeat last known center if history exists
800
+ head_routes.append(head_routes[-1])
801
+
802
+ # --- Plotting Setup ---
803
+ # if n_heads != 8: print(f"Warning: Plotting layout assumes 8 heads, found {n_heads}. Layout may be incorrect.")
804
+ mosaic = [['head_0', 'head_1', 'head_2', 'head_3', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
805
+ ['head_4', 'head_5', 'head_6', 'head_7', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
806
+ ['head_8', 'head_9', 'head_10', 'head_11', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
807
+ ['head_12', 'head_13', 'head_14', 'head_15', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
808
+ ['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
809
+ ['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
810
+ ]
811
+
812
+ img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
813
+ aspect_ratio = (12 * figscale, 6 * figscale * img_aspect) # W, H
814
+ fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
815
+ for ax in axes.values():
816
+ ax.axis('off')
817
+
818
+ # --- Plot Certainty ---
819
+ ax_cert = axes['certainty']
820
+ ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
821
+ # Add background color based on prediction correctness at each step
822
+ for ii in range(len(certainties_now)):
823
+ is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
824
+ facecolor = 'limegreen' if is_correct else 'orchid'
825
+ ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
826
+ # Mark the last point
827
+ ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
828
+ ax_cert.axis('off')
829
+ ax_cert.set_ylim([0.05, 1.05])
830
+ ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
831
+
832
+ # --- Plot Probabilities ---
833
+ ax_prob = axes['probabilities']
834
+ # Get probabilities for the current step
835
+ ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
836
+ k = 15 # Top k predictions
837
+ topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
838
+ topk_indices = topk_indices.numpy()
839
+ topk_probs = topk_probs.numpy()
840
+
841
+ top_classes = np.array(class_labels)[topk_indices]
842
+ true_class_idx = ground_truth_target # Ground truth index
843
+
844
+ # Determine bar colors (green if correct, blue otherwise - consistent with original)
845
+ colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
846
+
847
+ # Plot horizontal bars (inverted range for top-down display)
848
+ ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
849
+ ax_prob.set_xlim([0, 1])
850
+ ax_prob.axis('off')
851
+
852
+ # Add text labels for top classes
853
+ for i, name_idx in enumerate(topk_indices):
854
+ name = class_labels[name_idx] # Get name from index
855
+ is_correct = name_idx == true_class_idx
856
+ fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
857
+ text_str = f'{name[:40]}' # Truncate long names
858
+ # Position text on the left side of the horizontal bars
859
+ ax_prob.text(
860
+ 0.01, # Small offset from left edge
861
+ k - 1 - i, # Y-position corresponding to the bar
862
+ text_str,
863
+ #transform=ax_prob.transAxes, # Use data coordinates for Y
864
+ verticalalignment='center',
865
+ horizontalalignment='left',
866
+ fontsize=8,
867
+ color=fg_color,
868
+ alpha=0.7, # Slightly more visible than 0.5
869
+ path_effects=[
870
+ patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
871
+ patheffects.Normal()
872
+ ])
873
+
874
+
875
+ # --- Plot Attention Heads & Overlays (using nearest interp) ---
876
+ # Re-interpolate attention using nearest neighbor for visual plotting
877
+ attention_interp_plot = F.interpolate(
878
+ torch.from_numpy(attention_now).unsqueeze(0).float(),
879
+ size=(img_h, img_w),
880
+ mode=interp_mode # 'nearest'
881
+ ).squeeze(0)
882
+
883
+
884
+ attn_mean = attention_interp_plot.mean(0)
885
+ attn_mean_min = attn_mean.min()
886
+ attn_mean_max = attn_mean.max()
887
+ attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
888
+
889
+
890
+ img_to_plot = cmap_attention(attn_mean)
891
+ axes['head_mean'].imshow(img_to_plot)
892
+ axes['head_mean'].axis('off')
893
+
894
+
895
+ these_route_steps = head_routes
896
+ ax_overlay = axes['overlay']
897
+
898
+ if these_route_steps: # Only plot if path exists
899
+ # Separate y and x coordinates
900
+ y_coords, x_coords = zip(*these_route_steps)
901
+ y_coords = np.array(y_coords)
902
+ x_coords = np.array(x_coords)
903
+
904
+ # Flip y-coordinates for correct plotting (imshow origin is top-left)
905
+ # NOTE: Original flip seemed complex, simplifying to standard flip
906
+ y_coords_flipped = img_h - 1 - y_coords
907
+
908
+ # Show original image flipped vertically to match coordinate system
909
+ ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
910
+
911
+ # Draw arrows for path segments
912
+ arrow_scale = 2 # Arrow size scaling from original
913
+ for i in range(len(these_route_steps) - 1):
914
+ dx = x_coords[i+1] - x_coords[i]
915
+ dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
916
+
917
+ # Draw white background arrow (thicker)
918
+ ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
919
+ linewidth=1.6 * arrow_scale * 1.3,
920
+ head_width=1.9 * arrow_scale * 1.3,
921
+ head_length=1.4 * arrow_scale * 1.45,
922
+ fc='white', ec='white', length_includes_head=True, alpha=1)
923
+ # Draw colored foreground arrow
924
+ ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
925
+ linewidth=1.6 * arrow_scale,
926
+ head_width=1.9 * arrow_scale,
927
+ head_length=1.4 * arrow_scale,
928
+ fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
929
+ length_includes_head=True)
930
+ # Set limits and turn off axes for overlay
931
+ ax_overlay.set_xlim([0, img_w - 1])
932
+ ax_overlay.set_ylim([0, img_h - 1])
933
+ ax_overlay.axis('off')
934
+
935
+
936
+ for head_i in range(n_heads):
937
+ if f'head_{head_i}' not in axes: continue # Skip if mosaic doesn't have this head
938
+
939
+ ax = axes[f'head_{head_i}']
940
+
941
+ # Plot attention heatmap
942
+ attn_up_to_now = attention_tracking[:step_i + 1, head_i].mean(0)
943
+ attn_up_to_now = (attn_up_to_now - attn_up_to_now.min())/(attn_up_to_now.max() - attn_up_to_now.min())
944
+ img_to_plot = cmap_attention(attn_up_to_now)
945
+ ax.imshow(img_to_plot)
946
+ ax.axis('off')
947
+
948
+
949
+
950
+
951
+
952
+
953
+ # --- Finalize and Save Frame ---
954
+ fig.tight_layout(pad=0.1) # Adjust spacing
955
+
956
+ # Render the plot to a numpy array
957
+ canvas = fig.canvas
958
+ canvas.draw()
959
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
960
+ image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
961
+
962
+ frames.append(image_numpy) # Add to list for GIF
963
+
964
+ # Save individual frame if requested
965
+ if step_i==model.iterations-1:
966
+ fig.savefig(os.path.join(index_output_dir, f'frame_{step_i}.png'), dpi=200)
967
+
968
+ plt.close(fig) # Close figure to free memory
969
+ outfilename = os.path.join(index_output_dir, f'{di}_demo.mp4')
970
+ save_frames_to_mp4([fm[:,:,::-1] for fm in frames], outfilename, fps=15, gop_size=1, preset='veryslow')
971
+
972
+
tasks/image_classification/imagenet_classes.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+
4
+ IMAGENET2012_CLASSES = OrderedDict(
5
+ {
6
+ "n01440764": "tench, Tinca tinca",
7
+ "n01443537": "goldfish, Carassius auratus",
8
+ "n01484850": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
9
+ "n01491361": "tiger shark, Galeocerdo cuvieri",
10
+ "n01494475": "hammerhead, hammerhead shark",
11
+ "n01496331": "electric ray, crampfish, numbfish, torpedo",
12
+ "n01498041": "stingray",
13
+ "n01514668": "cock",
14
+ "n01514859": "hen",
15
+ "n01518878": "ostrich, Struthio camelus",
16
+ "n01530575": "brambling, Fringilla montifringilla",
17
+ "n01531178": "goldfinch, Carduelis carduelis",
18
+ "n01532829": "house finch, linnet, Carpodacus mexicanus",
19
+ "n01534433": "junco, snowbird",
20
+ "n01537544": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
21
+ "n01558993": "robin, American robin, Turdus migratorius",
22
+ "n01560419": "bulbul",
23
+ "n01580077": "jay",
24
+ "n01582220": "magpie",
25
+ "n01592084": "chickadee",
26
+ "n01601694": "water ouzel, dipper",
27
+ "n01608432": "kite",
28
+ "n01614925": "bald eagle, American eagle, Haliaeetus leucocephalus",
29
+ "n01616318": "vulture",
30
+ "n01622779": "great grey owl, great gray owl, Strix nebulosa",
31
+ "n01629819": "European fire salamander, Salamandra salamandra",
32
+ "n01630670": "common newt, Triturus vulgaris",
33
+ "n01631663": "eft",
34
+ "n01632458": "spotted salamander, Ambystoma maculatum",
35
+ "n01632777": "axolotl, mud puppy, Ambystoma mexicanum",
36
+ "n01641577": "bullfrog, Rana catesbeiana",
37
+ "n01644373": "tree frog, tree-frog",
38
+ "n01644900": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
39
+ "n01664065": "loggerhead, loggerhead turtle, Caretta caretta",
40
+ "n01665541": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
41
+ "n01667114": "mud turtle",
42
+ "n01667778": "terrapin",
43
+ "n01669191": "box turtle, box tortoise",
44
+ "n01675722": "banded gecko",
45
+ "n01677366": "common iguana, iguana, Iguana iguana",
46
+ "n01682714": "American chameleon, anole, Anolis carolinensis",
47
+ "n01685808": "whiptail, whiptail lizard",
48
+ "n01687978": "agama",
49
+ "n01688243": "frilled lizard, Chlamydosaurus kingi",
50
+ "n01689811": "alligator lizard",
51
+ "n01692333": "Gila monster, Heloderma suspectum",
52
+ "n01693334": "green lizard, Lacerta viridis",
53
+ "n01694178": "African chameleon, Chamaeleo chamaeleon",
54
+ "n01695060": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
55
+ "n01697457": "African crocodile, Nile crocodile, Crocodylus niloticus",
56
+ "n01698640": "American alligator, Alligator mississipiensis",
57
+ "n01704323": "triceratops",
58
+ "n01728572": "thunder snake, worm snake, Carphophis amoenus",
59
+ "n01728920": "ringneck snake, ring-necked snake, ring snake",
60
+ "n01729322": "hognose snake, puff adder, sand viper",
61
+ "n01729977": "green snake, grass snake",
62
+ "n01734418": "king snake, kingsnake",
63
+ "n01735189": "garter snake, grass snake",
64
+ "n01737021": "water snake",
65
+ "n01739381": "vine snake",
66
+ "n01740131": "night snake, Hypsiglena torquata",
67
+ "n01742172": "boa constrictor, Constrictor constrictor",
68
+ "n01744401": "rock python, rock snake, Python sebae",
69
+ "n01748264": "Indian cobra, Naja naja",
70
+ "n01749939": "green mamba",
71
+ "n01751748": "sea snake",
72
+ "n01753488": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
73
+ "n01755581": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
74
+ "n01756291": "sidewinder, horned rattlesnake, Crotalus cerastes",
75
+ "n01768244": "trilobite",
76
+ "n01770081": "harvestman, daddy longlegs, Phalangium opilio",
77
+ "n01770393": "scorpion",
78
+ "n01773157": "black and gold garden spider, Argiope aurantia",
79
+ "n01773549": "barn spider, Araneus cavaticus",
80
+ "n01773797": "garden spider, Aranea diademata",
81
+ "n01774384": "black widow, Latrodectus mactans",
82
+ "n01774750": "tarantula",
83
+ "n01775062": "wolf spider, hunting spider",
84
+ "n01776313": "tick",
85
+ "n01784675": "centipede",
86
+ "n01795545": "black grouse",
87
+ "n01796340": "ptarmigan",
88
+ "n01797886": "ruffed grouse, partridge, Bonasa umbellus",
89
+ "n01798484": "prairie chicken, prairie grouse, prairie fowl",
90
+ "n01806143": "peacock",
91
+ "n01806567": "quail",
92
+ "n01807496": "partridge",
93
+ "n01817953": "African grey, African gray, Psittacus erithacus",
94
+ "n01818515": "macaw",
95
+ "n01819313": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
96
+ "n01820546": "lorikeet",
97
+ "n01824575": "coucal",
98
+ "n01828970": "bee eater",
99
+ "n01829413": "hornbill",
100
+ "n01833805": "hummingbird",
101
+ "n01843065": "jacamar",
102
+ "n01843383": "toucan",
103
+ "n01847000": "drake",
104
+ "n01855032": "red-breasted merganser, Mergus serrator",
105
+ "n01855672": "goose",
106
+ "n01860187": "black swan, Cygnus atratus",
107
+ "n01871265": "tusker",
108
+ "n01872401": "echidna, spiny anteater, anteater",
109
+ "n01873310": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
110
+ "n01877812": "wallaby, brush kangaroo",
111
+ "n01882714": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
112
+ "n01883070": "wombat",
113
+ "n01910747": "jellyfish",
114
+ "n01914609": "sea anemone, anemone",
115
+ "n01917289": "brain coral",
116
+ "n01924916": "flatworm, platyhelminth",
117
+ "n01930112": "nematode, nematode worm, roundworm",
118
+ "n01943899": "conch",
119
+ "n01944390": "snail",
120
+ "n01945685": "slug",
121
+ "n01950731": "sea slug, nudibranch",
122
+ "n01955084": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
123
+ "n01968897": "chambered nautilus, pearly nautilus, nautilus",
124
+ "n01978287": "Dungeness crab, Cancer magister",
125
+ "n01978455": "rock crab, Cancer irroratus",
126
+ "n01980166": "fiddler crab",
127
+ "n01981276": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
128
+ "n01983481": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
129
+ "n01984695": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
130
+ "n01985128": "crayfish, crawfish, crawdad, crawdaddy",
131
+ "n01986214": "hermit crab",
132
+ "n01990800": "isopod",
133
+ "n02002556": "white stork, Ciconia ciconia",
134
+ "n02002724": "black stork, Ciconia nigra",
135
+ "n02006656": "spoonbill",
136
+ "n02007558": "flamingo",
137
+ "n02009229": "little blue heron, Egretta caerulea",
138
+ "n02009912": "American egret, great white heron, Egretta albus",
139
+ "n02011460": "bittern",
140
+ "n02012849": "crane",
141
+ "n02013706": "limpkin, Aramus pictus",
142
+ "n02017213": "European gallinule, Porphyrio porphyrio",
143
+ "n02018207": "American coot, marsh hen, mud hen, water hen, Fulica americana",
144
+ "n02018795": "bustard",
145
+ "n02025239": "ruddy turnstone, Arenaria interpres",
146
+ "n02027492": "red-backed sandpiper, dunlin, Erolia alpina",
147
+ "n02028035": "redshank, Tringa totanus",
148
+ "n02033041": "dowitcher",
149
+ "n02037110": "oystercatcher, oyster catcher",
150
+ "n02051845": "pelican",
151
+ "n02056570": "king penguin, Aptenodytes patagonica",
152
+ "n02058221": "albatross, mollymawk",
153
+ "n02066245": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
154
+ "n02071294": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
155
+ "n02074367": "dugong, Dugong dugon",
156
+ "n02077923": "sea lion",
157
+ "n02085620": "Chihuahua",
158
+ "n02085782": "Japanese spaniel",
159
+ "n02085936": "Maltese dog, Maltese terrier, Maltese",
160
+ "n02086079": "Pekinese, Pekingese, Peke",
161
+ "n02086240": "Shih-Tzu",
162
+ "n02086646": "Blenheim spaniel",
163
+ "n02086910": "papillon",
164
+ "n02087046": "toy terrier",
165
+ "n02087394": "Rhodesian ridgeback",
166
+ "n02088094": "Afghan hound, Afghan",
167
+ "n02088238": "basset, basset hound",
168
+ "n02088364": "beagle",
169
+ "n02088466": "bloodhound, sleuthhound",
170
+ "n02088632": "bluetick",
171
+ "n02089078": "black-and-tan coonhound",
172
+ "n02089867": "Walker hound, Walker foxhound",
173
+ "n02089973": "English foxhound",
174
+ "n02090379": "redbone",
175
+ "n02090622": "borzoi, Russian wolfhound",
176
+ "n02090721": "Irish wolfhound",
177
+ "n02091032": "Italian greyhound",
178
+ "n02091134": "whippet",
179
+ "n02091244": "Ibizan hound, Ibizan Podenco",
180
+ "n02091467": "Norwegian elkhound, elkhound",
181
+ "n02091635": "otterhound, otter hound",
182
+ "n02091831": "Saluki, gazelle hound",
183
+ "n02092002": "Scottish deerhound, deerhound",
184
+ "n02092339": "Weimaraner",
185
+ "n02093256": "Staffordshire bullterrier, Staffordshire bull terrier",
186
+ "n02093428": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
187
+ "n02093647": "Bedlington terrier",
188
+ "n02093754": "Border terrier",
189
+ "n02093859": "Kerry blue terrier",
190
+ "n02093991": "Irish terrier",
191
+ "n02094114": "Norfolk terrier",
192
+ "n02094258": "Norwich terrier",
193
+ "n02094433": "Yorkshire terrier",
194
+ "n02095314": "wire-haired fox terrier",
195
+ "n02095570": "Lakeland terrier",
196
+ "n02095889": "Sealyham terrier, Sealyham",
197
+ "n02096051": "Airedale, Airedale terrier",
198
+ "n02096177": "cairn, cairn terrier",
199
+ "n02096294": "Australian terrier",
200
+ "n02096437": "Dandie Dinmont, Dandie Dinmont terrier",
201
+ "n02096585": "Boston bull, Boston terrier",
202
+ "n02097047": "miniature schnauzer",
203
+ "n02097130": "giant schnauzer",
204
+ "n02097209": "standard schnauzer",
205
+ "n02097298": "Scotch terrier, Scottish terrier, Scottie",
206
+ "n02097474": "Tibetan terrier, chrysanthemum dog",
207
+ "n02097658": "silky terrier, Sydney silky",
208
+ "n02098105": "soft-coated wheaten terrier",
209
+ "n02098286": "West Highland white terrier",
210
+ "n02098413": "Lhasa, Lhasa apso",
211
+ "n02099267": "flat-coated retriever",
212
+ "n02099429": "curly-coated retriever",
213
+ "n02099601": "golden retriever",
214
+ "n02099712": "Labrador retriever",
215
+ "n02099849": "Chesapeake Bay retriever",
216
+ "n02100236": "German short-haired pointer",
217
+ "n02100583": "vizsla, Hungarian pointer",
218
+ "n02100735": "English setter",
219
+ "n02100877": "Irish setter, red setter",
220
+ "n02101006": "Gordon setter",
221
+ "n02101388": "Brittany spaniel",
222
+ "n02101556": "clumber, clumber spaniel",
223
+ "n02102040": "English springer, English springer spaniel",
224
+ "n02102177": "Welsh springer spaniel",
225
+ "n02102318": "cocker spaniel, English cocker spaniel, cocker",
226
+ "n02102480": "Sussex spaniel",
227
+ "n02102973": "Irish water spaniel",
228
+ "n02104029": "kuvasz",
229
+ "n02104365": "schipperke",
230
+ "n02105056": "groenendael",
231
+ "n02105162": "malinois",
232
+ "n02105251": "briard",
233
+ "n02105412": "kelpie",
234
+ "n02105505": "komondor",
235
+ "n02105641": "Old English sheepdog, bobtail",
236
+ "n02105855": "Shetland sheepdog, Shetland sheep dog, Shetland",
237
+ "n02106030": "collie",
238
+ "n02106166": "Border collie",
239
+ "n02106382": "Bouvier des Flandres, Bouviers des Flandres",
240
+ "n02106550": "Rottweiler",
241
+ "n02106662": "German shepherd, German shepherd dog, German police dog, alsatian",
242
+ "n02107142": "Doberman, Doberman pinscher",
243
+ "n02107312": "miniature pinscher",
244
+ "n02107574": "Greater Swiss Mountain dog",
245
+ "n02107683": "Bernese mountain dog",
246
+ "n02107908": "Appenzeller",
247
+ "n02108000": "EntleBucher",
248
+ "n02108089": "boxer",
249
+ "n02108422": "bull mastiff",
250
+ "n02108551": "Tibetan mastiff",
251
+ "n02108915": "French bulldog",
252
+ "n02109047": "Great Dane",
253
+ "n02109525": "Saint Bernard, St Bernard",
254
+ "n02109961": "Eskimo dog, husky",
255
+ "n02110063": "malamute, malemute, Alaskan malamute",
256
+ "n02110185": "Siberian husky",
257
+ "n02110341": "dalmatian, coach dog, carriage dog",
258
+ "n02110627": "affenpinscher, monkey pinscher, monkey dog",
259
+ "n02110806": "basenji",
260
+ "n02110958": "pug, pug-dog",
261
+ "n02111129": "Leonberg",
262
+ "n02111277": "Newfoundland, Newfoundland dog",
263
+ "n02111500": "Great Pyrenees",
264
+ "n02111889": "Samoyed, Samoyede",
265
+ "n02112018": "Pomeranian",
266
+ "n02112137": "chow, chow chow",
267
+ "n02112350": "keeshond",
268
+ "n02112706": "Brabancon griffon",
269
+ "n02113023": "Pembroke, Pembroke Welsh corgi",
270
+ "n02113186": "Cardigan, Cardigan Welsh corgi",
271
+ "n02113624": "toy poodle",
272
+ "n02113712": "miniature poodle",
273
+ "n02113799": "standard poodle",
274
+ "n02113978": "Mexican hairless",
275
+ "n02114367": "timber wolf, grey wolf, gray wolf, Canis lupus",
276
+ "n02114548": "white wolf, Arctic wolf, Canis lupus tundrarum",
277
+ "n02114712": "red wolf, maned wolf, Canis rufus, Canis niger",
278
+ "n02114855": "coyote, prairie wolf, brush wolf, Canis latrans",
279
+ "n02115641": "dingo, warrigal, warragal, Canis dingo",
280
+ "n02115913": "dhole, Cuon alpinus",
281
+ "n02116738": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
282
+ "n02117135": "hyena, hyaena",
283
+ "n02119022": "red fox, Vulpes vulpes",
284
+ "n02119789": "kit fox, Vulpes macrotis",
285
+ "n02120079": "Arctic fox, white fox, Alopex lagopus",
286
+ "n02120505": "grey fox, gray fox, Urocyon cinereoargenteus",
287
+ "n02123045": "tabby, tabby cat",
288
+ "n02123159": "tiger cat",
289
+ "n02123394": "Persian cat",
290
+ "n02123597": "Siamese cat, Siamese",
291
+ "n02124075": "Egyptian cat",
292
+ "n02125311": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
293
+ "n02127052": "lynx, catamount",
294
+ "n02128385": "leopard, Panthera pardus",
295
+ "n02128757": "snow leopard, ounce, Panthera uncia",
296
+ "n02128925": "jaguar, panther, Panthera onca, Felis onca",
297
+ "n02129165": "lion, king of beasts, Panthera leo",
298
+ "n02129604": "tiger, Panthera tigris",
299
+ "n02130308": "cheetah, chetah, Acinonyx jubatus",
300
+ "n02132136": "brown bear, bruin, Ursus arctos",
301
+ "n02133161": "American black bear, black bear, Ursus americanus, Euarctos americanus",
302
+ "n02134084": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
303
+ "n02134418": "sloth bear, Melursus ursinus, Ursus ursinus",
304
+ "n02137549": "mongoose",
305
+ "n02138441": "meerkat, mierkat",
306
+ "n02165105": "tiger beetle",
307
+ "n02165456": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
308
+ "n02167151": "ground beetle, carabid beetle",
309
+ "n02168699": "long-horned beetle, longicorn, longicorn beetle",
310
+ "n02169497": "leaf beetle, chrysomelid",
311
+ "n02172182": "dung beetle",
312
+ "n02174001": "rhinoceros beetle",
313
+ "n02177972": "weevil",
314
+ "n02190166": "fly",
315
+ "n02206856": "bee",
316
+ "n02219486": "ant, emmet, pismire",
317
+ "n02226429": "grasshopper, hopper",
318
+ "n02229544": "cricket",
319
+ "n02231487": "walking stick, walkingstick, stick insect",
320
+ "n02233338": "cockroach, roach",
321
+ "n02236044": "mantis, mantid",
322
+ "n02256656": "cicada, cicala",
323
+ "n02259212": "leafhopper",
324
+ "n02264363": "lacewing, lacewing fly",
325
+ "n02268443": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
326
+ "n02268853": "damselfly",
327
+ "n02276258": "admiral",
328
+ "n02277742": "ringlet, ringlet butterfly",
329
+ "n02279972": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
330
+ "n02280649": "cabbage butterfly",
331
+ "n02281406": "sulphur butterfly, sulfur butterfly",
332
+ "n02281787": "lycaenid, lycaenid butterfly",
333
+ "n02317335": "starfish, sea star",
334
+ "n02319095": "sea urchin",
335
+ "n02321529": "sea cucumber, holothurian",
336
+ "n02325366": "wood rabbit, cottontail, cottontail rabbit",
337
+ "n02326432": "hare",
338
+ "n02328150": "Angora, Angora rabbit",
339
+ "n02342885": "hamster",
340
+ "n02346627": "porcupine, hedgehog",
341
+ "n02356798": "fox squirrel, eastern fox squirrel, Sciurus niger",
342
+ "n02361337": "marmot",
343
+ "n02363005": "beaver",
344
+ "n02364673": "guinea pig, Cavia cobaya",
345
+ "n02389026": "sorrel",
346
+ "n02391049": "zebra",
347
+ "n02395406": "hog, pig, grunter, squealer, Sus scrofa",
348
+ "n02396427": "wild boar, boar, Sus scrofa",
349
+ "n02397096": "warthog",
350
+ "n02398521": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
351
+ "n02403003": "ox",
352
+ "n02408429": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
353
+ "n02410509": "bison",
354
+ "n02412080": "ram, tup",
355
+ "n02415577": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
356
+ "n02417914": "ibex, Capra ibex",
357
+ "n02422106": "hartebeest",
358
+ "n02422699": "impala, Aepyceros melampus",
359
+ "n02423022": "gazelle",
360
+ "n02437312": "Arabian camel, dromedary, Camelus dromedarius",
361
+ "n02437616": "llama",
362
+ "n02441942": "weasel",
363
+ "n02442845": "mink",
364
+ "n02443114": "polecat, fitch, foulmart, foumart, Mustela putorius",
365
+ "n02443484": "black-footed ferret, ferret, Mustela nigripes",
366
+ "n02444819": "otter",
367
+ "n02445715": "skunk, polecat, wood pussy",
368
+ "n02447366": "badger",
369
+ "n02454379": "armadillo",
370
+ "n02457408": "three-toed sloth, ai, Bradypus tridactylus",
371
+ "n02480495": "orangutan, orang, orangutang, Pongo pygmaeus",
372
+ "n02480855": "gorilla, Gorilla gorilla",
373
+ "n02481823": "chimpanzee, chimp, Pan troglodytes",
374
+ "n02483362": "gibbon, Hylobates lar",
375
+ "n02483708": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
376
+ "n02484975": "guenon, guenon monkey",
377
+ "n02486261": "patas, hussar monkey, Erythrocebus patas",
378
+ "n02486410": "baboon",
379
+ "n02487347": "macaque",
380
+ "n02488291": "langur",
381
+ "n02488702": "colobus, colobus monkey",
382
+ "n02489166": "proboscis monkey, Nasalis larvatus",
383
+ "n02490219": "marmoset",
384
+ "n02492035": "capuchin, ringtail, Cebus capucinus",
385
+ "n02492660": "howler monkey, howler",
386
+ "n02493509": "titi, titi monkey",
387
+ "n02493793": "spider monkey, Ateles geoffroyi",
388
+ "n02494079": "squirrel monkey, Saimiri sciureus",
389
+ "n02497673": "Madagascar cat, ring-tailed lemur, Lemur catta",
390
+ "n02500267": "indri, indris, Indri indri, Indri brevicaudatus",
391
+ "n02504013": "Indian elephant, Elephas maximus",
392
+ "n02504458": "African elephant, Loxodonta africana",
393
+ "n02509815": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
394
+ "n02510455": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
395
+ "n02514041": "barracouta, snoek",
396
+ "n02526121": "eel",
397
+ "n02536864": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
398
+ "n02606052": "rock beauty, Holocanthus tricolor",
399
+ "n02607072": "anemone fish",
400
+ "n02640242": "sturgeon",
401
+ "n02641379": "gar, garfish, garpike, billfish, Lepisosteus osseus",
402
+ "n02643566": "lionfish",
403
+ "n02655020": "puffer, pufferfish, blowfish, globefish",
404
+ "n02666196": "abacus",
405
+ "n02667093": "abaya",
406
+ "n02669723": "academic gown, academic robe, judge's robe",
407
+ "n02672831": "accordion, piano accordion, squeeze box",
408
+ "n02676566": "acoustic guitar",
409
+ "n02687172": "aircraft carrier, carrier, flattop, attack aircraft carrier",
410
+ "n02690373": "airliner",
411
+ "n02692877": "airship, dirigible",
412
+ "n02699494": "altar",
413
+ "n02701002": "ambulance",
414
+ "n02704792": "amphibian, amphibious vehicle",
415
+ "n02708093": "analog clock",
416
+ "n02727426": "apiary, bee house",
417
+ "n02730930": "apron",
418
+ "n02747177": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
419
+ "n02749479": "assault rifle, assault gun",
420
+ "n02769748": "backpack, back pack, knapsack, packsack, rucksack, haversack",
421
+ "n02776631": "bakery, bakeshop, bakehouse",
422
+ "n02777292": "balance beam, beam",
423
+ "n02782093": "balloon",
424
+ "n02783161": "ballpoint, ballpoint pen, ballpen, Biro",
425
+ "n02786058": "Band Aid",
426
+ "n02787622": "banjo",
427
+ "n02788148": "bannister, banister, balustrade, balusters, handrail",
428
+ "n02790996": "barbell",
429
+ "n02791124": "barber chair",
430
+ "n02791270": "barbershop",
431
+ "n02793495": "barn",
432
+ "n02794156": "barometer",
433
+ "n02795169": "barrel, cask",
434
+ "n02797295": "barrow, garden cart, lawn cart, wheelbarrow",
435
+ "n02799071": "baseball",
436
+ "n02802426": "basketball",
437
+ "n02804414": "bassinet",
438
+ "n02804610": "bassoon",
439
+ "n02807133": "bathing cap, swimming cap",
440
+ "n02808304": "bath towel",
441
+ "n02808440": "bathtub, bathing tub, bath, tub",
442
+ "n02814533": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
443
+ "n02814860": "beacon, lighthouse, beacon light, pharos",
444
+ "n02815834": "beaker",
445
+ "n02817516": "bearskin, busby, shako",
446
+ "n02823428": "beer bottle",
447
+ "n02823750": "beer glass",
448
+ "n02825657": "bell cote, bell cot",
449
+ "n02834397": "bib",
450
+ "n02835271": "bicycle-built-for-two, tandem bicycle, tandem",
451
+ "n02837789": "bikini, two-piece",
452
+ "n02840245": "binder, ring-binder",
453
+ "n02841315": "binoculars, field glasses, opera glasses",
454
+ "n02843684": "birdhouse",
455
+ "n02859443": "boathouse",
456
+ "n02860847": "bobsled, bobsleigh, bob",
457
+ "n02865351": "bolo tie, bolo, bola tie, bola",
458
+ "n02869837": "bonnet, poke bonnet",
459
+ "n02870880": "bookcase",
460
+ "n02871525": "bookshop, bookstore, bookstall",
461
+ "n02877765": "bottlecap",
462
+ "n02879718": "bow",
463
+ "n02883205": "bow tie, bow-tie, bowtie",
464
+ "n02892201": "brass, memorial tablet, plaque",
465
+ "n02892767": "brassiere, bra, bandeau",
466
+ "n02894605": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
467
+ "n02895154": "breastplate, aegis, egis",
468
+ "n02906734": "broom",
469
+ "n02909870": "bucket, pail",
470
+ "n02910353": "buckle",
471
+ "n02916936": "bulletproof vest",
472
+ "n02917067": "bullet train, bullet",
473
+ "n02927161": "butcher shop, meat market",
474
+ "n02930766": "cab, hack, taxi, taxicab",
475
+ "n02939185": "caldron, cauldron",
476
+ "n02948072": "candle, taper, wax light",
477
+ "n02950826": "cannon",
478
+ "n02951358": "canoe",
479
+ "n02951585": "can opener, tin opener",
480
+ "n02963159": "cardigan",
481
+ "n02965783": "car mirror",
482
+ "n02966193": "carousel, carrousel, merry-go-round, roundabout, whirligig",
483
+ "n02966687": "carpenter's kit, tool kit",
484
+ "n02971356": "carton",
485
+ "n02974003": "car wheel",
486
+ "n02977058": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
487
+ "n02978881": "cassette",
488
+ "n02979186": "cassette player",
489
+ "n02980441": "castle",
490
+ "n02981792": "catamaran",
491
+ "n02988304": "CD player",
492
+ "n02992211": "cello, violoncello",
493
+ "n02992529": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
494
+ "n02999410": "chain",
495
+ "n03000134": "chainlink fence",
496
+ "n03000247": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
497
+ "n03000684": "chain saw, chainsaw",
498
+ "n03014705": "chest",
499
+ "n03016953": "chiffonier, commode",
500
+ "n03017168": "chime, bell, gong",
501
+ "n03018349": "china cabinet, china closet",
502
+ "n03026506": "Christmas stocking",
503
+ "n03028079": "church, church building",
504
+ "n03032252": "cinema, movie theater, movie theatre, movie house, picture palace",
505
+ "n03041632": "cleaver, meat cleaver, chopper",
506
+ "n03042490": "cliff dwelling",
507
+ "n03045698": "cloak",
508
+ "n03047690": "clog, geta, patten, sabot",
509
+ "n03062245": "cocktail shaker",
510
+ "n03063599": "coffee mug",
511
+ "n03063689": "coffeepot",
512
+ "n03065424": "coil, spiral, volute, whorl, helix",
513
+ "n03075370": "combination lock",
514
+ "n03085013": "computer keyboard, keypad",
515
+ "n03089624": "confectionery, confectionary, candy store",
516
+ "n03095699": "container ship, containership, container vessel",
517
+ "n03100240": "convertible",
518
+ "n03109150": "corkscrew, bottle screw",
519
+ "n03110669": "cornet, horn, trumpet, trump",
520
+ "n03124043": "cowboy boot",
521
+ "n03124170": "cowboy hat, ten-gallon hat",
522
+ "n03125729": "cradle",
523
+ "n03126707": "crane2",
524
+ "n03127747": "crash helmet",
525
+ "n03127925": "crate",
526
+ "n03131574": "crib, cot",
527
+ "n03133878": "Crock Pot",
528
+ "n03134739": "croquet ball",
529
+ "n03141823": "crutch",
530
+ "n03146219": "cuirass",
531
+ "n03160309": "dam, dike, dyke",
532
+ "n03179701": "desk",
533
+ "n03180011": "desktop computer",
534
+ "n03187595": "dial telephone, dial phone",
535
+ "n03188531": "diaper, nappy, napkin",
536
+ "n03196217": "digital clock",
537
+ "n03197337": "digital watch",
538
+ "n03201208": "dining table, board",
539
+ "n03207743": "dishrag, dishcloth",
540
+ "n03207941": "dishwasher, dish washer, dishwashing machine",
541
+ "n03208938": "disk brake, disc brake",
542
+ "n03216828": "dock, dockage, docking facility",
543
+ "n03218198": "dogsled, dog sled, dog sleigh",
544
+ "n03220513": "dome",
545
+ "n03223299": "doormat, welcome mat",
546
+ "n03240683": "drilling platform, offshore rig",
547
+ "n03249569": "drum, membranophone, tympan",
548
+ "n03250847": "drumstick",
549
+ "n03255030": "dumbbell",
550
+ "n03259280": "Dutch oven",
551
+ "n03271574": "electric fan, blower",
552
+ "n03272010": "electric guitar",
553
+ "n03272562": "electric locomotive",
554
+ "n03290653": "entertainment center",
555
+ "n03291819": "envelope",
556
+ "n03297495": "espresso maker",
557
+ "n03314780": "face powder",
558
+ "n03325584": "feather boa, boa",
559
+ "n03337140": "file, file cabinet, filing cabinet",
560
+ "n03344393": "fireboat",
561
+ "n03345487": "fire engine, fire truck",
562
+ "n03347037": "fire screen, fireguard",
563
+ "n03355925": "flagpole, flagstaff",
564
+ "n03372029": "flute, transverse flute",
565
+ "n03376595": "folding chair",
566
+ "n03379051": "football helmet",
567
+ "n03384352": "forklift",
568
+ "n03388043": "fountain",
569
+ "n03388183": "fountain pen",
570
+ "n03388549": "four-poster",
571
+ "n03393912": "freight car",
572
+ "n03394916": "French horn, horn",
573
+ "n03400231": "frying pan, frypan, skillet",
574
+ "n03404251": "fur coat",
575
+ "n03417042": "garbage truck, dustcart",
576
+ "n03424325": "gasmask, respirator, gas helmet",
577
+ "n03425413": "gas pump, gasoline pump, petrol pump, island dispenser",
578
+ "n03443371": "goblet",
579
+ "n03444034": "go-kart",
580
+ "n03445777": "golf ball",
581
+ "n03445924": "golfcart, golf cart",
582
+ "n03447447": "gondola",
583
+ "n03447721": "gong, tam-tam",
584
+ "n03450230": "gown",
585
+ "n03452741": "grand piano, grand",
586
+ "n03457902": "greenhouse, nursery, glasshouse",
587
+ "n03459775": "grille, radiator grille",
588
+ "n03461385": "grocery store, grocery, food market, market",
589
+ "n03467068": "guillotine",
590
+ "n03476684": "hair slide",
591
+ "n03476991": "hair spray",
592
+ "n03478589": "half track",
593
+ "n03481172": "hammer",
594
+ "n03482405": "hamper",
595
+ "n03483316": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
596
+ "n03485407": "hand-held computer, hand-held microcomputer",
597
+ "n03485794": "handkerchief, hankie, hanky, hankey",
598
+ "n03492542": "hard disc, hard disk, fixed disk",
599
+ "n03494278": "harmonica, mouth organ, harp, mouth harp",
600
+ "n03495258": "harp",
601
+ "n03496892": "harvester, reaper",
602
+ "n03498962": "hatchet",
603
+ "n03527444": "holster",
604
+ "n03529860": "home theater, home theatre",
605
+ "n03530642": "honeycomb",
606
+ "n03532672": "hook, claw",
607
+ "n03534580": "hoopskirt, crinoline",
608
+ "n03535780": "horizontal bar, high bar",
609
+ "n03538406": "horse cart, horse-cart",
610
+ "n03544143": "hourglass",
611
+ "n03584254": "iPod",
612
+ "n03584829": "iron, smoothing iron",
613
+ "n03590841": "jack-o'-lantern",
614
+ "n03594734": "jean, blue jean, denim",
615
+ "n03594945": "jeep, landrover",
616
+ "n03595614": "jersey, T-shirt, tee shirt",
617
+ "n03598930": "jigsaw puzzle",
618
+ "n03599486": "jinrikisha, ricksha, rickshaw",
619
+ "n03602883": "joystick",
620
+ "n03617480": "kimono",
621
+ "n03623198": "knee pad",
622
+ "n03627232": "knot",
623
+ "n03630383": "lab coat, laboratory coat",
624
+ "n03633091": "ladle",
625
+ "n03637318": "lampshade, lamp shade",
626
+ "n03642806": "laptop, laptop computer",
627
+ "n03649909": "lawn mower, mower",
628
+ "n03657121": "lens cap, lens cover",
629
+ "n03658185": "letter opener, paper knife, paperknife",
630
+ "n03661043": "library",
631
+ "n03662601": "lifeboat",
632
+ "n03666591": "lighter, light, igniter, ignitor",
633
+ "n03670208": "limousine, limo",
634
+ "n03673027": "liner, ocean liner",
635
+ "n03676483": "lipstick, lip rouge",
636
+ "n03680355": "Loafer",
637
+ "n03690938": "lotion",
638
+ "n03691459": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
639
+ "n03692522": "loupe, jeweler's loupe",
640
+ "n03697007": "lumbermill, sawmill",
641
+ "n03706229": "magnetic compass",
642
+ "n03709823": "mailbag, postbag",
643
+ "n03710193": "mailbox, letter box",
644
+ "n03710637": "maillot",
645
+ "n03710721": "maillot, tank suit",
646
+ "n03717622": "manhole cover",
647
+ "n03720891": "maraca",
648
+ "n03721384": "marimba, xylophone",
649
+ "n03724870": "mask",
650
+ "n03729826": "matchstick",
651
+ "n03733131": "maypole",
652
+ "n03733281": "maze, labyrinth",
653
+ "n03733805": "measuring cup",
654
+ "n03742115": "medicine chest, medicine cabinet",
655
+ "n03743016": "megalith, megalithic structure",
656
+ "n03759954": "microphone, mike",
657
+ "n03761084": "microwave, microwave oven",
658
+ "n03763968": "military uniform",
659
+ "n03764736": "milk can",
660
+ "n03769881": "minibus",
661
+ "n03770439": "miniskirt, mini",
662
+ "n03770679": "minivan",
663
+ "n03773504": "missile",
664
+ "n03775071": "mitten",
665
+ "n03775546": "mixing bowl",
666
+ "n03776460": "mobile home, manufactured home",
667
+ "n03777568": "Model T",
668
+ "n03777754": "modem",
669
+ "n03781244": "monastery",
670
+ "n03782006": "monitor",
671
+ "n03785016": "moped",
672
+ "n03786901": "mortar",
673
+ "n03787032": "mortarboard",
674
+ "n03788195": "mosque",
675
+ "n03788365": "mosquito net",
676
+ "n03791053": "motor scooter, scooter",
677
+ "n03792782": "mountain bike, all-terrain bike, off-roader",
678
+ "n03792972": "mountain tent",
679
+ "n03793489": "mouse, computer mouse",
680
+ "n03794056": "mousetrap",
681
+ "n03796401": "moving van",
682
+ "n03803284": "muzzle",
683
+ "n03804744": "nail",
684
+ "n03814639": "neck brace",
685
+ "n03814906": "necklace",
686
+ "n03825788": "nipple",
687
+ "n03832673": "notebook, notebook computer",
688
+ "n03837869": "obelisk",
689
+ "n03838899": "oboe, hautboy, hautbois",
690
+ "n03840681": "ocarina, sweet potato",
691
+ "n03841143": "odometer, hodometer, mileometer, milometer",
692
+ "n03843555": "oil filter",
693
+ "n03854065": "organ, pipe organ",
694
+ "n03857828": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
695
+ "n03866082": "overskirt",
696
+ "n03868242": "oxcart",
697
+ "n03868863": "oxygen mask",
698
+ "n03871628": "packet",
699
+ "n03873416": "paddle, boat paddle",
700
+ "n03874293": "paddlewheel, paddle wheel",
701
+ "n03874599": "padlock",
702
+ "n03876231": "paintbrush",
703
+ "n03877472": "pajama, pyjama, pj's, jammies",
704
+ "n03877845": "palace",
705
+ "n03884397": "panpipe, pandean pipe, syrinx",
706
+ "n03887697": "paper towel",
707
+ "n03888257": "parachute, chute",
708
+ "n03888605": "parallel bars, bars",
709
+ "n03891251": "park bench",
710
+ "n03891332": "parking meter",
711
+ "n03895866": "passenger car, coach, carriage",
712
+ "n03899768": "patio, terrace",
713
+ "n03902125": "pay-phone, pay-station",
714
+ "n03903868": "pedestal, plinth, footstall",
715
+ "n03908618": "pencil box, pencil case",
716
+ "n03908714": "pencil sharpener",
717
+ "n03916031": "perfume, essence",
718
+ "n03920288": "Petri dish",
719
+ "n03924679": "photocopier",
720
+ "n03929660": "pick, plectrum, plectron",
721
+ "n03929855": "pickelhaube",
722
+ "n03930313": "picket fence, paling",
723
+ "n03930630": "pickup, pickup truck",
724
+ "n03933933": "pier",
725
+ "n03935335": "piggy bank, penny bank",
726
+ "n03937543": "pill bottle",
727
+ "n03938244": "pillow",
728
+ "n03942813": "ping-pong ball",
729
+ "n03944341": "pinwheel",
730
+ "n03947888": "pirate, pirate ship",
731
+ "n03950228": "pitcher, ewer",
732
+ "n03954731": "plane, carpenter's plane, woodworking plane",
733
+ "n03956157": "planetarium",
734
+ "n03958227": "plastic bag",
735
+ "n03961711": "plate rack",
736
+ "n03967562": "plow, plough",
737
+ "n03970156": "plunger, plumber's helper",
738
+ "n03976467": "Polaroid camera, Polaroid Land camera",
739
+ "n03976657": "pole",
740
+ "n03977966": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
741
+ "n03980874": "poncho",
742
+ "n03982430": "pool table, billiard table, snooker table",
743
+ "n03983396": "pop bottle, soda bottle",
744
+ "n03991062": "pot, flowerpot",
745
+ "n03992509": "potter's wheel",
746
+ "n03995372": "power drill",
747
+ "n03998194": "prayer rug, prayer mat",
748
+ "n04004767": "printer",
749
+ "n04005630": "prison, prison house",
750
+ "n04008634": "projectile, missile",
751
+ "n04009552": "projector",
752
+ "n04019541": "puck, hockey puck",
753
+ "n04023962": "punching bag, punch bag, punching ball, punchball",
754
+ "n04026417": "purse",
755
+ "n04033901": "quill, quill pen",
756
+ "n04033995": "quilt, comforter, comfort, puff",
757
+ "n04037443": "racer, race car, racing car",
758
+ "n04039381": "racket, racquet",
759
+ "n04040759": "radiator",
760
+ "n04041544": "radio, wireless",
761
+ "n04044716": "radio telescope, radio reflector",
762
+ "n04049303": "rain barrel",
763
+ "n04065272": "recreational vehicle, RV, R.V.",
764
+ "n04067472": "reel",
765
+ "n04069434": "reflex camera",
766
+ "n04070727": "refrigerator, icebox",
767
+ "n04074963": "remote control, remote",
768
+ "n04081281": "restaurant, eating house, eating place, eatery",
769
+ "n04086273": "revolver, six-gun, six-shooter",
770
+ "n04090263": "rifle",
771
+ "n04099969": "rocking chair, rocker",
772
+ "n04111531": "rotisserie",
773
+ "n04116512": "rubber eraser, rubber, pencil eraser",
774
+ "n04118538": "rugby ball",
775
+ "n04118776": "rule, ruler",
776
+ "n04120489": "running shoe",
777
+ "n04125021": "safe",
778
+ "n04127249": "safety pin",
779
+ "n04131690": "saltshaker, salt shaker",
780
+ "n04133789": "sandal",
781
+ "n04136333": "sarong",
782
+ "n04141076": "sax, saxophone",
783
+ "n04141327": "scabbard",
784
+ "n04141975": "scale, weighing machine",
785
+ "n04146614": "school bus",
786
+ "n04147183": "schooner",
787
+ "n04149813": "scoreboard",
788
+ "n04152593": "screen, CRT screen",
789
+ "n04153751": "screw",
790
+ "n04154565": "screwdriver",
791
+ "n04162706": "seat belt, seatbelt",
792
+ "n04179913": "sewing machine",
793
+ "n04192698": "shield, buckler",
794
+ "n04200800": "shoe shop, shoe-shop, shoe store",
795
+ "n04201297": "shoji",
796
+ "n04204238": "shopping basket",
797
+ "n04204347": "shopping cart",
798
+ "n04208210": "shovel",
799
+ "n04209133": "shower cap",
800
+ "n04209239": "shower curtain",
801
+ "n04228054": "ski",
802
+ "n04229816": "ski mask",
803
+ "n04235860": "sleeping bag",
804
+ "n04238763": "slide rule, slipstick",
805
+ "n04239074": "sliding door",
806
+ "n04243546": "slot, one-armed bandit",
807
+ "n04251144": "snorkel",
808
+ "n04252077": "snowmobile",
809
+ "n04252225": "snowplow, snowplough",
810
+ "n04254120": "soap dispenser",
811
+ "n04254680": "soccer ball",
812
+ "n04254777": "sock",
813
+ "n04258138": "solar dish, solar collector, solar furnace",
814
+ "n04259630": "sombrero",
815
+ "n04263257": "soup bowl",
816
+ "n04264628": "space bar",
817
+ "n04265275": "space heater",
818
+ "n04266014": "space shuttle",
819
+ "n04270147": "spatula",
820
+ "n04273569": "speedboat",
821
+ "n04275548": "spider web, spider's web",
822
+ "n04277352": "spindle",
823
+ "n04285008": "sports car, sport car",
824
+ "n04286575": "spotlight, spot",
825
+ "n04296562": "stage",
826
+ "n04310018": "steam locomotive",
827
+ "n04311004": "steel arch bridge",
828
+ "n04311174": "steel drum",
829
+ "n04317175": "stethoscope",
830
+ "n04325704": "stole",
831
+ "n04326547": "stone wall",
832
+ "n04328186": "stopwatch, stop watch",
833
+ "n04330267": "stove",
834
+ "n04332243": "strainer",
835
+ "n04335435": "streetcar, tram, tramcar, trolley, trolley car",
836
+ "n04336792": "stretcher",
837
+ "n04344873": "studio couch, day bed",
838
+ "n04346328": "stupa, tope",
839
+ "n04347754": "submarine, pigboat, sub, U-boat",
840
+ "n04350905": "suit, suit of clothes",
841
+ "n04355338": "sundial",
842
+ "n04355933": "sunglass",
843
+ "n04356056": "sunglasses, dark glasses, shades",
844
+ "n04357314": "sunscreen, sunblock, sun blocker",
845
+ "n04366367": "suspension bridge",
846
+ "n04367480": "swab, swob, mop",
847
+ "n04370456": "sweatshirt",
848
+ "n04371430": "swimming trunks, bathing trunks",
849
+ "n04371774": "swing",
850
+ "n04372370": "switch, electric switch, electrical switch",
851
+ "n04376876": "syringe",
852
+ "n04380533": "table lamp",
853
+ "n04389033": "tank, army tank, armored combat vehicle, armoured combat vehicle",
854
+ "n04392985": "tape player",
855
+ "n04398044": "teapot",
856
+ "n04399382": "teddy, teddy bear",
857
+ "n04404412": "television, television system",
858
+ "n04409515": "tennis ball",
859
+ "n04417672": "thatch, thatched roof",
860
+ "n04418357": "theater curtain, theatre curtain",
861
+ "n04423845": "thimble",
862
+ "n04428191": "thresher, thrasher, threshing machine",
863
+ "n04429376": "throne",
864
+ "n04435653": "tile roof",
865
+ "n04442312": "toaster",
866
+ "n04443257": "tobacco shop, tobacconist shop, tobacconist",
867
+ "n04447861": "toilet seat",
868
+ "n04456115": "torch",
869
+ "n04458633": "totem pole",
870
+ "n04461696": "tow truck, tow car, wrecker",
871
+ "n04462240": "toyshop",
872
+ "n04465501": "tractor",
873
+ "n04467665": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
874
+ "n04476259": "tray",
875
+ "n04479046": "trench coat",
876
+ "n04482393": "tricycle, trike, velocipede",
877
+ "n04483307": "trimaran",
878
+ "n04485082": "tripod",
879
+ "n04486054": "triumphal arch",
880
+ "n04487081": "trolleybus, trolley coach, trackless trolley",
881
+ "n04487394": "trombone",
882
+ "n04493381": "tub, vat",
883
+ "n04501370": "turnstile",
884
+ "n04505470": "typewriter keyboard",
885
+ "n04507155": "umbrella",
886
+ "n04509417": "unicycle, monocycle",
887
+ "n04515003": "upright, upright piano",
888
+ "n04517823": "vacuum, vacuum cleaner",
889
+ "n04522168": "vase",
890
+ "n04523525": "vault",
891
+ "n04525038": "velvet",
892
+ "n04525305": "vending machine",
893
+ "n04532106": "vestment",
894
+ "n04532670": "viaduct",
895
+ "n04536866": "violin, fiddle",
896
+ "n04540053": "volleyball",
897
+ "n04542943": "waffle iron",
898
+ "n04548280": "wall clock",
899
+ "n04548362": "wallet, billfold, notecase, pocketbook",
900
+ "n04550184": "wardrobe, closet, press",
901
+ "n04552348": "warplane, military plane",
902
+ "n04553703": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
903
+ "n04554684": "washer, automatic washer, washing machine",
904
+ "n04557648": "water bottle",
905
+ "n04560804": "water jug",
906
+ "n04562935": "water tower",
907
+ "n04579145": "whiskey jug",
908
+ "n04579432": "whistle",
909
+ "n04584207": "wig",
910
+ "n04589890": "window screen",
911
+ "n04590129": "window shade",
912
+ "n04591157": "Windsor tie",
913
+ "n04591713": "wine bottle",
914
+ "n04592741": "wing",
915
+ "n04596742": "wok",
916
+ "n04597913": "wooden spoon",
917
+ "n04599235": "wool, woolen, woollen",
918
+ "n04604644": "worm fence, snake fence, snake-rail fence, Virginia fence",
919
+ "n04606251": "wreck",
920
+ "n04612504": "yawl",
921
+ "n04613696": "yurt",
922
+ "n06359193": "web site, website, internet site, site",
923
+ "n06596364": "comic book",
924
+ "n06785654": "crossword puzzle, crossword",
925
+ "n06794110": "street sign",
926
+ "n06874185": "traffic light, traffic signal, stoplight",
927
+ "n07248320": "book jacket, dust cover, dust jacket, dust wrapper",
928
+ "n07565083": "menu",
929
+ "n07579787": "plate",
930
+ "n07583066": "guacamole",
931
+ "n07584110": "consomme",
932
+ "n07590611": "hot pot, hotpot",
933
+ "n07613480": "trifle",
934
+ "n07614500": "ice cream, icecream",
935
+ "n07615774": "ice lolly, lolly, lollipop, popsicle",
936
+ "n07684084": "French loaf",
937
+ "n07693725": "bagel, beigel",
938
+ "n07695742": "pretzel",
939
+ "n07697313": "cheeseburger",
940
+ "n07697537": "hotdog, hot dog, red hot",
941
+ "n07711569": "mashed potato",
942
+ "n07714571": "head cabbage",
943
+ "n07714990": "broccoli",
944
+ "n07715103": "cauliflower",
945
+ "n07716358": "zucchini, courgette",
946
+ "n07716906": "spaghetti squash",
947
+ "n07717410": "acorn squash",
948
+ "n07717556": "butternut squash",
949
+ "n07718472": "cucumber, cuke",
950
+ "n07718747": "artichoke, globe artichoke",
951
+ "n07720875": "bell pepper",
952
+ "n07730033": "cardoon",
953
+ "n07734744": "mushroom",
954
+ "n07742313": "Granny Smith",
955
+ "n07745940": "strawberry",
956
+ "n07747607": "orange",
957
+ "n07749582": "lemon",
958
+ "n07753113": "fig",
959
+ "n07753275": "pineapple, ananas",
960
+ "n07753592": "banana",
961
+ "n07754684": "jackfruit, jak, jack",
962
+ "n07760859": "custard apple",
963
+ "n07768694": "pomegranate",
964
+ "n07802026": "hay",
965
+ "n07831146": "carbonara",
966
+ "n07836838": "chocolate sauce, chocolate syrup",
967
+ "n07860988": "dough",
968
+ "n07871810": "meat loaf, meatloaf",
969
+ "n07873807": "pizza, pizza pie",
970
+ "n07875152": "potpie",
971
+ "n07880968": "burrito",
972
+ "n07892512": "red wine",
973
+ "n07920052": "espresso",
974
+ "n07930864": "cup",
975
+ "n07932039": "eggnog",
976
+ "n09193705": "alp",
977
+ "n09229709": "bubble",
978
+ "n09246464": "cliff, drop, drop-off",
979
+ "n09256479": "coral reef",
980
+ "n09288635": "geyser",
981
+ "n09332890": "lakeside, lakeshore",
982
+ "n09399592": "promontory, headland, head, foreland",
983
+ "n09421951": "sandbar, sand bar",
984
+ "n09428293": "seashore, coast, seacoast, sea-coast",
985
+ "n09468604": "valley, vale",
986
+ "n09472597": "volcano",
987
+ "n09835506": "ballplayer, baseball player",
988
+ "n10148035": "groom, bridegroom",
989
+ "n10565667": "scuba diver",
990
+ "n11879895": "rapeseed",
991
+ "n11939491": "daisy",
992
+ "n12057211": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
993
+ "n12144580": "corn",
994
+ "n12267677": "acorn",
995
+ "n12620546": "hip, rose hip, rosehip",
996
+ "n12768682": "buckeye, horse chestnut, conker",
997
+ "n12985857": "coral fungus",
998
+ "n12998815": "agaric",
999
+ "n13037406": "gyromitra",
1000
+ "n13040303": "stinkhorn, carrion fungus",
1001
+ "n13044778": "earthstar",
1002
+ "n13052670": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
1003
+ "n13054560": "bolete",
1004
+ "n13133613": "ear, spike, capitulum",
1005
+ "n15075141": "toilet tissue, toilet paper, bathroom tissue",
1006
+ }
1007
+ )
tasks/image_classification/plotting.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ import os
6
+ import imageio
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib as mpl
9
+ from matplotlib import patheffects
10
+ mpl.use('Agg')
11
+ import seaborn as sns
12
+ import numpy as np
13
+ from tqdm.auto import tqdm
14
+ sns.set_style('darkgrid')
15
+
16
+ from tqdm.auto import tqdm
17
+ from scipy import ndimage
18
+ import umap
19
+ from scipy.special import softmax
20
+
21
+ import subprocess as sp
22
+ import cv2 # Still potentially useful for color conversion checks if needed
23
+ import os
24
+
25
+ def save_frames_to_mp4(frames, output_filename, fps=15.0, gop_size=None, crf=23, preset='medium', pix_fmt='yuv420p'):
26
+ """
27
+ Saves a list of NumPy array frames to an MP4 video file using FFmpeg via subprocess.
28
+
29
+ Includes fix for odd frame dimensions by padding to the nearest even number using -vf pad.
30
+
31
+ Requires FFmpeg to be installed and available in the system PATH.
32
+
33
+ Args:
34
+ frames (list): A list of NumPy arrays representing the video frames.
35
+ Expected format: uint8, (height, width, 3) for BGR color
36
+ or (height, width) for grayscale. Should be consistent.
37
+ output_filename (str): The path and name for the output MP4 file.
38
+ fps (float, optional): Frames per second for the output video. Defaults to 15.0.
39
+ gop_size (int, optional): Group of Pictures (GOP) size. This determines the
40
+ maximum interval between keyframes. Lower values
41
+ mean more frequent keyframes (better seeking, larger file).
42
+ Defaults to int(fps) (approx 1 keyframe per second).
43
+ crf (int, optional): Constant Rate Factor for H.264 encoding. Lower values mean
44
+ better quality and larger files. Typical range: 18-28.
45
+ Defaults to 23.
46
+ preset (str, optional): FFmpeg encoding speed preset. Affects encoding time
47
+ and compression efficiency. Options include 'ultrafast',
48
+ 'superfast', 'veryfast', 'faster', 'fast', 'medium',
49
+ 'slow', 'slower', 'veryslow'. Defaults to 'medium'.
50
+ """
51
+ if not frames:
52
+ print("Error: The 'frames' list is empty. No video to save.")
53
+ return
54
+
55
+ # --- Determine Parameters from First Frame ---
56
+ try:
57
+ first_frame = frames[0]
58
+ print(first_frame.shape)
59
+ if not isinstance(first_frame, np.ndarray):
60
+ print(f"Error: Frame 0 is not a NumPy array (type: {type(first_frame)}).")
61
+ return
62
+
63
+ frame_height, frame_width = first_frame.shape[:2]
64
+ frame_size_str = f"{frame_width}x{frame_height}"
65
+
66
+ # Determine input pixel format based on first frame's shape
67
+ if len(first_frame.shape) == 3 and first_frame.shape[2] == 3:
68
+ input_pixel_format = 'bgr24' # Assume OpenCV's default BGR uint8
69
+ expected_dims = 3
70
+ print(f"Info: Detected color frames (shape: {first_frame.shape}). Expecting BGR input.")
71
+ elif len(first_frame.shape) == 2:
72
+ input_pixel_format = 'gray'
73
+ expected_dims = 2
74
+ print(f"Info: Detected grayscale frames (shape: {first_frame.shape}).")
75
+ else:
76
+ print(f"Error: Unsupported frame shape {first_frame.shape}. Must be (h, w) or (h, w, 3).")
77
+ return
78
+
79
+ if first_frame.dtype != np.uint8:
80
+ print(f"Warning: First frame dtype is {first_frame.dtype}. Will attempt conversion to uint8.")
81
+
82
+ except IndexError:
83
+ print("Error: Could not access the first frame to determine dimensions.")
84
+ return
85
+ except Exception as e:
86
+ print(f"Error processing first frame: {e}")
87
+ return
88
+
89
+ # --- Set GOP size default if not provided ---
90
+ if gop_size is None:
91
+ gop_size = int(fps)
92
+ print(f"Info: GOP size not specified, defaulting to {gop_size} (approx 1 keyframe/sec).")
93
+
94
+ # --- Construct FFmpeg Command ---
95
+ # ADDED -vf pad filter to ensure even dimensions for libx264/yuv420p
96
+ # It calculates the nearest even dimensions >= original dimensions
97
+ # Example: 1600x1351 -> 1600x1352
98
+ pad_filter = "pad=ceil(iw/2)*2:ceil(ih/2)*2"
99
+
100
+ command = [
101
+ 'ffmpeg',
102
+ '-y',
103
+ '-f', 'rawvideo',
104
+ '-vcodec', 'rawvideo',
105
+ '-pix_fmt', input_pixel_format,
106
+ '-s', frame_size_str,
107
+ '-r', str(float(fps)),
108
+ '-i', '-',
109
+ '-vf', pad_filter, # <--- ADDED VIDEO FILTER HERE
110
+ '-c:v', 'libx264',
111
+ '-pix_fmt', pix_fmt,
112
+ '-preset', preset,
113
+ '-crf', str(crf),
114
+ '-g', str(gop_size),
115
+ '-movflags', '+faststart',
116
+ output_filename
117
+ ]
118
+
119
+ print(f"\n--- Starting FFmpeg ---")
120
+ print(f"Output File: {output_filename}")
121
+ print(f"Parameters: FPS={fps}, Size={frame_size_str}, GOP={gop_size}, CRF={crf}, Preset={preset}")
122
+ print(f"Applying Filter: -vf {pad_filter} (Ensures even dimensions)")
123
+ # print(f"FFmpeg Command: {' '.join(command)}") # Uncomment for debugging
124
+
125
+ # --- Execute FFmpeg via Subprocess ---
126
+ try:
127
+ process = sp.Popen(command, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE)
128
+
129
+ print(f"\nWriting {len(frames)} frames to FFmpeg...")
130
+ progress_interval = max(1, len(frames) // 10) # Print progress roughly 10 times
131
+
132
+ for i, frame in enumerate(frames):
133
+ # Basic validation and conversion for each frame
134
+ if not isinstance(frame, np.ndarray):
135
+ print(f"Warning: Frame {i} is not a numpy array (type: {type(frame)}). Skipping.")
136
+ continue
137
+ if frame.shape[0] != frame_height or frame.shape[1] != frame_width:
138
+ print(f"Warning: Frame {i} has different dimensions {frame.shape[:2]}! Expected ({frame_height},{frame_width}). Skipping.")
139
+ continue
140
+
141
+ current_dims = len(frame.shape)
142
+ if current_dims != expected_dims:
143
+ print(f"Warning: Frame {i} has inconsistent dimensions ({current_dims}D vs expected {expected_dims}D). Skipping.")
144
+ continue
145
+ if expected_dims == 3 and frame.shape[2] != 3:
146
+ print(f"Warning: Frame {i} is color but doesn't have 3 channels ({frame.shape}). Skipping.")
147
+ continue
148
+
149
+ if frame.dtype != np.uint8:
150
+ try:
151
+ frame = np.clip(frame, 0, 255).astype(np.uint8)
152
+ except Exception as clip_err:
153
+ print(f"Error clipping/converting frame {i} dtype: {clip_err}. Skipping.")
154
+ continue
155
+
156
+ # Write frame bytes to FFmpeg's stdin
157
+ try:
158
+ process.stdin.write(frame.tobytes())
159
+ except (OSError, BrokenPipeError) as pipe_err:
160
+ print(f"\nError writing frame {i} to FFmpeg stdin: {pipe_err}")
161
+ print("FFmpeg process likely terminated prematurely. Check FFmpeg errors below.")
162
+ try:
163
+ # Immediately try to read stderr if pipe breaks
164
+ stderr_output_on_error = process.stderr.read()
165
+ if stderr_output_on_error:
166
+ print("\n--- FFmpeg stderr output on error ---")
167
+ print(stderr_output_on_error.decode(errors='ignore'))
168
+ print("--- End FFmpeg stderr ---")
169
+ except Exception as read_err:
170
+ print(f"(Could not read stderr after pipe error: {read_err})")
171
+ return
172
+ except Exception as write_err:
173
+ print(f"Unexpected error writing frame {i}: {write_err}. Skipping.")
174
+ continue
175
+
176
+ if (i + 1) % progress_interval == 0 or (i + 1) == len(frames):
177
+ print(f" Processed frame {i + 1}/{len(frames)}")
178
+
179
+ print("\nFinished writing frames. Closing FFmpeg stdin and waiting for completion...")
180
+ process.stdin.close()
181
+ stdout, stderr = process.communicate()
182
+ return_code = process.wait()
183
+
184
+ print("\n--- FFmpeg Final Status ---")
185
+ if return_code == 0:
186
+ print(f"FFmpeg process completed successfully.")
187
+ print(f"Video saved as: {output_filename}")
188
+ else:
189
+ print(f"FFmpeg process failed with return code {return_code}.")
190
+ print("--- FFmpeg Standard Error Output: ---")
191
+ print(stderr.decode(errors='replace')) # Print stderr captured by communicate()
192
+ print("--- End FFmpeg Output ---")
193
+ print("Review the FFmpeg error message above for details (e.g., dimension errors, parameter issues).")
194
+
195
+ except FileNotFoundError:
196
+ print("\n--- FATAL ERROR ---")
197
+ print("Error: 'ffmpeg' command not found.")
198
+ print("Please ensure FFmpeg is installed and its directory is included in your system's PATH environment variable.")
199
+ print("Download from: https://ffmpeg.org/")
200
+ print("-------------------")
201
+ except Exception as e:
202
+ print(f"\nAn unexpected error occurred during FFmpeg execution: {e}")
203
+
204
+ def find_island_centers(array_2d, threshold):
205
+ """
206
+ Finds the center of mass of each island (connected component) in a 2D array.
207
+
208
+ Args:
209
+ array_2d: A 2D numpy array of values.
210
+ threshold: The threshold to binarize the array.
211
+
212
+ Returns:
213
+ A list of tuples (y, x) representing the center of mass of each island.
214
+ """
215
+ binary_image = array_2d > threshold
216
+ labeled_image, num_labels = ndimage.label(binary_image)
217
+ centers = []
218
+ areas = [] # Store the area of each island
219
+ for i in range(1, num_labels + 1):
220
+ island = (labeled_image == i)
221
+ total_mass = np.sum(array_2d[island])
222
+ if total_mass > 0:
223
+ y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
224
+ x_center = np.average(x_coords[island], weights=array_2d[island])
225
+ y_center = np.average(y_coords[island], weights=array_2d[island])
226
+ centers.append((round(y_center, 4), round(x_center, 4)))
227
+ areas.append(np.sum(island)) # Calculate area of the island
228
+ return centers, areas
229
+
230
+ def plot_neural_dynamics(post_activations_history, N_to_plot, save_location, axis_snap=False, N_per_row=5, which_neurons_mid=None, mid_colours=None, use_most_active_neurons=False):
231
+ assert N_to_plot%N_per_row==0, f'For nice visualisation, N_to_plot={N_to_plot} must be a multiple of N_per_row={N_per_row}'
232
+ assert post_activations_history.shape[-1] >= N_to_plot
233
+ figscale = 2
234
+ aspect_ratio = 3
235
+ mosaic = np.array([[f'{i}'] for i in range(N_to_plot)]).flatten().reshape(-1, N_per_row)
236
+ fig_synch, axes_synch = plt.subplot_mosaic(mosaic=mosaic, figsize=(figscale*mosaic.shape[1]*aspect_ratio*0.2, figscale*mosaic.shape[0]*0.2))
237
+ fig_mid, axes_mid = plt.subplot_mosaic(mosaic=mosaic, figsize=(figscale*mosaic.shape[1]*aspect_ratio*0.2, figscale*mosaic.shape[0]*0.2), dpi=200)
238
+
239
+ palette = sns.color_palette("husl", 8)
240
+
241
+ which_neurons_synch = np.arange(N_to_plot)
242
+ # which_neurons_mid = np.arange(N_to_plot, N_to_plot*2) if post_activations_history.shape[-1] >= 2*N_to_plot else np.random.choice(np.arange(post_activations_history.shape[-1]), size=N_to_plot, replace=True)
243
+ random_indices = np.random.choice(np.arange(post_activations_history.shape[-1]), size=N_to_plot, replace=post_activations_history.shape[-1] < N_to_plot)
244
+ if use_most_active_neurons:
245
+ metric = np.abs(np.fft.rfft(post_activations_history, axis=0))[3:].mean(0).std(0)
246
+ random_indices = np.argsort(metric)[-N_to_plot:]
247
+ np.random.shuffle(random_indices)
248
+ which_neurons_mid = which_neurons_mid if which_neurons_mid is not None else random_indices
249
+
250
+ if mid_colours is None:
251
+ mid_colours = [palette[np.random.randint(0, 8)] for ndx in range(N_to_plot)]
252
+ with tqdm(total=N_to_plot, initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
253
+ pbar_inner.set_description('Plotting neural dynamics')
254
+ for ndx in range(N_to_plot):
255
+
256
+ ax_s = axes_synch[f'{ndx}']
257
+ ax_m = axes_mid[f'{ndx}']
258
+
259
+ traces_s = post_activations_history[:,:,which_neurons_synch[ndx]].T
260
+ traces_m = post_activations_history[:,:,which_neurons_mid[ndx]].T
261
+ c_s = palette[np.random.randint(0, 8)]
262
+ c_m = mid_colours[ndx]
263
+
264
+ for traces_s_here, traces_m_here in zip(traces_s, traces_m):
265
+ ax_s.plot(np.arange(len(traces_s_here)), traces_s_here, linestyle='-', color=c_s, alpha=0.05, linewidth=0.6)
266
+ ax_m.plot(np.arange(len(traces_m_here)), traces_m_here, linestyle='-', color=c_m, alpha=0.05, linewidth=0.6)
267
+
268
+
269
+ ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color='white', alpha=1, linewidth=2.5)
270
+ ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color=c_s, alpha=1, linewidth=1.3)
271
+ ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color='black', alpha=1, linewidth=0.3)
272
+ ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color='white', alpha=1, linewidth=2.5)
273
+ ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color=c_m, alpha=1, linewidth=1.3)
274
+ ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color='black', alpha=1, linewidth=0.3)
275
+ if axis_snap and np.all(np.isfinite(traces_s[0])):
276
+ ax_s.set_ylim([np.min(traces_s[0])-np.ptp(traces_s[0])*0.05, np.max(traces_s[0])+np.ptp(traces_s[0])*0.05])
277
+ ax_m.set_ylim([np.min(traces_m[0])-np.ptp(traces_m[0])*0.05, np.max(traces_m[0])+np.ptp(traces_m[0])*0.05])
278
+
279
+
280
+ ax_s.grid(False)
281
+ ax_m.grid(False)
282
+ ax_s.set_xlim([0, len(traces_s[0])-1])
283
+ ax_m.set_xlim([0, len(traces_m[0])-1])
284
+
285
+ ax_s.set_xticklabels([])
286
+ ax_s.set_yticklabels([])
287
+
288
+ ax_m.set_xticklabels([])
289
+ ax_m.set_yticklabels([])
290
+ pbar_inner.update(1)
291
+ fig_synch.tight_layout(pad=0.05)
292
+ fig_mid.tight_layout(pad=0.05)
293
+ if save_location is not None:
294
+ fig_synch.savefig(f'{save_location}/neural_dynamics_synch.pdf', dpi=200)
295
+ fig_synch.savefig(f'{save_location}/neural_dynamics_synch.png', dpi=200)
296
+ fig_mid.savefig(f'{save_location}/neural_dynamics_other.pdf', dpi=200)
297
+ fig_mid.savefig(f'{save_location}/neural_dynamics_other.png', dpi=200)
298
+ plt.close(fig_synch)
299
+ plt.close(fig_mid)
300
+ return fig_synch, fig_mid, which_neurons_mid, mid_colours
301
+
302
+
303
+
304
+ def make_classification_gif(image, target, predictions, certainties, post_activations, attention_tracking, class_labels, save_location):
305
+ cmap_viridis = sns.color_palette('viridis', as_cmap=True)
306
+ cmap_spectral = sns.color_palette("Spectral", as_cmap=True)
307
+ figscale = 2
308
+ with tqdm(total=post_activations.shape[0]+1, initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
309
+ pbar_inner.set_description('Computing UMAP')
310
+
311
+
312
+ low = np.percentile(post_activations, 1, axis=0, keepdims=True)
313
+ high = np.percentile(post_activations, 99, axis=0, keepdims=True)
314
+ post_activations_normed = np.clip((post_activations - low)/(high - low), 0, 1)
315
+ metric = 'cosine'
316
+ reducer = umap.UMAP(n_components=2,
317
+ n_neighbors=100,
318
+ min_dist=3,
319
+ spread=3.0,
320
+ metric=metric,
321
+ random_state=None,
322
+ # low_memory=True,
323
+ ) if post_activations.shape[-1] > 2048 else umap.UMAP(n_components=2,
324
+ n_neighbors=20,
325
+ min_dist=1,
326
+ spread=1.0,
327
+ metric=metric,
328
+ random_state=None,
329
+ # low_memory=True,
330
+ )
331
+ positions = reducer.fit_transform(post_activations_normed.T)
332
+
333
+ x_umap = positions[:, 0]
334
+ y_umap = positions[:, 1]
335
+
336
+ pbar_inner.update(1)
337
+ pbar_inner.set_description('Iterating through to build frames')
338
+
339
+
340
+
341
+ frames = []
342
+ route_steps = {}
343
+ route_colours = []
344
+
345
+ n_steps = len(post_activations)
346
+ n_heads = attention_tracking.shape[1]
347
+ step_linspace = np.linspace(0, 1, n_steps)
348
+
349
+ for stepi in np.arange(0, n_steps, 1):
350
+ pbar_inner.set_description('Making frames for gif')
351
+
352
+
353
+ attention_now = attention_tracking[max(0, stepi-5):stepi+1].mean(0) # Make it smooth for pretty
354
+ # attention_now[:,0,0] = 0 # Corners can be weird looking
355
+ # attention_now[:,0,-1] = 0
356
+ # attention_now[:,-1,0] = 0
357
+ # attention_now[:,-1,-1] = 0
358
+ # attention_now = (attention_tracking[:stepi+1, 0] * decay).sum(0)/(decay.sum(0))
359
+ certainties_now = certainties[1, :stepi+1]
360
+ attention_interp = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), image.shape[:2], mode='bilinear')[0]
361
+ attention_interp = (attention_interp.flatten(1) - attention_interp.flatten(1).min(-1, keepdim=True)[0])/(attention_interp.flatten(1).max(-1, keepdim=True)[0] - attention_interp.flatten(1).min(-1, keepdim=True)[0])
362
+ attention_interp = attention_interp.reshape(n_heads, image.shape[0], image.shape[1])
363
+
364
+
365
+ colour = list(cmap_spectral(step_linspace[stepi]))
366
+ route_colours.append(colour)
367
+ for headi in range(min(8, n_heads)):
368
+ com_attn = np.copy(attention_interp[headi])
369
+ com_attn[com_attn < np.percentile(com_attn, 97)] = 0.0
370
+ if headi not in route_steps:
371
+ A = attention_interp[headi].detach().cpu().numpy()
372
+ centres, areas = find_island_centers(A, threshold=0.7)
373
+ route_steps[headi] = [centres[np.argmax(areas)]]
374
+ else:
375
+ A = attention_interp[headi].detach().cpu().numpy()
376
+ centres, areas = find_island_centers(A, threshold=0.7)
377
+ route_steps[headi] = route_steps[headi] + [centres[np.argmax(areas)]]
378
+
379
+ mosaic = [['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay'],
380
+ ['head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],
381
+ ['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay'],
382
+ ['head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],
383
+ ['probabilities', 'probabilities','certainty', 'certainty'],
384
+ ['umap', 'umap', 'umap', 'umap'],
385
+ ['umap', 'umap', 'umap', 'umap'],
386
+ ['umap', 'umap', 'umap', 'umap'],
387
+
388
+ ]
389
+
390
+
391
+ img_aspect = image.shape[0]/image.shape[1]
392
+ # print(img_aspect)
393
+ aspect_ratio = (4*figscale, 8*figscale*img_aspect)
394
+ fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
395
+ for ax in axes.values():
396
+ ax.axis('off')
397
+
398
+
399
+ axes['certainty'].plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1, label='1-(normalised entropy)')
400
+ for ii, (x, y) in enumerate(zip(np.arange(len(certainties_now)), certainties_now)):
401
+ is_correct = predictions[:, ii].argmax(-1)==target
402
+ if is_correct: axes['certainty'].axvspan(ii, ii + 1, facecolor='limegreen', edgecolor=None, lw=0, alpha=0.3)
403
+ else:
404
+ axes['certainty'].axvspan(ii, ii + 1, facecolor='orchid', edgecolor=None, lw=0, alpha=0.3)
405
+ axes['certainty'].plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
406
+ axes['certainty'].axis('off')
407
+ axes['certainty'].set_ylim([-0.05, 1.05])
408
+ axes['certainty'].set_xlim([0, certainties.shape[-1]+1])
409
+
410
+ ps = torch.softmax(torch.from_numpy(predictions[:, stepi]), -1)
411
+ k = 15 if len(class_labels) > 15 else len(class_labels)
412
+ topk = torch.topk (ps, k, dim = 0, largest=True).indices.detach().cpu().numpy()
413
+ top_classes = np.array(class_labels)[topk]
414
+ true_class = target
415
+ colours = [('b' if ci != true_class else 'g') for ci in topk]
416
+ bar_heights = ps[topk].detach().cpu().numpy()
417
+
418
+
419
+ axes['probabilities'].bar(np.arange(len(bar_heights))[::-1], bar_heights, color=np.array(colours), alpha=1)
420
+ axes['probabilities'].set_ylim([0, 1])
421
+ axes['probabilities'].axis('off')
422
+
423
+
424
+ for i, (name) in enumerate(top_classes):
425
+ prob = ps[i]
426
+ is_correct = name==class_labels[true_class]
427
+ fg_color = 'darkgreen' if is_correct else 'crimson'
428
+ text_str = f'{name[:40]}'
429
+ axes['probabilities'].text(
430
+ 0.05,
431
+ 0.95 - i * 0.055, # Adjust vertical position for each line
432
+ text_str,
433
+ transform=axes['probabilities'].transAxes,
434
+ verticalalignment='top',
435
+ fontsize=8, # Increased font size
436
+ color=fg_color,
437
+ alpha=0.5,
438
+ path_effects=[
439
+ patheffects.Stroke(linewidth=3, foreground='aliceblue'),
440
+ patheffects.Normal()
441
+ ])
442
+
443
+
444
+
445
+ attention_now = attention_tracking[max(0, stepi-5):stepi+1].mean(0) # Make it smooth for pretty
446
+ # attention_now = (attention_tracking[:stepi+1, 0] * decay).sum(0)/(decay.sum(0))
447
+ certainties_now = certainties[1, :stepi+1]
448
+ attention_interp = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), image.shape[:2], mode='nearest')[0]
449
+ attention_interp = (attention_interp.flatten(1) - attention_interp.flatten(1).min(-1, keepdim=True)[0])/(attention_interp.flatten(1).max(-1, keepdim=True)[0] - attention_interp.flatten(1).min(-1, keepdim=True)[0])
450
+ attention_interp = attention_interp.reshape(n_heads, image.shape[0], image.shape[1])
451
+
452
+ for hi in range(min(8, n_heads)):
453
+ ax = axes[f'head_{hi}']
454
+ img_to_plot = cmap_viridis(attention_interp[hi].detach().cpu().numpy())
455
+ ax.imshow(img_to_plot)
456
+
457
+ ax_overlay = axes[f'head_{hi}_overlay']
458
+
459
+ these_route_steps = route_steps[hi]
460
+ y_coords, x_coords = zip(*these_route_steps)
461
+ y_coords = image.shape[-2] - np.array(list(y_coords))-1
462
+
463
+ ax_overlay.imshow(np.flip(image, axis=0), origin='lower')
464
+ # ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
465
+ arrow_scale = 1.5 if image.shape[0] > 32 else 0.8
466
+ for i in range(len(these_route_steps)-1):
467
+ dx = x_coords[i+1] - x_coords[i]
468
+ dy = y_coords[i+1] - y_coords[i]
469
+
470
+ ax_overlay.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=1.6*arrow_scale*1.3, head_width=1.9*arrow_scale*1.3, head_length=1.4*arrow_scale*1.45, fc='white', ec='white', length_includes_head = True, alpha=1)
471
+ ax_overlay.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=1.6*arrow_scale, head_width=1.9*arrow_scale, head_length=1.4*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
472
+
473
+ ax_overlay.set_xlim([0,image.shape[1]-1])
474
+ ax_overlay.set_ylim([0,image.shape[0]-1])
475
+ ax_overlay.axis('off')
476
+
477
+
478
+ z = post_activations_normed[stepi]
479
+
480
+ axes['umap'].scatter(x_umap, y_umap, s=30, c=cmap_spectral(z))
481
+
482
+ fig.tight_layout(pad=0.1)
483
+
484
+
485
+
486
+ canvas = fig.canvas
487
+ canvas.draw()
488
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
489
+ image_numpy = (image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3])
490
+ frames.append(image_numpy)
491
+ plt.close(fig)
492
+ pbar_inner.update(1)
493
+ pbar_inner.set_description('Saving gif')
494
+ imageio.mimsave(save_location, frames, fps=15, loop=100)
tasks/image_classification/scripts/train_cifar10.sh ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -m tasks.image_classification.train \
2
+ --log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \
3
+ --model ctm
4
+ --dataset cifar10 \
5
+ --d_model 256 \
6
+ --d_input 64 \
7
+ --synapse_depth 5 \
8
+ --heads 16 \
9
+ --n_synch_out 256 \
10
+ --n_synch_action 512 \
11
+ --n_random_pairing_self 0 \
12
+ --neuron_select_type random-pairing \
13
+ --iterations 50 \
14
+ --memory_length 15 \
15
+ --deep_memory \
16
+ --memory_hidden_dims 64 \
17
+ --dropout 0.0 \
18
+ --dropout_nlm 0 \
19
+ --no-do_normalisation \
20
+ --positional_embedding_type none \
21
+ --backbone_type resnet18-1 \
22
+ --training_iterations 600001 \
23
+ --warmup_steps 1000 \
24
+ --use_scheduler \
25
+ --scheduler_type cosine \
26
+ --weight_decay 0.0001 \
27
+ --save_every 1000 \
28
+ --track_every 2000 \
29
+ --n_test_batches 50 \
30
+ --num_workers_train 8 \
31
+ --batch_size 512 \
32
+ --batch_size_test 512 \
33
+ --lr 1e-4 \
34
+ --device 0 \
35
+ --seed 1
36
+
37
+
38
+ python -m tasks.image_classification.train \
39
+ --log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \
40
+ --model ctm
41
+ --dataset cifar10 \
42
+ --d_model 256 \
43
+ --d_input 64 \
44
+ --synapse_depth 5 \
45
+ --heads 16 \
46
+ --n_synch_out 256 \
47
+ --n_synch_action 512 \
48
+ --n_random_pairing_self 0 \
49
+ --neuron_select_type random-pairing \
50
+ --iterations 50 \
51
+ --memory_length 15 \
52
+ --deep_memory \
53
+ --memory_hidden_dims 64 \
54
+ --dropout 0.0 \
55
+ --dropout_nlm 0 \
56
+ --no-do_normalisation \
57
+ --positional_embedding_type none \
58
+ --backbone_type resnet18-1 \
59
+ --training_iterations 600001 \
60
+ --warmup_steps 1000 \
61
+ --use_scheduler \
62
+ --scheduler_type cosine \
63
+ --weight_decay 0.0001 \
64
+ --save_every 1000 \
65
+ --track_every 2000 \
66
+ --n_test_batches 50 \
67
+ --num_workers_train 8 \
68
+ --batch_size 512 \
69
+ --batch_size_test 512 \
70
+ --lr 1e-4 \
71
+ --device 0 \
72
+ --seed 2
73
+
74
+ python -m tasks.image_classification.train \
75
+ --log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \
76
+ --model ctm
77
+ --dataset cifar10 \
78
+ --d_model 256 \
79
+ --d_input 64 \
80
+ --synapse_depth 5 \
81
+ --heads 16 \
82
+ --n_synch_out 256 \
83
+ --n_synch_action 512 \
84
+ --n_random_pairing_self 0 \
85
+ --neuron_select_type random-pairing \
86
+ --iterations 50 \
87
+ --memory_length 15 \
88
+ --deep_memory \
89
+ --memory_hidden_dims 64 \
90
+ --dropout 0.0 \
91
+ --dropout_nlm 0 \
92
+ --no-do_normalisation \
93
+ --positional_embedding_type none \
94
+ --backbone_type resnet18-1 \
95
+ --training_iterations 600001 \
96
+ --warmup_steps 1000 \
97
+ --use_scheduler \
98
+ --scheduler_type cosine \
99
+ --weight_decay 0.0001 \
100
+ --save_every 1000 \
101
+ --track_every 2000 \
102
+ --n_test_batches 50 \
103
+ --num_workers_train 8 \
104
+ --batch_size 512 \
105
+ --batch_size_test 512 \
106
+ --lr 1e-4 \
107
+ --device 0 \
108
+ --seed 42
109
+
110
+
111
+
112
+
113
+
114
+
115
+ python -m tasks.image_classification.train \
116
+ --log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \
117
+ --dataset cifar10 \
118
+ --model lstm \
119
+ --num_layers 2 \
120
+ --d_model 256 \
121
+ --d_input 64 \
122
+ --heads 16 \
123
+ --iterations 50 \
124
+ --dropout 0.0 \
125
+ --positional_embedding_type none \
126
+ --backbone_type resnet18-1 \
127
+ --training_iterations 600001 \
128
+ --warmup_steps 2000 \
129
+ --use_scheduler \
130
+ --scheduler_type cosine \
131
+ --weight_decay 0.0001 \
132
+ --save_every 1000 \
133
+ --track_every 2000 \
134
+ --n_test_batches 50 \
135
+ --reload \
136
+ --num_workers_train 8 \
137
+ --batch_size 512 \
138
+ --batch_size_test 512 \
139
+ --lr 1e-4 \
140
+ --device 0 \
141
+ --seed 1 \
142
+ --no-reload
143
+
144
+
145
+ python -m tasks.image_classification.train \
146
+ --log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \
147
+ --dataset cifar10 \
148
+ --model lstm \
149
+ --num_layers 2 \
150
+ --d_model 256 \
151
+ --d_input 64 \
152
+ --heads 16 \
153
+ --iterations 50 \
154
+ --dropout 0.0 \
155
+ --positional_embedding_type none \
156
+ --backbone_type resnet18-1 \
157
+ --training_iterations 600001 \
158
+ --warmup_steps 2000 \
159
+ --use_scheduler \
160
+ --scheduler_type cosine \
161
+ --weight_decay 0.0001 \
162
+ --save_every 1000 \
163
+ --track_every 2000 \
164
+ --n_test_batches 50 \
165
+ --reload \
166
+ --num_workers_train 8 \
167
+ --batch_size 512 \
168
+ --batch_size_test 512 \
169
+ --lr 1e-4 \
170
+ --device 0 \
171
+ --seed 2 \
172
+ --no-reload
173
+
174
+
175
+ python -m tasks.image_classification.train \
176
+ --log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \
177
+ --dataset cifar10 \
178
+ --model lstm \
179
+ --num_layers 2 \
180
+ --d_model 256 \
181
+ --d_input 64 \
182
+ --heads 16 \
183
+ --iterations 50 \
184
+ --dropout 0.0 \
185
+ --positional_embedding_type none \
186
+ --backbone_type resnet18-1 \
187
+ --training_iterations 600001 \
188
+ --warmup_steps 2000 \
189
+ --use_scheduler \
190
+ --scheduler_type cosine \
191
+ --weight_decay 0.0001 \
192
+ --save_every 1000 \
193
+ --track_every 2000 \
194
+ --n_test_batches 50 \
195
+ --reload \
196
+ --num_workers_train 8 \
197
+ --batch_size 512 \
198
+ --batch_size_test 512 \
199
+ --lr 1e-4 \
200
+ --device 0 \
201
+ --seed 42 \
202
+ --no-reload
203
+
204
+
205
+
206
+
207
+
208
+ python -m tasks.image_classification.train \
209
+ --log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=1 \
210
+ --dataset cifar10 \
211
+ --model ff \
212
+ --d_model 256 \
213
+ --memory_hidden_dims 64 \
214
+ --dropout 0.0 \
215
+ --dropout_nlm 0 \
216
+ --backbone_type resnet18-1 \
217
+ --training_iterations 600001 \
218
+ --warmup_steps 1000 \
219
+ --use_scheduler \
220
+ --scheduler_type cosine \
221
+ --weight_decay 0.0001 \
222
+ --save_every 1000 \
223
+ --track_every 2000 \
224
+ --n_test_batches 50 \
225
+ --num_workers_train 8 \
226
+ --batch_size 512 \
227
+ --batch_size_test 512 \
228
+ --lr 1e-4 \
229
+ --device 0 \
230
+ --seed 1
231
+
232
+
233
+ python -m tasks.image_classification.train \
234
+ --log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=2 \
235
+ --dataset cifar10 \
236
+ --model ff \
237
+ --d_model 256 \
238
+ --memory_hidden_dims 64 \
239
+ --dropout 0.0 \
240
+ --dropout_nlm 0 \
241
+ --backbone_type resnet18-1 \
242
+ --training_iterations 600001 \
243
+ --warmup_steps 1000 \
244
+ --use_scheduler \
245
+ --scheduler_type cosine \
246
+ --weight_decay 0.0001 \
247
+ --save_every 1000 \
248
+ --track_every 2000 \
249
+ --n_test_batches 50 \
250
+ --num_workers_train 8 \
251
+ --batch_size 512 \
252
+ --batch_size_test 512 \
253
+ --lr 1e-4 \
254
+ --device 0 \
255
+ --seed 2
256
+
257
+ python -m tasks.image_classification.train \
258
+ --log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=42 \
259
+ --dataset cifar10 \
260
+ --model ff \
261
+ --d_model 256 \
262
+ --memory_hidden_dims 64 \
263
+ --dropout 0.0 \
264
+ --dropout_nlm 0 \
265
+ --backbone_type resnet18-1 \
266
+ --training_iterations 600001 \
267
+ --warmup_steps 1000 \
268
+ --use_scheduler \
269
+ --scheduler_type cosine \
270
+ --weight_decay 0.0001 \
271
+ --save_every 1000 \
272
+ --track_every 2000 \
273
+ --n_test_batches 50 \
274
+ --num_workers_train 8 \
275
+ --batch_size 512 \
276
+ --batch_size_test 512 \
277
+ --lr 1e-4 \
278
+ --device 0 \
279
+ --seed 42
280
+
281
+
282
+
283
+
284
+
285
+
286
+
tasks/image_classification/scripts/train_imagenet.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed \
2
+ --log_dir logs/imagenet/d=4096--i=1024--heads=16--sd=8--nlm=64--synch=8192-2048-32-h=64-random-pairing--iters=50x25--backbone=152x4 \
3
+ --model ctm \
4
+ --dataset imagenet \
5
+ --d_model 4096 \
6
+ --d_input 1024 \
7
+ --synapse_depth 8 \
8
+ --heads 16 \
9
+ --n_synch_out 8196 \
10
+ --n_synch_action 2048 \
11
+ --n_random_pairing_self 32 \
12
+ --neuron_select_type random-pairing \
13
+ --iterations 50 \
14
+ --memory_length 25 \
15
+ --deep_memory \
16
+ --memory_hidden_dims 64 \
17
+ --dropout 0.2 \
18
+ --dropout_nlm 0 \
19
+ --no-do_normalisation \
20
+ --positional_embedding_type none \
21
+ --backbone_type resnet152-4 \
22
+ --batch_size 64 \
23
+ --batch_size_test 64 \
24
+ --n_test_batches 200 \
25
+ --lr 5e-4 \
26
+ --gradient_clipping 20 \
27
+ --training_iterations 500001 \
28
+ --save_every 1000 \
29
+ --track_every 5000 \
30
+ --warmup_steps 10000 \
31
+ --use_scheduler \
32
+ --scheduler_type cosine \
33
+ --weight_decay 0.0 \
34
+ --seed 1 \
35
+ --use_amp \
36
+ --reload \
37
+ --num_workers_train 8 \
38
+ --use_custom_sampler
tasks/image_classification/train.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import seaborn as sns
8
+ sns.set_style('darkgrid')
9
+ import torch
10
+ if torch.cuda.is_available():
11
+ # For faster
12
+ torch.set_float32_matmul_precision('high')
13
+ import torch.nn as nn
14
+ from tqdm.auto import tqdm
15
+
16
+ from data.custom_datasets import ImageNet
17
+ from torchvision import datasets
18
+ from torchvision import transforms
19
+ from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
20
+ from models.ctm import ContinuousThoughtMachine
21
+ from models.lstm import LSTMBaseline
22
+ from models.ff import FFBaseline
23
+ from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif
24
+ from utils.housekeeping import set_seed, zip_python_code
25
+ from utils.losses import image_classification_loss # Used by CTM, LSTM
26
+ from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
27
+
28
+ from autoclip.torch import QuantileClip
29
+
30
+ import gc
31
+ import torchvision
32
+ torchvision.disable_beta_transforms_warning()
33
+
34
+
35
+ import warnings
36
+ warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
37
+ warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
38
+ warnings.filterwarnings(
39
+ "ignore",
40
+ "Corrupt EXIF data",
41
+ UserWarning,
42
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
43
+ )
44
+ warnings.filterwarnings(
45
+ "ignore",
46
+ "UserWarning: Metadata Warning",
47
+ UserWarning,
48
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
49
+ )
50
+ warnings.filterwarnings(
51
+ "ignore",
52
+ "UserWarning: Truncated File Read",
53
+ UserWarning,
54
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
55
+ )
56
+
57
+
58
+ def parse_args():
59
+ parser = argparse.ArgumentParser()
60
+
61
+ # Model Selection
62
+ parser.add_argument('--model', type=str, default='ctm', choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
63
+
64
+ # Model Architecture
65
+ # Common
66
+ parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
67
+ parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
68
+ parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.')
69
+ # CTM / LSTM specific
70
+ parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
71
+ parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).')
72
+ parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
73
+ parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).',
74
+ choices=['none',
75
+ 'learnable-fourier',
76
+ 'multi-learnable-fourier',
77
+ 'custom-rotational'])
78
+ # CTM specific
79
+ parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
80
+ parser.add_argument('--n_synch_out', type=int, default=512, help='Number of neurons to use for output synch (CTM only).')
81
+ parser.add_argument('--n_synch_action', type=int, default=512, help='Number of neurons to use for observation/action synch (CTM only).')
82
+ parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
83
+ parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
84
+ parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
85
+ parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
86
+ parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).')
87
+ parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
88
+ parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
89
+ # LSTM specific
90
+ parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
91
+
92
+ # Training
93
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training.')
94
+ parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing.')
95
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.')
96
+ parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
97
+ parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
98
+ parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
99
+ parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
100
+ parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
101
+ parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
102
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
103
+ parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
104
+ parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
105
+ parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components (backbone, synapses if CTM).')
106
+ parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.')
107
+
108
+ # Housekeeping
109
+ parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
110
+ parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.')
111
+ parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
112
+ parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
113
+ parser.add_argument('--seed', type=int, default=412, help='Random seed.')
114
+ parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
115
+ parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
116
+ parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
117
+ parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
118
+ parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
119
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
120
+ parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
121
+
122
+
123
+ args = parser.parse_args()
124
+ return args
125
+
126
+
127
+ def get_dataset(dataset, root):
128
+ if dataset=='imagenet':
129
+ dataset_mean = [0.485, 0.456, 0.406]
130
+ dataset_std = [0.229, 0.224, 0.225]
131
+
132
+ normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
133
+ train_transform = transforms.Compose([
134
+ transforms.RandomResizedCrop(224),
135
+ transforms.RandomHorizontalFlip(),
136
+ transforms.ToTensor(),
137
+ normalize])
138
+ test_transform = transforms.Compose([
139
+ transforms.Resize(256),
140
+ transforms.CenterCrop(224),
141
+ transforms.ToTensor(),
142
+ normalize])
143
+
144
+ class_labels = list(IMAGENET2012_CLASSES.values())
145
+
146
+ train_data = ImageNet(which_split='train', transform=train_transform)
147
+ test_data = ImageNet(which_split='validation', transform=test_transform)
148
+ elif dataset=='cifar10':
149
+ dataset_mean = [0.49139968, 0.48215827, 0.44653124]
150
+ dataset_std = [0.24703233, 0.24348505, 0.26158768]
151
+ normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
152
+ train_transform = transforms.Compose(
153
+ [transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
154
+ transforms.ToTensor(),
155
+ normalize,
156
+ ])
157
+
158
+ test_transform = transforms.Compose(
159
+ [transforms.ToTensor(),
160
+ normalize,
161
+ ])
162
+ train_data = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
163
+ test_data = datasets.CIFAR10(root, train=False, transform=test_transform, download=True)
164
+ class_labels = ['air', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
165
+ elif dataset=='cifar100':
166
+ dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
167
+ dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
168
+ normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
169
+
170
+ train_transform = transforms.Compose(
171
+ [transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
172
+ transforms.ToTensor(),
173
+ normalize,
174
+ ])
175
+ test_transform = transforms.Compose(
176
+ [transforms.ToTensor(),
177
+ normalize,
178
+ ])
179
+ train_data = datasets.CIFAR100(root, train=True, transform=train_transform, download=True)
180
+ test_data = datasets.CIFAR100(root, train=False, transform=test_transform, download=True)
181
+ idx_order = np.argsort(np.array(list(train_data.class_to_idx.values())))
182
+ class_labels = list(np.array(list(train_data.class_to_idx.keys()))[idx_order])
183
+ else:
184
+ raise NotImplementedError
185
+
186
+ return train_data, test_data, class_labels, dataset_mean, dataset_std
187
+
188
+
189
+
190
+ if __name__=='__main__':
191
+
192
+ # Hosuekeeping
193
+ args = parse_args()
194
+
195
+ set_seed(args.seed, False)
196
+ if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
197
+
198
+ assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
199
+
200
+ # Data
201
+ train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root)
202
+
203
+ num_workers_test = 1 # Defaulting to 1, change if needed
204
+ trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train)
205
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False)
206
+
207
+ prediction_reshaper = [-1] # Problem specific
208
+ args.out_dims = len(class_labels)
209
+
210
+ # For total reproducibility
211
+ zip_python_code(f'{args.log_dir}/repo_state.zip')
212
+ with open(f'{args.log_dir}/args.txt', 'w') as f:
213
+ print(args, file=f)
214
+
215
+ # Configure device string
216
+ device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
217
+ print(f'Running model {args.model} on {device}')
218
+
219
+ # Build model conditionally
220
+ model = None
221
+ if args.model == 'ctm':
222
+ model = ContinuousThoughtMachine(
223
+ iterations=args.iterations,
224
+ d_model=args.d_model,
225
+ d_input=args.d_input,
226
+ heads=args.heads,
227
+ n_synch_out=args.n_synch_out,
228
+ n_synch_action=args.n_synch_action,
229
+ synapse_depth=args.synapse_depth,
230
+ memory_length=args.memory_length,
231
+ deep_nlms=args.deep_memory,
232
+ memory_hidden_dims=args.memory_hidden_dims,
233
+ do_layernorm_nlm=args.do_normalisation,
234
+ backbone_type=args.backbone_type,
235
+ positional_embedding_type=args.positional_embedding_type,
236
+ out_dims=args.out_dims,
237
+ prediction_reshaper=prediction_reshaper,
238
+ dropout=args.dropout,
239
+ dropout_nlm=args.dropout_nlm,
240
+ neuron_select_type=args.neuron_select_type,
241
+ n_random_pairing_self=args.n_random_pairing_self,
242
+ ).to(device)
243
+ elif args.model == 'lstm':
244
+ model = LSTMBaseline(
245
+ num_layers=args.num_layers,
246
+ iterations=args.iterations,
247
+ d_model=args.d_model,
248
+ d_input=args.d_input,
249
+ heads=args.heads,
250
+ backbone_type=args.backbone_type,
251
+ positional_embedding_type=args.positional_embedding_type,
252
+ out_dims=args.out_dims,
253
+ prediction_reshaper=prediction_reshaper,
254
+ dropout=args.dropout,
255
+ ).to(device)
256
+ elif args.model == 'ff':
257
+ model = FFBaseline(
258
+ d_model=args.d_model,
259
+ backbone_type=args.backbone_type,
260
+ out_dims=args.out_dims,
261
+ dropout=args.dropout,
262
+ ).to(device)
263
+ else:
264
+ raise ValueError(f"Unknown model type: {args.model}")
265
+
266
+
267
+ # For lazy modules so that we can get param count
268
+ pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
269
+ model(pseudo_inputs)
270
+
271
+ model.train()
272
+
273
+
274
+ print(f'Total params: {sum(p.numel() for p in model.parameters())}')
275
+ decay_params = []
276
+ no_decay_params = []
277
+ no_decay_names = []
278
+ for name, param in model.named_parameters():
279
+ if not param.requires_grad:
280
+ continue # Skip parameters that don't require gradients
281
+ if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
282
+ no_decay_params.append(param)
283
+ no_decay_names.append(name)
284
+ else:
285
+ decay_params.append(param)
286
+ if len(no_decay_names):
287
+ print(f'WARNING, excluding: {no_decay_names}')
288
+
289
+ # Optimizer and scheduler (Common setup)
290
+ if len(no_decay_names) and args.weight_decay!=0:
291
+ optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
292
+ {'params': no_decay_params, 'weight_decay':0}],
293
+ lr=args.lr,
294
+ eps=1e-8 if not args.use_amp else 1e-6)
295
+ else:
296
+ optimizer = torch.optim.AdamW(model.parameters(),
297
+ lr=args.lr,
298
+ eps=1e-8 if not args.use_amp else 1e-6,
299
+ weight_decay=args.weight_decay)
300
+
301
+
302
+ warmup_schedule = warmup(args.warmup_steps)
303
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
304
+ if args.use_scheduler:
305
+ if args.scheduler_type == 'multistep':
306
+ scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
307
+ elif args.scheduler_type == 'cosine':
308
+ scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
309
+ else:
310
+ raise NotImplementedError
311
+
312
+
313
+ # Metrics tracking
314
+ start_iter = 0
315
+ train_losses = []
316
+ test_losses = []
317
+ train_accuracies = []
318
+ test_accuracies = []
319
+ iters = []
320
+ # Conditional metrics for CTM/LSTM
321
+ train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
322
+ test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
323
+
324
+ scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
325
+
326
+ # Reloading logic
327
+ if args.reload:
328
+ checkpoint_path = f'{args.log_dir}/checkpoint.pt'
329
+ if os.path.isfile(checkpoint_path):
330
+ print(f'Reloading from: {checkpoint_path}')
331
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
332
+ if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
333
+ load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
334
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
335
+
336
+ if not args.reload_model_only:
337
+ print('Reloading optimizer etc.')
338
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
339
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
340
+ scaler.load_state_dict(checkpoint['scaler_state_dict'])
341
+ start_iter = checkpoint['iteration']
342
+ # Load common metrics
343
+ train_losses = checkpoint['train_losses']
344
+ test_losses = checkpoint['test_losses']
345
+ train_accuracies = checkpoint['train_accuracies']
346
+ test_accuracies = checkpoint['test_accuracies']
347
+ iters = checkpoint['iters']
348
+
349
+ # Load conditional metrics if they exist in checkpoint and are expected for current model
350
+ if args.model in ['ctm', 'lstm']:
351
+ train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
352
+ test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
353
+
354
+ else:
355
+ print('Only reloading model!')
356
+
357
+ if 'torch_rng_state' in checkpoint:
358
+ # Reset seeds
359
+ torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte())
360
+ np.random.set_state(checkpoint['numpy_rng_state'])
361
+ random.setstate(checkpoint['random_rng_state'])
362
+
363
+ del checkpoint
364
+ gc.collect()
365
+ if torch.cuda.is_available():
366
+ torch.cuda.empty_cache()
367
+
368
+ # Conditional Compilation
369
+ if args.do_compile:
370
+ print('Compiling...')
371
+ if hasattr(model, 'backbone'):
372
+ model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True)
373
+
374
+ # Compile synapses only for CTM
375
+ if args.model == 'ctm':
376
+ model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True)
377
+
378
+ # Training
379
+ iterator = iter(trainloader)
380
+
381
+
382
+ with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar:
383
+ for bi in range(start_iter, args.training_iterations):
384
+ current_lr = optimizer.param_groups[-1]['lr']
385
+
386
+ try:
387
+ inputs, targets = next(iterator)
388
+ except StopIteration:
389
+ iterator = iter(trainloader)
390
+ inputs, targets = next(iterator)
391
+
392
+ inputs = inputs.to(device)
393
+ targets = targets.to(device)
394
+
395
+ loss = None
396
+ accuracy = None
397
+ # Model-specific forward and loss calculation
398
+ with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
399
+ if args.do_compile: # CUDAGraph marking for clean compile
400
+ torch.compiler.cudagraph_mark_step_begin()
401
+
402
+ if args.model == 'ctm':
403
+ predictions, certainties, synchronisation = model(inputs)
404
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
405
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
406
+ pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
407
+
408
+ elif args.model == 'lstm':
409
+ predictions, certainties, synchronisation = model(inputs)
410
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
411
+ # LSTM where_most_certain will just be -1 because use_most_certain is False owing to stability issues with LSTM training
412
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
413
+ pbar_desc = f'LSTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
414
+
415
+ elif args.model == 'ff':
416
+ predictions = model(inputs)
417
+ loss = nn.CrossEntropyLoss()(predictions, targets)
418
+ accuracy = (predictions.argmax(1) == targets).float().mean().item()
419
+ pbar_desc = f'FF Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
420
+
421
+ scaler.scale(loss).backward()
422
+
423
+ if args.gradient_clipping!=-1:
424
+ scaler.unscale_(optimizer)
425
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
426
+
427
+ scaler.step(optimizer)
428
+ scaler.update()
429
+ optimizer.zero_grad(set_to_none=True)
430
+ scheduler.step()
431
+
432
+ pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
433
+
434
+
435
+ # Metrics tracking and plotting (conditional logic needed)
436
+ if (bi % args.track_every == 0 or bi == args.warmup_steps) and (bi != 0 or args.reload_model_only):
437
+
438
+ iters.append(bi)
439
+ current_train_losses = []
440
+ current_test_losses = []
441
+ current_train_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
442
+ current_test_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
443
+ current_train_accuracies_most_certain = [] # Only for CTM/LSTM
444
+ current_test_accuracies_most_certain = [] # Only for CTM/LSTM
445
+
446
+
447
+ # Reset BN stats using train mode
448
+ pbar.set_description('Resetting BN')
449
+ model.train()
450
+ for module in model.modules():
451
+ if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
452
+ module.reset_running_stats()
453
+
454
+ pbar.set_description('Tracking: Computing TRAIN metrics')
455
+ with torch.no_grad(): # Should use inference_mode? CTM/LSTM scripts used no_grad
456
+ loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
457
+ all_targets_list = []
458
+ all_predictions_list = [] # List to store raw predictions (B, C, T) or (B, C)
459
+ all_predictions_most_certain_list = [] # Only for CTM/LSTM
460
+ all_losses = []
461
+
462
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
463
+ for inferi, (inputs, targets) in enumerate(loader):
464
+ inputs = inputs.to(device)
465
+ targets = targets.to(device)
466
+ all_targets_list.append(targets.detach().cpu().numpy())
467
+
468
+ # Model-specific forward and loss for evaluation
469
+ if args.model == 'ctm':
470
+ these_predictions, certainties, _ = model(inputs)
471
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
472
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
473
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
474
+
475
+ elif args.model == 'lstm':
476
+ these_predictions, certainties, _ = model(inputs)
477
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
478
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
479
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
480
+
481
+ elif args.model == 'ff':
482
+ these_predictions = model(inputs)
483
+ loss = nn.CrossEntropyLoss()(these_predictions, targets)
484
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B,)
485
+
486
+ all_losses.append(loss.item())
487
+
488
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break # Check condition >= N-1
489
+ pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})')
490
+ pbar_inner.update(1)
491
+
492
+ all_targets = np.concatenate(all_targets_list)
493
+ all_predictions = np.concatenate(all_predictions_list) # Shape (N, T) or (N,)
494
+ train_losses.append(np.mean(all_losses))
495
+
496
+ if args.model in ['ctm', 'lstm']:
497
+ # Accuracies per tick for CTM/LSTM
498
+ current_train_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0) # Mean over batch dim -> Shape (T,)
499
+ train_accuracies.append(current_train_accuracies)
500
+ # Most certain accuracy
501
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
502
+ current_train_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
503
+ train_accuracies_most_certain.append(current_train_accuracies_most_certain)
504
+ else: # FF
505
+ current_train_accuracies = (all_targets == all_predictions).mean() # Shape scalar
506
+ train_accuracies.append(current_train_accuracies)
507
+
508
+ del these_predictions
509
+
510
+
511
+ # Switch to eval mode for test metrics (fixed BN stats)
512
+ model.eval()
513
+ pbar.set_description('Tracking: Computing TEST metrics')
514
+ with torch.inference_mode(): # Use inference_mode for test eval
515
+ loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
516
+ all_targets_list = []
517
+ all_predictions_list = []
518
+ all_predictions_most_certain_list = [] # Only for CTM/LSTM
519
+ all_losses = []
520
+
521
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
522
+ for inferi, (inputs, targets) in enumerate(loader):
523
+ inputs = inputs.to(device)
524
+ targets = targets.to(device)
525
+ all_targets_list.append(targets.detach().cpu().numpy())
526
+
527
+ # Model-specific forward and loss for evaluation
528
+ if args.model == 'ctm':
529
+ these_predictions, certainties, _ = model(inputs)
530
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
531
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
532
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
533
+
534
+ elif args.model == 'lstm':
535
+ these_predictions, certainties, _ = model(inputs)
536
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
537
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
538
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
539
+
540
+ elif args.model == 'ff':
541
+ these_predictions = model(inputs)
542
+ loss = nn.CrossEntropyLoss()(these_predictions, targets)
543
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
544
+
545
+ all_losses.append(loss.item())
546
+
547
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
548
+ pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})')
549
+ pbar_inner.update(1)
550
+
551
+ all_targets = np.concatenate(all_targets_list)
552
+ all_predictions = np.concatenate(all_predictions_list)
553
+ test_losses.append(np.mean(all_losses))
554
+
555
+ if args.model in ['ctm', 'lstm']:
556
+ current_test_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0)
557
+ test_accuracies.append(current_test_accuracies)
558
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
559
+ current_test_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
560
+ test_accuracies_most_certain.append(current_test_accuracies_most_certain)
561
+ else: # FF
562
+ current_test_accuracies = (all_targets == all_predictions).mean()
563
+ test_accuracies.append(current_test_accuracies)
564
+
565
+ # Plotting (conditional)
566
+ figacc = plt.figure(figsize=(10, 10))
567
+ axacc_train = figacc.add_subplot(211)
568
+ axacc_test = figacc.add_subplot(212)
569
+ cm = sns.color_palette("viridis", as_cmap=True)
570
+
571
+ if args.model in ['ctm', 'lstm']:
572
+ # Plot per-tick accuracy for CTM/LSTM
573
+ train_acc_arr = np.array(train_accuracies) # Shape (N_iters, T)
574
+ test_acc_arr = np.array(test_accuracies) # Shape (N_iters, T)
575
+ num_ticks = train_acc_arr.shape[1]
576
+ for ti in range(num_ticks):
577
+ axacc_train.plot(iters, train_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
578
+ axacc_test.plot(iters, test_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
579
+ # Plot most certain accuracy
580
+ axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
581
+ axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
582
+ else: # FF
583
+ axacc_train.plot(iters, train_accuracies, 'k-', alpha=0.7, label='Accuracy') # Simple line
584
+ axacc_test.plot(iters, test_accuracies, 'k-', alpha=0.7, label='Accuracy')
585
+
586
+ axacc_train.set_title('Train Accuracy')
587
+ axacc_test.set_title('Test Accuracy')
588
+ axacc_train.legend(loc='lower right')
589
+ axacc_test.legend(loc='lower right')
590
+ axacc_train.set_xlim([0, args.training_iterations])
591
+ axacc_test.set_xlim([0, args.training_iterations])
592
+ if args.dataset=='cifar10':
593
+ axacc_train.set_ylim([0.75, 1])
594
+ axacc_test.set_ylim([0.75, 1])
595
+
596
+
597
+
598
+ figacc.tight_layout()
599
+ figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
600
+ plt.close(figacc)
601
+
602
+ figloss = plt.figure(figsize=(10, 5))
603
+ axloss = figloss.add_subplot(111)
604
+ axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}')
605
+ axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}')
606
+ axloss.legend(loc='upper right')
607
+ axloss.set_xlim([0, args.training_iterations])
608
+ axloss.set_ylim(bottom=0)
609
+
610
+ figloss.tight_layout()
611
+ figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
612
+ plt.close(figloss)
613
+
614
+ # Conditional Visualization (Only for CTM/LSTM)
615
+ if args.model in ['ctm', 'lstm']:
616
+ try: # For safety
617
+ inputs_viz, targets_viz = next(iter(testloader)) # Get a fresh batch
618
+ inputs_viz = inputs_viz.to(device)
619
+ targets_viz = targets_viz.to(device)
620
+
621
+ pbar.set_description('Tracking: Processing test data for viz')
622
+ predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
623
+
624
+ att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
625
+ attention_tracking_viz = attention_tracking_viz.reshape(
626
+ attention_tracking_viz.shape[0],
627
+ attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
628
+
629
+ pbar.set_description('Tracking: Neural dynamics plot')
630
+ plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
631
+
632
+ imgi = 0 # Visualize the first image in the batch
633
+ img_to_gif = np.moveaxis(np.clip(inputs_viz[imgi].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1)
634
+
635
+ pbar.set_description('Tracking: Producing attention gif')
636
+ make_classification_gif(img_to_gif,
637
+ targets_viz[imgi].item(),
638
+ predictions_viz[imgi].detach().cpu().numpy(),
639
+ certainties_viz[imgi].detach().cpu().numpy(),
640
+ post_activations_viz[:,imgi],
641
+ attention_tracking_viz[:,imgi],
642
+ class_labels,
643
+ f'{args.log_dir}/{imgi}_attention.gif',
644
+ )
645
+ del predictions_viz, certainties_viz, pre_activations_viz, post_activations_viz, attention_tracking_viz
646
+ except Exception as e:
647
+ print(f"Visualization failed for model {args.model}: {e}")
648
+
649
+
650
+
651
+ gc.collect()
652
+ if torch.cuda.is_available():
653
+ torch.cuda.empty_cache()
654
+ model.train() # Switch back to train mode
655
+
656
+
657
+ # Save model checkpoint (conditional metrics)
658
+ if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
659
+ pbar.set_description('Saving model checkpoint...')
660
+ checkpoint_data = {
661
+ 'model_state_dict': model.state_dict(),
662
+ 'optimizer_state_dict': optimizer.state_dict(),
663
+ 'scheduler_state_dict': scheduler.state_dict(),
664
+ 'scaler_state_dict': scaler.state_dict(),
665
+ 'iteration': bi,
666
+ # Always save these
667
+ 'train_losses': train_losses,
668
+ 'test_losses': test_losses,
669
+ 'train_accuracies': train_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
670
+ 'test_accuracies': test_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
671
+ 'iters': iters,
672
+ 'args': args, # Save args used for this run
673
+ # RNG states
674
+ 'torch_rng_state': torch.get_rng_state(),
675
+ 'numpy_rng_state': np.random.get_state(),
676
+ 'random_rng_state': random.getstate(),
677
+ }
678
+ # Conditionally add metrics specific to CTM/LSTM
679
+ if args.model in ['ctm', 'lstm']:
680
+ checkpoint_data['train_accuracies_most_certain'] = train_accuracies_most_certain
681
+ checkpoint_data['test_accuracies_most_certain'] = test_accuracies_most_certain
682
+
683
+ torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
684
+
685
+ pbar.update(1)
tasks/image_classification/train_distributed.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import time
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import seaborn as sns
9
+ sns.set_style('darkgrid')
10
+ import torch
11
+ if torch.cuda.is_available():
12
+ # For faster
13
+ torch.set_float32_matmul_precision('high')
14
+ import torch.nn as nn
15
+ import torch.distributed as dist
16
+ from torch.nn.parallel import DistributedDataParallel as DDP
17
+ from torch.utils.data.distributed import DistributedSampler
18
+ from utils.samplers import FastRandomDistributedSampler
19
+ from tqdm.auto import tqdm
20
+
21
+ from tasks.image_classification.train import get_dataset # Use shared get_dataset
22
+
23
+ # Model Imports
24
+ from models.ctm import ContinuousThoughtMachine
25
+ from models.lstm import LSTMBaseline
26
+ from models.ff import FFBaseline
27
+
28
+ # Plotting/Utils Imports
29
+ from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif
30
+ from utils.housekeeping import set_seed, zip_python_code
31
+ from utils.losses import image_classification_loss # For CTM, LSTM
32
+ from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
33
+
34
+ import torchvision
35
+ torchvision.disable_beta_transforms_warning()
36
+
37
+ import warnings
38
+ warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
39
+ warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
40
+ warnings.filterwarnings("ignore", message="UserWarning: Metadata Warning, tag 274 had too many entries: 4, expected 1")
41
+ warnings.filterwarnings(
42
+ "ignore",
43
+ "Corrupt EXIF data",
44
+ UserWarning,
45
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
46
+ )
47
+ warnings.filterwarnings(
48
+ "ignore",
49
+ "UserWarning: Metadata Warning",
50
+ UserWarning,
51
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
52
+ )
53
+ warnings.filterwarnings(
54
+ "ignore",
55
+ "UserWarning: Truncated File Read",
56
+ UserWarning,
57
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
58
+ )
59
+
60
+
61
+ def parse_args():
62
+ parser = argparse.ArgumentParser()
63
+
64
+ # Model Selection
65
+ parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
66
+
67
+ # Model Architecture
68
+ # Common
69
+ parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
70
+ parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
71
+ parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.')
72
+ # CTM / LSTM specific
73
+ parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
74
+ parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).')
75
+ parser.add_argument('--iterations', type=int, default=50, help='Number of internal ticks (CTM, LSTM).')
76
+ parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).',
77
+ choices=['none',
78
+ 'learnable-fourier',
79
+ 'multi-learnable-fourier',
80
+ 'custom-rotational'])
81
+ # CTM specific
82
+ parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
83
+ parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).')
84
+ parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).')
85
+ parser.add_argument('--neuron_select_type', type=str, default='first-last', help='Protocol for selecting neuron subset (CTM only).')
86
+ parser.add_argument('--n_random_pairing_self', type=int, default=256, help='Number of neurons paired self-to-self for synch (CTM only).')
87
+ parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
88
+ parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
89
+ parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).')
90
+ parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
91
+ parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
92
+ # LSTM specific
93
+ parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
94
+
95
+ # Training
96
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training (per GPU).')
97
+ parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing (per GPU).')
98
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.')
99
+ parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
100
+ parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
101
+ parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
102
+ parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
103
+ parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
104
+ parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
105
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
106
+ parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
107
+ parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
108
+ parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.')
109
+ parser.add_argument('--use_custom_sampler', action=argparse.BooleanOptionalAction, default=False, help='Use custom fast sampler to avoid reshuffling.')
110
+ parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
111
+
112
+ # Housekeeping
113
+ parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
114
+ parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.')
115
+ parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
116
+ parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
117
+ parser.add_argument('--seed', type=int, default=412, help='Random seed.')
118
+ parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
119
+ parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
120
+ parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.')
121
+ parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading?')
122
+
123
+ # Tracking
124
+ parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
125
+ parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
126
+ parser.add_argument('--plot_indices', type=int, default=[0], nargs='+', help='Which indices in test data to plot?') # Defaulted to 0
127
+
128
+ # Precision
129
+ parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
130
+ args = parser.parse_args()
131
+ return args
132
+
133
+ # --- DDP Setup Functions ---
134
+ def setup_ddp():
135
+ if 'RANK' not in os.environ:
136
+ # Basic setup for non-distributed run
137
+ os.environ['RANK'] = '0'
138
+ os.environ['WORLD_SIZE'] = '1'
139
+ os.environ['MASTER_ADDR'] = 'localhost'
140
+ os.environ['MASTER_PORT'] = '12355' # Ensure this port is free
141
+ os.environ['LOCAL_RANK'] = '0'
142
+ print("Running in non-distributed mode (simulated DDP setup).")
143
+ # Need to manually init if only 1 process desired for non-GPU testing
144
+ if not torch.cuda.is_available() or int(os.environ['WORLD_SIZE']) == 1:
145
+ dist.init_process_group(backend='gloo') # Gloo backend for CPU
146
+ print("Initialized process group with Gloo backend for single/CPU process.")
147
+ rank = int(os.environ['RANK'])
148
+ world_size = int(os.environ['WORLD_SIZE'])
149
+ local_rank = int(os.environ['LOCAL_RANK'])
150
+ return rank, world_size, local_rank
151
+
152
+
153
+ # Standard DDP setup
154
+ dist.init_process_group(backend='nccl') # 'nccl' for NVIDIA GPUs
155
+ rank = int(os.environ['RANK'])
156
+ world_size = int(os.environ['WORLD_SIZE'])
157
+ local_rank = int(os.environ['LOCAL_RANK'])
158
+ if torch.cuda.is_available():
159
+ torch.cuda.set_device(local_rank)
160
+ print(f"Rank {rank} setup on GPU {local_rank}")
161
+ else:
162
+ print(f"Rank {rank} setup on CPU (GPU not available or requested)")
163
+ return rank, world_size, local_rank
164
+
165
+ def cleanup_ddp():
166
+ if dist.is_initialized():
167
+ dist.destroy_process_group()
168
+ print("DDP cleanup complete.")
169
+
170
+ def is_main_process(rank):
171
+ return rank == 0
172
+ # --- End DDP Setup ---
173
+
174
+
175
+ if __name__=='__main__':
176
+
177
+ args = parse_args()
178
+
179
+ rank, world_size, local_rank = setup_ddp()
180
+
181
+ set_seed(args.seed + rank, False) # Add rank for different seeds per process
182
+
183
+ # Rank 0 handles directory creation and initial logging
184
+ if is_main_process(rank):
185
+ if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
186
+ zip_python_code(f'{args.log_dir}/repo_state.zip')
187
+ with open(f'{args.log_dir}/args.txt', 'w') as f:
188
+ print(args, file=f)
189
+ if world_size > 1: dist.barrier() # Sync after rank 0 setup
190
+
191
+
192
+ assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
193
+
194
+ # Data Loading
195
+ train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root)
196
+
197
+ # Setup Samplers
198
+ # This custom sampler is useful when using large batch sizes for Cifar. Otherwise the reshuffle happens tediously often
199
+ train_sampler = (FastRandomDistributedSampler(train_data, num_replicas=world_size, rank=rank, seed=args.seed, epoch_steps=int(10e10))
200
+ if args.use_custom_sampler else
201
+ DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True, seed=args.seed))
202
+ test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank, shuffle=False, seed=args.seed) # No shuffle needed for test; consistent
203
+
204
+ # Setup DataLoaders
205
+ trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler,
206
+ num_workers=args.num_workers_train, pin_memory=True, drop_last=True) # drop_last=True often used in DDP
207
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, sampler=test_sampler,
208
+ num_workers=1, pin_memory=True, drop_last=False)
209
+
210
+
211
+ prediction_reshaper = [-1] # Task specific
212
+ args.out_dims = len(class_labels)
213
+
214
+ # Setup Device
215
+ if torch.cuda.is_available():
216
+ device = torch.device(f'cuda:{local_rank}')
217
+ else:
218
+ device = torch.device('cpu')
219
+ if world_size > 1:
220
+ warnings.warn("Running DDP on CPU is not recommended.")
221
+ if is_main_process(rank):
222
+ print(f'Main process (Rank {rank}): Using device {device}. World size: {world_size}. Model: {args.model}')
223
+
224
+ # --- Model Definition (Conditional) ---
225
+ model_base = None # Base model before DDP wrapping
226
+ if args.model == 'ctm':
227
+ model_base = ContinuousThoughtMachine(
228
+ iterations=args.iterations,
229
+ d_model=args.d_model,
230
+ d_input=args.d_input,
231
+ heads=args.heads,
232
+ n_synch_out=args.n_synch_out,
233
+ n_synch_action=args.n_synch_action,
234
+ synapse_depth=args.synapse_depth,
235
+ memory_length=args.memory_length,
236
+ deep_nlms=args.deep_memory,
237
+ memory_hidden_dims=args.memory_hidden_dims,
238
+ do_layernorm_nlm=args.do_normalisation,
239
+ backbone_type=args.backbone_type,
240
+ positional_embedding_type=args.positional_embedding_type,
241
+ out_dims=args.out_dims,
242
+ prediction_reshaper=prediction_reshaper,
243
+ dropout=args.dropout,
244
+ dropout_nlm=args.dropout_nlm,
245
+ neuron_select_type=args.neuron_select_type,
246
+ n_random_pairing_self=args.n_random_pairing_self,
247
+ ).to(device)
248
+ elif args.model == 'lstm':
249
+ model_base = LSTMBaseline(
250
+ num_layers=args.num_layers,
251
+ iterations=args.iterations,
252
+ d_model=args.d_model,
253
+ d_input=args.d_input,
254
+ heads=args.heads,
255
+ backbone_type=args.backbone_type,
256
+ positional_embedding_type=args.positional_embedding_type,
257
+ out_dims=args.out_dims,
258
+ prediction_reshaper=prediction_reshaper,
259
+ dropout=args.dropout,
260
+ start_type=args.start_type,
261
+ ).to(device)
262
+ elif args.model == 'ff':
263
+ model_base = FFBaseline(
264
+ d_model=args.d_model,
265
+ backbone_type=args.backbone_type,
266
+ out_dims=args.out_dims,
267
+ dropout=args.dropout,
268
+ ).to(device)
269
+ else:
270
+ raise ValueError(f"Unknown model type: {args.model}")
271
+
272
+ # Initialize lazy modules if any
273
+ try:
274
+ pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
275
+ model_base(pseudo_inputs)
276
+ except Exception as e:
277
+ print(f"Warning: Pseudo forward pass failed: {e}")
278
+
279
+ # Wrap model with DDP
280
+ if device.type == 'cuda' and world_size > 1:
281
+ model = DDP(model_base, device_ids=[local_rank], output_device=local_rank)
282
+ elif device.type == 'cpu' and world_size > 1:
283
+ model = DDP(model_base) # No device_ids for CPU
284
+ else: # Single process run
285
+ model = model_base # No DDP wrapping needed
286
+
287
+ if is_main_process(rank):
288
+ # Access underlying model for param count
289
+ param_count = sum(p.numel() for p in model.module.parameters() if p.requires_grad) if world_size > 1 else sum(p.numel() for p in model.parameters() if p.requires_grad)
290
+ print(f'Total trainable params: {param_count}')
291
+ # --- End Model Definition ---
292
+
293
+
294
+ # Optimizer and scheduler
295
+ # Use model.parameters() directly, DDP handles it
296
+ decay_params = []
297
+ no_decay_params = []
298
+ no_decay_names = []
299
+ for name, param in model.named_parameters():
300
+ if not param.requires_grad:
301
+ continue # Skip parameters that don't require gradients
302
+ if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
303
+ no_decay_params.append(param)
304
+ no_decay_names.append(name)
305
+ else:
306
+ decay_params.append(param)
307
+ if len(no_decay_names) and is_main_process(rank):
308
+ print(f'WARNING, excluding: {no_decay_names}')
309
+
310
+ # Optimizer and scheduler (Common setup)
311
+ if len(no_decay_names) and args.weight_decay!=0:
312
+ optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
313
+ {'params': no_decay_params, 'weight_decay':0}],
314
+ lr=args.lr,
315
+ eps=1e-8 if not args.use_amp else 1e-6)
316
+ else:
317
+ optimizer = torch.optim.AdamW(model.parameters(),
318
+ lr=args.lr,
319
+ eps=1e-8 if not args.use_amp else 1e-6,
320
+ weight_decay=args.weight_decay)
321
+
322
+ warmup_schedule = warmup(args.warmup_steps)
323
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
324
+ if args.use_scheduler:
325
+ if args.scheduler_type == 'multistep':
326
+ scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
327
+ elif args.scheduler_type == 'cosine':
328
+ scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
329
+ else:
330
+ raise NotImplementedError
331
+
332
+
333
+ # Metrics tracking (on Rank 0)
334
+ start_iter = 0
335
+ train_losses = []
336
+ test_losses = []
337
+ train_accuracies = [] # Placeholder for potential detailed accuracy
338
+ test_accuracies = [] # Placeholder for potential detailed accuracy
339
+ # Conditional metrics
340
+ train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None # Scalar accuracy list
341
+ test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None # Scalar accuracy list
342
+ train_accuracies_standard = [] if args.model == 'ff' else None # Standard accuracy list for FF
343
+ test_accuracies_standard = [] if args.model == 'ff' else None # Standard accuracy list for FF
344
+ iters = []
345
+
346
+ scaler = torch.amp.GradScaler("cuda" if device.type == 'cuda' else "cpu", enabled=args.use_amp)
347
+ # Reloading Logic
348
+ if args.reload:
349
+ map_location = device # Load directly onto the process's device
350
+ chkpt_path = f'{args.log_dir}/checkpoint.pt'
351
+ if os.path.isfile(chkpt_path):
352
+ print(f'Rank {rank}: Reloading from: {chkpt_path}')
353
+ checkpoint = torch.load(chkpt_path, map_location=map_location, weights_only=False)
354
+
355
+ # Determine underlying model based on whether DDP wrapping occurred
356
+ model_to_load = model.module if isinstance(model, DDP) else model
357
+
358
+ # Handle potential 'module.' prefix in saved state_dict
359
+ state_dict = checkpoint['model_state_dict']
360
+ has_module_prefix = all(k.startswith('module.') for k in state_dict)
361
+ is_wrapped = isinstance(model, DDP)
362
+
363
+ if has_module_prefix and not is_wrapped:
364
+ # Saved with DDP, loading into non-DDP model -> remove prefix
365
+ state_dict = {k.partition('module.')[2]: v for k,v in state_dict.items()}
366
+ elif not has_module_prefix and is_wrapped:
367
+ load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
368
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
369
+ state_dict = None # Prevent loading again
370
+
371
+ if state_dict is not None:
372
+ load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
373
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
374
+
375
+
376
+ if not args.reload_model_only:
377
+ print(f'Rank {rank}: Reloading optimizer, scheduler, scaler, iteration.')
378
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
379
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
380
+ scaler_state_dict = checkpoint['scaler_state_dict']
381
+ if scaler.is_enabled():
382
+ print("Loading non-empty GradScaler state dict.")
383
+ try:
384
+ scaler.load_state_dict(scaler_state_dict)
385
+ except Exception as e:
386
+ print(f"Error loading GradScaler state dict: {e}")
387
+ print("Continuing with a fresh GradScaler state.")
388
+
389
+ start_iter = checkpoint['iteration']
390
+ # Only rank 0 loads metric history
391
+ if is_main_process(rank) and not args.ignore_metrics_when_reloading:
392
+ print(f'Rank {rank}: Reloading metrics history.')
393
+ iters = checkpoint['iters']
394
+ train_losses = checkpoint['train_losses']
395
+ test_losses = checkpoint['test_losses']
396
+ train_accuracies = checkpoint['train_accuracies']
397
+ test_accuracies = checkpoint['test_accuracies']
398
+ if args.model in ['ctm', 'lstm']:
399
+ train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
400
+ test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
401
+ elif args.model == 'ff':
402
+ train_accuracies_standard = checkpoint['train_accuracies_standard']
403
+ test_accuracies_standard = checkpoint['test_accuracies_standard']
404
+ elif is_main_process(rank) and args.ignore_metrics_when_reloading:
405
+ print(f'Rank {rank}: Ignoring metrics history upon reload.')
406
+
407
+ else:
408
+ print(f'Rank {rank}: Only reloading model weights!')
409
+
410
+ # Load RNG states
411
+ if is_main_process(rank) and 'torch_rng_state' in checkpoint and not args.reload_model_only:
412
+ print(f'Rank {rank}: Loading RNG states (may need DDP adaptation for full reproducibility).')
413
+ torch.set_rng_state(checkpoint['torch_rng_state'].cpu()) # Load CPU state
414
+ # Add CUDA state loading if needed, ensuring correct device handling
415
+ np.random.set_state(checkpoint['numpy_rng_state'])
416
+ random.setstate(checkpoint['random_rng_state'])
417
+
418
+ del checkpoint
419
+ if torch.cuda.is_available(): torch.cuda.empty_cache()
420
+ print(f"Rank {rank}: Reload finished, starting from iteration {start_iter}")
421
+ else:
422
+ print(f"Rank {rank}: Checkpoint not found at {chkpt_path}, starting from scratch.")
423
+ if world_size > 1: dist.barrier() # Sync after loading
424
+
425
+
426
+ # Conditional Compilation
427
+ if args.do_compile:
428
+ if is_main_process(rank): print('Compiling model components...')
429
+ # Compile on the underlying model if wrapped
430
+ model_to_compile = model.module if isinstance(model, DDP) else model
431
+ if hasattr(model_to_compile, 'backbone'):
432
+ model_to_compile.backbone = torch.compile(model_to_compile.backbone, mode='reduce-overhead', fullgraph=True)
433
+ if args.model == 'ctm':
434
+ if hasattr(model_to_compile, 'synapses'):
435
+ model_to_compile.synapses = torch.compile(model_to_compile.synapses, mode='reduce-overhead', fullgraph=True)
436
+ if world_size > 1: dist.barrier() # Sync after compilation
437
+ if is_main_process(rank): print('Compilation finished.')
438
+
439
+
440
+ # --- Training Loop ---
441
+ model.train() # Ensure model is in train mode
442
+ pbar = tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True, disable=not is_main_process(rank))
443
+
444
+ iterator = iter(trainloader)
445
+
446
+ for bi in range(start_iter, args.training_iterations):
447
+
448
+ # Set sampler epoch (important for shuffling in DistributedSampler)
449
+ if not args.use_custom_sampler and hasattr(train_sampler, 'set_epoch'):
450
+ train_sampler.set_epoch(bi)
451
+
452
+ current_lr = optimizer.param_groups[-1]['lr']
453
+
454
+ time_start_data = time.time()
455
+ try:
456
+ inputs, targets = next(iterator)
457
+ except StopIteration:
458
+ # Reset iterator - set_epoch handles shuffling if needed
459
+ iterator = iter(trainloader)
460
+ inputs, targets = next(iterator)
461
+
462
+
463
+ inputs = inputs.to(device, non_blocking=True)
464
+ targets = targets.to(device, non_blocking=True)
465
+ time_end_data = time.time()
466
+
467
+ loss = None
468
+ # Model-specific forward and loss calculation
469
+ time_start_forward = time.time()
470
+ with torch.autocast(device_type="cuda" if device.type == 'cuda' else "cpu", dtype=torch.float16, enabled=args.use_amp):
471
+ if args.do_compile:
472
+ torch.compiler.cudagraph_mark_step_begin()
473
+
474
+ if args.model == 'ctm':
475
+ predictions, certainties, synchronisation = model(inputs)
476
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
477
+ elif args.model == 'lstm':
478
+ predictions, certainties, synchronisation = model(inputs)
479
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
480
+ elif args.model == 'ff':
481
+ predictions = model(inputs) # FF returns only predictions
482
+ loss = nn.CrossEntropyLoss()(predictions, targets)
483
+ where_most_certain = None # Not applicable for FF standard loss
484
+ time_end_forward = time.time()
485
+ time_start_backward = time.time()
486
+
487
+ scaler.scale(loss).backward() # DDP handles gradient synchronization
488
+ time_end_backward = time.time()
489
+
490
+ if args.gradient_clipping!=-1:
491
+ scaler.unscale_(optimizer)
492
+ # Clip gradients across all parameters controlled by the optimizer
493
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
494
+
495
+ scaler.step(optimizer)
496
+ scaler.update()
497
+ optimizer.zero_grad(set_to_none=True)
498
+ scheduler.step()
499
+
500
+ # --- Aggregation and Logging (Rank 0) ---
501
+ # Aggregate loss for logging
502
+ loss_log = loss.detach() # Use detached loss for aggregation
503
+ if world_size > 1: dist.all_reduce(loss_log, op=dist.ReduceOp.AVG)
504
+
505
+ if is_main_process(rank):
506
+ # Calculate accuracy locally on rank 0 for description (approximate)
507
+ # Note: This uses rank 0's batch, not aggregated accuracy
508
+ accuracy_local = 0.0
509
+ if args.model in ['ctm', 'lstm']:
510
+ accuracy_local = (predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain] == targets).float().mean().item()
511
+ where_certain_tensor = where_most_certain.float() # Use rank 0's tensor for stats
512
+ pbar_desc = f'Timing; d={(time_end_data-time_start_data):0.3f}, f={(time_end_forward-time_start_forward):0.3f}, b={(time_end_backward-time_start_backward):0.3f}. Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_local:.3f} LR={current_lr:.6f} WhereCert(loc)={where_certain_tensor.mean().item():.2f}'
513
+ elif args.model == 'ff':
514
+ accuracy_local = (predictions.argmax(1) == targets).float().mean().item()
515
+ pbar_desc = f'Timing; d={(time_end_data-time_start_data):0.3f}, f={(time_end_forward-time_start_forward):0.3f}, b={(time_end_backward-time_start_backward):0.3f}. Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_local:.3f} LR={current_lr:.6f}'
516
+
517
+ pbar.set_description(f'{args.model.upper()} {pbar_desc}')
518
+ # --- End Aggregation and Logging ---
519
+
520
+
521
+ # --- Evaluation and Plotting (Rank 0 + Aggregation) ---
522
+ if bi % args.track_every == 0 and (bi != 0 or args.reload_model_only):
523
+
524
+ model.eval()
525
+ with torch.inference_mode():
526
+
527
+
528
+ # --- Distributed Evaluation ---
529
+ iters.append(bi)
530
+
531
+ # TRAIN METRICS
532
+ total_train_loss = torch.tensor(0.0, device=device)
533
+ total_train_correct_certain = torch.tensor(0.0, device=device) # CTM/LSTM
534
+ total_train_correct_standard = torch.tensor(0.0, device=device) # FF
535
+ total_train_samples = torch.tensor(0.0, device=device)
536
+
537
+ # Use a sampler for evaluation to ensure non-overlapping data if needed
538
+ train_eval_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=False)
539
+ train_eval_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, sampler=train_eval_sampler, num_workers=1, pin_memory=True)
540
+
541
+ pbar_inner_desc = 'Eval Train (Rank 0)' if is_main_process(rank) else None
542
+ with tqdm(total=len(train_eval_loader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
543
+ for inferi, (inputs, targets) in enumerate(train_eval_loader):
544
+ inputs = inputs.to(device, non_blocking=True)
545
+ targets = targets.to(device, non_blocking=True)
546
+
547
+ loss_eval = None
548
+ if args.model == 'ctm':
549
+ predictions, certainties, _ = model(inputs)
550
+ loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
551
+ preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
552
+ total_train_correct_certain += (preds_eval == targets).sum()
553
+ elif args.model == 'lstm':
554
+ predictions, certainties, _ = model(inputs)
555
+ loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
556
+ preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
557
+ total_train_correct_certain += (preds_eval == targets).sum()
558
+ elif args.model == 'ff':
559
+ predictions = model(inputs)
560
+ loss_eval = nn.CrossEntropyLoss()(predictions, targets)
561
+ preds_eval = predictions.argmax(1)
562
+ total_train_correct_standard += (preds_eval == targets).sum()
563
+
564
+ total_train_loss += loss_eval * inputs.size(0)
565
+ total_train_samples += inputs.size(0)
566
+
567
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
568
+ pbar_inner.update(1)
569
+
570
+ # Aggregate Train Metrics
571
+ if world_size > 1:
572
+ dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
573
+ dist.all_reduce(total_train_correct_certain, op=dist.ReduceOp.SUM)
574
+ dist.all_reduce(total_train_correct_standard, op=dist.ReduceOp.SUM)
575
+ dist.all_reduce(total_train_samples, op=dist.ReduceOp.SUM)
576
+
577
+ # Calculate final Train metrics on Rank 0
578
+ if is_main_process(rank) and total_train_samples > 0:
579
+ avg_train_loss = total_train_loss.item() / total_train_samples.item()
580
+ train_losses.append(avg_train_loss)
581
+ if args.model in ['ctm', 'lstm']:
582
+ avg_train_acc_certain = total_train_correct_certain.item() / total_train_samples.item()
583
+ train_accuracies_most_certain.append(avg_train_acc_certain)
584
+ elif args.model == 'ff':
585
+ avg_train_acc_standard = total_train_correct_standard.item() / total_train_samples.item()
586
+ train_accuracies_standard.append(avg_train_acc_standard)
587
+ print(f"Iter {bi} Train Metrics (Agg): Loss={avg_train_loss:.4f}")
588
+
589
+ # TEST METRICS
590
+ total_test_loss = torch.tensor(0.0, device=device)
591
+ total_test_correct_certain = torch.tensor(0.0, device=device) # CTM/LSTM
592
+ total_test_correct_standard = torch.tensor(0.0, device=device) # FF
593
+ total_test_samples = torch.tensor(0.0, device=device)
594
+
595
+ pbar_inner_desc = 'Eval Test (Rank 0)' if is_main_process(rank) else None
596
+ with tqdm(total=len(testloader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
597
+ for inferi, (inputs, targets) in enumerate(testloader): # Testloader already uses sampler
598
+ inputs = inputs.to(device, non_blocking=True)
599
+ targets = targets.to(device, non_blocking=True)
600
+
601
+ loss_eval = None
602
+ if args.model == 'ctm':
603
+ predictions, certainties, _ = model(inputs)
604
+ loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
605
+ preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
606
+ total_test_correct_certain += (preds_eval == targets).sum()
607
+ elif args.model == 'lstm':
608
+ predictions, certainties, _ = model(inputs)
609
+ loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
610
+ preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
611
+ total_test_correct_certain += (preds_eval == targets).sum()
612
+ elif args.model == 'ff':
613
+ predictions = model(inputs)
614
+ loss_eval = nn.CrossEntropyLoss()(predictions, targets)
615
+ preds_eval = predictions.argmax(1)
616
+ total_test_correct_standard += (preds_eval == targets).sum()
617
+
618
+ total_test_loss += loss_eval * inputs.size(0)
619
+ total_test_samples += inputs.size(0)
620
+
621
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
622
+ pbar_inner.update(1)
623
+
624
+ # Aggregate Test Metrics
625
+ if world_size > 1:
626
+ dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
627
+ dist.all_reduce(total_test_correct_certain, op=dist.ReduceOp.SUM)
628
+ dist.all_reduce(total_test_correct_standard, op=dist.ReduceOp.SUM)
629
+ dist.all_reduce(total_test_samples, op=dist.ReduceOp.SUM)
630
+
631
+ # Calculate and Plot final Test metrics on Rank 0
632
+ if is_main_process(rank) and total_test_samples > 0:
633
+ avg_test_loss = total_test_loss.item() / total_test_samples.item()
634
+ test_losses.append(avg_test_loss)
635
+ acc_label = ''
636
+ acc_val = 0.0
637
+ if args.model in ['ctm', 'lstm']:
638
+ avg_test_acc_certain = total_test_correct_certain.item() / total_test_samples.item()
639
+ test_accuracies_most_certain.append(avg_test_acc_certain)
640
+ acc_label = f'Most certain ({avg_test_acc_certain:.3f})'
641
+ acc_val = avg_test_acc_certain
642
+ elif args.model == 'ff':
643
+ avg_test_acc_standard = total_test_correct_standard.item() / total_test_samples.item()
644
+ test_accuracies_standard.append(avg_test_acc_standard)
645
+ acc_label = f'Standard Acc ({avg_test_acc_standard:.3f})'
646
+ acc_val = avg_test_acc_standard
647
+ print(f"Iter {bi} Test Metrics (Agg): Loss={avg_test_loss:.4f}, Acc={acc_val:.4f}\n")
648
+
649
+
650
+ # --- Plotting ---
651
+ figacc = plt.figure(figsize=(10, 10))
652
+ axacc_train = figacc.add_subplot(211)
653
+ axacc_test = figacc.add_subplot(212)
654
+
655
+ if args.model in ['ctm', 'lstm']:
656
+ axacc_train.plot(iters, train_accuracies_most_certain, 'k-', alpha=0.9, label=f'Most certain ({train_accuracies_most_certain[-1]:.3f})')
657
+ axacc_test.plot(iters, test_accuracies_most_certain, 'k-', alpha=0.9, label=acc_label)
658
+ elif args.model == 'ff':
659
+ axacc_train.plot(iters, train_accuracies_standard, 'k-', alpha=0.9, label=f'Standard Acc ({train_accuracies_standard[-1]:.3f})')
660
+ axacc_test.plot(iters, test_accuracies_standard, 'k-', alpha=0.9, label=acc_label)
661
+
662
+ axacc_train.set_title('Train Accuracy (Aggregated)')
663
+ axacc_test.set_title('Test Accuracy (Aggregated)')
664
+ axacc_train.legend(loc='lower right')
665
+ axacc_test.legend(loc='lower right')
666
+ axacc_train.set_xlim([0, args.training_iterations])
667
+ axacc_test.set_xlim([0, args.training_iterations])
668
+
669
+ # Keep dataset specific ylim adjustments if needed
670
+ if args.dataset == 'imagenet':
671
+ # For easy comparison when training
672
+ train_ylim_set = False
673
+ if args.model in ['ctm', 'lstm'] and len(train_accuracies_most_certain)>0 and np.any(np.array(train_accuracies_most_certain)>0.4): train_ylim_set=True; axacc_train.set_ylim([0.4, 1])
674
+ if args.model == 'ff' and len(train_accuracies_standard)>0 and np.any(np.array(train_accuracies_standard)>0.4): train_ylim_set=True; axacc_train.set_ylim([0.4, 1])
675
+
676
+ test_ylim_set = False
677
+ if args.model in ['ctm', 'lstm'] and len(test_accuracies_most_certain)>0 and np.any(np.array(test_accuracies_most_certain)>0.3): test_ylim_set=True; axacc_test.set_ylim([0.3, 0.8])
678
+ if args.model == 'ff' and len(test_accuracies_standard)>0 and np.any(np.array(test_accuracies_standard)>0.3): test_ylim_set=True; axacc_test.set_ylim([0.3, 0.8])
679
+
680
+
681
+ figacc.tight_layout()
682
+ figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
683
+ plt.close(figacc)
684
+
685
+ # Loss Plot
686
+ figloss = plt.figure(figsize=(10, 5))
687
+ axloss = figloss.add_subplot(111)
688
+ axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train (Aggregated): {train_losses[-1]:.4f}')
689
+ axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test (Aggregated): {test_losses[-1]:.4f}')
690
+ axloss.legend(loc='upper right')
691
+ axloss.set_xlabel("Iteration")
692
+ axloss.set_ylabel("Loss")
693
+ axloss.set_xlim([0, args.training_iterations])
694
+ axloss.set_ylim(bottom=0)
695
+ figloss.tight_layout()
696
+ figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
697
+ plt.close(figloss)
698
+ # --- End Plotting ---
699
+
700
+ # Visualization on Rank 0
701
+ if is_main_process(rank) and args.model in ['ctm', 'lstm']:
702
+ try:
703
+ model_module = model.module if isinstance(model, DDP) else model # Get underlying model
704
+ # Simplified viz: use first batch from testloader
705
+ inputs_viz, targets_viz = next(iter(testloader))
706
+ inputs_viz = inputs_viz.to(device)
707
+ targets_viz = targets_viz.to(device)
708
+
709
+ pbar.set_description('Tracking (Rank 0): Viz Fwd Pass')
710
+ predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model_module(inputs_viz, track=True)
711
+
712
+ att_shape = (model_module.kv_features.shape[2], model_module.kv_features.shape[3])
713
+ attention_tracking_viz = attention_tracking_viz.reshape(
714
+ attention_tracking_viz.shape[0],
715
+ attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
716
+
717
+
718
+ pbar.set_description('Tracking (Rank 0): Dynamics Plot')
719
+ plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
720
+
721
+ # Plot specific indices from test_data directly
722
+ pbar.set_description('Tracking (Rank 0): GIF Generation')
723
+ for plot_idx in args.plot_indices:
724
+ try:
725
+ if plot_idx < len(test_data):
726
+ inputs_plot, target_plot = test_data.__getitem__(plot_idx)
727
+ inputs_plot = inputs_plot.unsqueeze(0).to(device)
728
+
729
+ preds_plot, certs_plot, _, _, posts_plot, atts_plot = model_module(inputs_plot, track=True)
730
+ atts_plot = atts_plot.reshape(atts_plot.shape[0], atts_plot.shape[1], -1, att_shape[0], att_shape[1])
731
+
732
+
733
+ img_gif = np.moveaxis(np.clip(inputs_plot[0].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1)
734
+
735
+ make_classification_gif(img_gif, target_plot, preds_plot[0].detach().cpu().numpy(), certs_plot[0].detach().cpu().numpy(),
736
+ posts_plot[:,0], atts_plot[:,0] if atts_plot is not None else None, class_labels,
737
+ f'{args.log_dir}/idx{plot_idx}_attention.gif')
738
+ else:
739
+ print(f"Warning: Plot index {plot_idx} out of range for test dataset size {len(test_data)}.")
740
+ except Exception as e_gif:
741
+ print(f"Rank 0 GIF generation failed for index {plot_idx}: {e_gif}")
742
+
743
+ except Exception as e_viz:
744
+ print(f"Rank 0 visualization failed: {e_viz}")
745
+
746
+
747
+
748
+ if world_size > 1: dist.barrier() # Sync after evaluation block
749
+ model.train() # Set back to train mode
750
+ # --- End Evaluation Block ---
751
+
752
+
753
+ # --- Checkpointing (Rank 0) ---
754
+ if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter and is_main_process(rank):
755
+ pbar.set_description('Rank 0: Saving checkpoint...')
756
+ save_path = f'{args.log_dir}/checkpoint.pt'
757
+ # Access underlying model state dict if DDP is used
758
+ model_state_to_save = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
759
+
760
+ save_dict = {
761
+ 'model_state_dict': model_state_to_save,
762
+ 'optimizer_state_dict': optimizer.state_dict(),
763
+ 'scheduler_state_dict': scheduler.state_dict(),
764
+ 'scaler_state_dict':scaler.state_dict(),
765
+ 'iteration': bi,
766
+ 'train_losses': train_losses,
767
+ 'test_losses': test_losses,
768
+ 'iters': iters,
769
+ 'args': args,
770
+ 'torch_rng_state': torch.get_rng_state(), # CPU state
771
+ 'numpy_rng_state': np.random.get_state(),
772
+ 'random_rng_state': random.getstate(),
773
+ # Include conditional metrics
774
+ 'train_accuracies': train_accuracies, # Placeholder
775
+ 'test_accuracies': test_accuracies, # Placeholder
776
+ }
777
+ if args.model in ['ctm', 'lstm']:
778
+ save_dict['train_accuracies_most_certain'] = train_accuracies_most_certain
779
+ save_dict['test_accuracies_most_certain'] = test_accuracies_most_certain
780
+ elif args.model == 'ff':
781
+ save_dict['train_accuracies_standard'] = train_accuracies_standard
782
+ save_dict['test_accuracies_standard'] = test_accuracies_standard
783
+
784
+ torch.save(save_dict , save_path)
785
+ pbar.set_description(f"Rank 0: Checkpoint saved to {save_path}")
786
+ # --- End Checkpointing ---
787
+
788
+
789
+ if world_size > 1: dist.barrier() # Sync before next iteration
790
+
791
+ # Update pbar on Rank 0
792
+ if is_main_process(rank):
793
+ pbar.update(1)
794
+ # --- End Training Loop ---
795
+
796
+ if is_main_process(rank):
797
+ pbar.close()
798
+
799
+ cleanup_ddp() # Cleanup DDP resources
tasks/mazes/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mazes
2
+
3
+ This folder contains code for training and analysing 2D maze solving experiments
4
+
5
+
6
+ ## Training
7
+ To run the maze training that we used for the paper, run the following command from the parent directory:
8
+ ```
9
+ python -m tasks.mazes.train --d_model 2048 --d_input 512 --synapse_depth 4 --heads 8 --n_synch_out 64 --n_synch_action 32 --neuron_select_type first-last --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 32 --dropout 0.1 --no-do_normalisation --positional_embedding_type none --backbone_type resnet34-2 --batch_size 64 --batch_size_test 64 --lr 1e-4 --training_iterations 1000001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs/mazes/d=2048--i=512--h=8--ns=64-32--iters=75x25--h=32--drop=0.1--pos=none--back=34-2--seed=42 --dataset mazes-medium --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50
10
+ ```
tasks/mazes/analysis/README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analysis
2
+
3
+ This folder contains analysis code for 2D maze experiments. To build GIFs for imagenet run (from the base directory):
4
+
5
+ To run maze analysis run the following command from the parent directory:
6
+ ```
7
+ python -m tasks.mazes.analysis.run --actions viz viz --checkpoint checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt
8
+ ```
9
+
10
+ You will need to download the checkpoint from here: https://drive.google.com/file/d/1vGiMaQCxzKVT68SipxDCW0W5n5jjEQnC/view?usp=drive_link . Extract this to the appropriate directory: `checkpoints/mazes/...` . Otherwise, use your own after training.
tasks/mazes/analysis/run.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ np.seterr(divide='ignore', invalid='warn') # Keep specific numpy error settings
4
+ import matplotlib as mpl
5
+ mpl.use('Agg') # Use Agg backend for matplotlib (important to set before importing pyplot)
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ sns.set_style('darkgrid') # Keep seaborn style
9
+ import os
10
+ import argparse
11
+ import cv2
12
+ import imageio # Used for saving GIFs in viz
13
+
14
+ # Local imports
15
+ from data.custom_datasets import MazeImageFolder
16
+ from models.ctm import ContinuousThoughtMachine
17
+ from tasks.mazes.plotting import draw_path #
18
+ from tasks.image_classification.plotting import save_frames_to_mp4
19
+
20
+ def has_solved_checker(x_maze, route, valid_only=True, fault_tolerance=1, exclusions=[]):
21
+ """Checks if a route solves a maze."""
22
+ maze = np.copy(x_maze)
23
+ H, W, _ = maze.shape
24
+ start_coords = np.argwhere((maze == [1, 0, 0]).all(axis=2))
25
+ end_coords = np.argwhere((maze == [0, 1, 0]).all(axis=2))
26
+
27
+ if len(start_coords) == 0:
28
+ return False, (-1, -1), 0 # Cannot start
29
+
30
+ current_pos = tuple(start_coords[0])
31
+ target_pos = tuple(end_coords[0]) if len(end_coords) > 0 else None
32
+
33
+ mistakes_made = 0
34
+ final_pos = current_pos
35
+ path_taken_len = 0
36
+
37
+ for step in route:
38
+ if mistakes_made > fault_tolerance:
39
+ break
40
+
41
+ next_pos_candidate = list(current_pos) # Use a list for mutable coordinate calculation
42
+ if step == 0: next_pos_candidate[0] -= 1
43
+ elif step == 1: next_pos_candidate[0] += 1
44
+ elif step == 2: next_pos_candidate[1] -= 1
45
+ elif step == 3: next_pos_candidate[1] += 1
46
+ elif step == 4: pass # Stay in place
47
+ else: continue # Invalid step action
48
+ next_pos = tuple(next_pos_candidate)
49
+
50
+
51
+ is_invalid_step = False
52
+ # Check bounds first, then maze content if in bounds
53
+ if not (0 <= next_pos[0] < H and 0 <= next_pos[1] < W):
54
+ is_invalid_step = True
55
+ elif np.all(maze[next_pos] == [0, 0, 0]): # Wall
56
+ is_invalid_step = True
57
+
58
+ if is_invalid_step:
59
+ mistakes_made += 1
60
+ if valid_only:
61
+ continue
62
+
63
+ current_pos = next_pos
64
+ path_taken_len += 1
65
+
66
+ if target_pos and current_pos == target_pos:
67
+ if mistakes_made <= fault_tolerance:
68
+ return True, current_pos, path_taken_len
69
+
70
+ if mistakes_made <= fault_tolerance:
71
+ # Assuming exclusions is a list of tuples (as populated in the 'gen' action)
72
+ if current_pos not in exclusions:
73
+ final_pos = current_pos
74
+
75
+ if target_pos and final_pos == target_pos and mistakes_made <= fault_tolerance: # Added mistakes_made check here
76
+ return True, final_pos, path_taken_len
77
+ return False, final_pos, path_taken_len
78
+
79
+
80
+ def parse_args():
81
+ """Parses command-line arguments for maze analysis."""
82
+ parser = argparse.ArgumentParser(description="Analyze Asynchronous Thought Machine on Maze Tasks")
83
+ parser.add_argument('--actions', type=str, nargs='+', default=['gen'], help="Actions: 'viz', 'gen'")
84
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help="GPU device index or -1 for CPU")
85
+ parser.add_argument('--checkpoint', type=str, default='checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt', help="Path to CTM checkpoint")
86
+ parser.add_argument('--output_dir', type=str, default='tasks/mazes/analysis/outputs', help="Directory for analysis outputs")
87
+ parser.add_argument('--dataset_for_viz', type=str, default='large', help="Dataset for 'viz' action")
88
+ parser.add_argument('--dataset_for_gen', type=str, default='extralarge', help="Dataset for 'gen' action")
89
+ parser.add_argument('--batch_size_test', type=int, default=32, help="Batch size for loading test data for 'viz'")
90
+ parser.add_argument('--max_reapplications', type=int, default=20, help="When testing generalisation to extra large mazes")
91
+ parser.add_argument('--legacy_scaling', action=argparse.BooleanOptionalAction, default=True, help='Legacy checkpoints scale between 0 and 1, new ones can scale -1 to 1.')
92
+ return parser.parse_args()
93
+
94
+ def _load_ctm_model(checkpoint_path, device):
95
+ """Loads the ContinuousThoughtMachine model from a checkpoint."""
96
+ print(f"Loading checkpoint: {checkpoint_path}")
97
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
98
+ model_args = checkpoint['args']
99
+
100
+ # Handle legacy arguments for model_args
101
+ if not hasattr(model_args, 'backbone_type') and hasattr(model_args, 'resnet_type'):
102
+ model_args.backbone_type = f'{model_args.resnet_type}-{getattr(model_args, "resnet_feature_scales", [4])[-1]}'
103
+
104
+ # Ensure prediction_reshaper is derived correctly
105
+ # Assuming out_dims exists and is used for this
106
+ prediction_reshaper = [model_args.out_dims // 5, 5] if hasattr(model_args, 'out_dims') else None
107
+
108
+
109
+ if not hasattr(model_args, 'neuron_select_type'):
110
+ model_args.neuron_select_type = 'first-last'
111
+ if not hasattr(model_args, 'n_random_pairing_self'):
112
+ model_args.n_random_pairing_self = 0
113
+
114
+ print("Instantiating CTM model...")
115
+ model = ContinuousThoughtMachine(
116
+ iterations=model_args.iterations,
117
+ d_model=model_args.d_model,
118
+ d_input=model_args.d_input,
119
+ heads=model_args.heads,
120
+ n_synch_out=model_args.n_synch_out,
121
+ n_synch_action=model_args.n_synch_action,
122
+ synapse_depth=model_args.synapse_depth,
123
+ memory_length=model_args.memory_length,
124
+ deep_nlms=model_args.deep_memory, # Mapping from model_args.deep_memory
125
+ memory_hidden_dims=model_args.memory_hidden_dims,
126
+ do_layernorm_nlm=model_args.do_normalisation, # Mapping from model_args.do_normalisation
127
+ backbone_type=model_args.backbone_type,
128
+ positional_embedding_type=model_args.positional_embedding_type,
129
+ out_dims=model_args.out_dims,
130
+ prediction_reshaper=prediction_reshaper,
131
+ dropout=0, # Explicitly setting dropout to 0 as in original
132
+ neuron_select_type=model_args.neuron_select_type,
133
+ n_random_pairing_self=model_args.n_random_pairing_self,
134
+ ).to(device)
135
+
136
+ load_result = model.load_state_dict(checkpoint['state_dict'], strict=False)
137
+ print(f"Loaded state_dict. Missing keys: {load_result.missing_keys}, Unexpected keys: {load_result.unexpected_keys}")
138
+ model.eval()
139
+ return model
140
+
141
+ # --- Main Execution Block ---
142
+ if __name__=='__main__':
143
+ args = parse_args()
144
+
145
+ if args.device[0] != -1 and torch.cuda.is_available():
146
+ device = f'cuda:{args.device[0]}'
147
+ else:
148
+ device = 'cpu'
149
+ print(f"Using device: {device}")
150
+
151
+ palette = sns.color_palette("husl", 8)
152
+ cmap = plt.get_cmap('gist_rainbow')
153
+
154
+ # --- Generalisation Action ('gen') ---
155
+ if 'gen' in args.actions:
156
+ model = _load_ctm_model(args.checkpoint, device)
157
+
158
+ print(f"\n--- Running Generalisation Analysis ('gen'): {args.dataset_for_gen} ---")
159
+ target_dataset_name = f'{args.dataset_for_gen}'
160
+ data_root = f'data/mazes/{target_dataset_name}/test'
161
+ max_target_route_len = 50 # Specific to 'gen' action
162
+
163
+ test_data = MazeImageFolder(
164
+ root=data_root, which_set='test',
165
+ maze_route_length=max_target_route_len,
166
+ expand_range=not args.legacy_scaling, # Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
167
+ trunc=True
168
+ )
169
+ # Load a single large batch for 'gen'
170
+ testloader = torch.utils.data.DataLoader(
171
+ test_data, batch_size=min(len(test_data), 2000),
172
+ shuffle=False, num_workers=1
173
+ )
174
+ inputs, targets = next(iter(testloader))
175
+
176
+ actual_lengths = (targets != 4).sum(dim=-1)
177
+ sorted_indices = torch.argsort(actual_lengths, descending=True)
178
+ inputs, targets, actual_lengths = inputs[sorted_indices], targets[sorted_indices], actual_lengths[sorted_indices]
179
+
180
+ test_how_many = min(1000, len(inputs))
181
+ print(f"Processing {test_how_many} mazes sorted by length...")
182
+
183
+ results = {}
184
+ fault_tolerance = 2 # Specific to 'gen' analysis
185
+ output_gen_dir = os.path.join(args.output_dir, 'gen', args.dataset_for_gen)
186
+ os.makedirs(output_gen_dir, exist_ok=True)
187
+
188
+ for n_tested in range(test_how_many):
189
+ maze_actual_length = actual_lengths[n_tested].item()
190
+ maze_idx_display = n_tested + 1
191
+ print(f"Testing maze {maze_idx_display}/{test_how_many} (Len: {maze_actual_length})...")
192
+
193
+ initial_input_maze = inputs[n_tested:n_tested+1].clone().to(device)
194
+ maze_output_dir = os.path.join(output_gen_dir, f"maze_{maze_idx_display}")
195
+
196
+ re_applications = 0
197
+ has_solved = False
198
+ current_input_maze = initial_input_maze
199
+ exclusions = []
200
+ long_frames = []
201
+ ongoing_solution_img = None
202
+
203
+ while not has_solved and re_applications < args.max_reapplications:
204
+ re_applications += 1
205
+ with torch.no_grad():
206
+ predictions, certainties, _, _, _, attention_tracking = model(current_input_maze, track=True)
207
+
208
+ h_feat, w_feat = model.kv_features.shape[-2:]
209
+ attention_tracking = attention_tracking.reshape(attention_tracking.shape[0], -1, h_feat, w_feat)
210
+
211
+ n_steps_viz = predictions.shape[-1] # Use a different name to avoid conflict if n_steps is used elsewhere
212
+ step_linspace = np.linspace(0, 1, n_steps_viz)
213
+ current_maze_np = current_input_maze[0].permute(1,2,0).detach().cpu().numpy()
214
+
215
+ for stepi in range(n_steps_viz):
216
+ pred_route = predictions[0, :, stepi].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
217
+ frame = draw_path(current_maze_np, pred_route)
218
+ if attention_tracking is not None and stepi < attention_tracking.shape[0]:
219
+ try:
220
+ attn = attention_tracking[stepi].mean(0)
221
+ attn_resized = cv2.resize(attn, (current_maze_np.shape[1], current_maze_np.shape[0]), interpolation=cv2.INTER_LINEAR)
222
+ if attn_resized.max() > attn_resized.min():
223
+ attn_norm = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min())
224
+ attn_norm[attn_norm < np.percentile(attn_norm, 80)] = 0.0
225
+ frame = np.clip((np.copy(frame)*(1-attn_norm[:,:,np.newaxis])*1 + (attn_norm[:,:,np.newaxis]*0.8 * np.reshape(np.array(cmap(step_linspace[stepi]))[:3], (1, 1, 3)))), 0, 1)
226
+ except Exception: # Keep broad except for visualization robustness
227
+ pass
228
+ frame_resized = cv2.resize(frame, (int(current_maze_np.shape[1]*4), int(current_maze_np.shape[0]*4)), interpolation=cv2.INTER_NEAREST) # Corrected shape[1]*4 for height
229
+ long_frames.append((np.clip(frame_resized, 0, 1) * 255).astype(np.uint8))
230
+
231
+ where_most_certain = certainties[0, 1].argmax().item()
232
+ chosen_pred_route = predictions[0, :, where_most_certain].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
233
+ current_start_loc_list = np.argwhere((current_maze_np == [1, 0, 0]).all(axis=2)).tolist()
234
+
235
+ # Ensure current_start_loc_list is not empty before trying to access its elements
236
+ if not current_start_loc_list:
237
+ print(f"Warning: Could not find start location in maze {maze_idx_display} during reapplication {re_applications}. Stopping reapplication.")
238
+ break # Cannot proceed without a start location
239
+
240
+ solved_now, final_pos, _ = has_solved_checker(current_maze_np, chosen_pred_route, True, fault_tolerance, exclusions)
241
+
242
+ path_img = draw_path(current_maze_np, chosen_pred_route, cmap=cmap, valid_only=True)
243
+ if ongoing_solution_img is None:
244
+ ongoing_solution_img = path_img
245
+ else:
246
+ mask = (np.any(ongoing_solution_img!=path_img, -1))&(~np.all(path_img==[1,1,1], -1))&(~np.all(ongoing_solution_img==[1,0,0], -1))
247
+ ongoing_solution_img[mask] = path_img[mask]
248
+
249
+ if solved_now:
250
+ has_solved = True
251
+ break
252
+
253
+ if tuple(current_start_loc_list[0]) == final_pos:
254
+ exclusions.append(tuple(current_start_loc_list[0]))
255
+
256
+ next_input = current_input_maze.clone()
257
+ old_start_idx = tuple(current_start_loc_list[0])
258
+ next_input[0, :, old_start_idx[0], old_start_idx[1]] = 1.0 # Reset old start to path
259
+
260
+ if 0 <= final_pos[0] < next_input.shape[2] and 0 <= final_pos[1] < next_input.shape[3]:
261
+ next_input[0, :, final_pos[0], final_pos[1]] = torch.tensor([1,0,0], device=device, dtype=next_input.dtype) # New start
262
+ else:
263
+ print(f"Warning: final_pos {final_pos} out of bounds for maze {maze_idx_display}. Stopping reapplication.")
264
+ break
265
+ current_input_maze = next_input
266
+
267
+ if has_solved:
268
+ print(f'Solved maze of length {maze_actual_length}! Saving...')
269
+ os.makedirs(maze_output_dir, exist_ok=True)
270
+ if ongoing_solution_img is not None:
271
+ cv2.imwrite(os.path.join(maze_output_dir, 'ongoing_solution.png'), (ongoing_solution_img * 255).astype(np.uint8)[:,:,::-1])
272
+ if long_frames:
273
+ save_frames_to_mp4([fm[:,:,::-1] for fm in long_frames], os.path.join(maze_output_dir, f'combined_process.mp4'), fps=45, gop_size=10, preset='veryslow', crf=20)
274
+ else:
275
+ print(f'Failed maze of length {maze_actual_length} after {re_applications} reapplications. Not saving visuals for this maze.')
276
+
277
+ if maze_actual_length not in results: results[maze_actual_length] = []
278
+ results[maze_actual_length].append((has_solved, re_applications))
279
+
280
+ fig_success, ax_success = plt.subplots()
281
+ fig_reapp, ax_reapp = plt.subplots()
282
+ sorted_lengths = sorted(results.keys())
283
+ if sorted_lengths:
284
+ success_rates = [np.mean([r[0] for r in results[l]]) * 100 for l in sorted_lengths]
285
+ reapps_mean = [np.mean([r[1] for r in results[l] if r[0]]) if any(r[0] for r in results[l]) else np.nan for l in sorted_lengths]
286
+ ax_success.plot(sorted_lengths, success_rates, linestyle='-', color=palette[0])
287
+ ax_reapp.plot(sorted_lengths, reapps_mean, linestyle='-', color=palette[5])
288
+ ax_success.set_xlabel('Route Length'); ax_success.set_ylabel('Success (%)')
289
+ ax_reapp.set_xlabel('Route Length'); ax_reapp.set_ylabel('Re-applications (Avg on Success)')
290
+ fig_success.tight_layout(pad=0.1); fig_reapp.tight_layout(pad=0.1)
291
+ fig_success.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-success_rate.png'), dpi=200)
292
+ fig_success.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-success_rate.pdf'), dpi=200)
293
+ fig_reapp.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-re-applications.png'), dpi=200)
294
+ fig_reapp.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-re-applications.pdf'), dpi=200)
295
+ plt.close(fig_success); plt.close(fig_reapp)
296
+ np.savez(os.path.join(output_gen_dir, f'{args.dataset_for_gen}_results.npz'), results=results)
297
+
298
+ print("\n--- Generalisation Analysis ('gen') Complete ---")
299
+
300
+ # --- Visualization Action ('viz') ---
301
+ if 'viz' in args.actions:
302
+ model = _load_ctm_model(args.checkpoint, device)
303
+
304
+ print(f"\n--- Running Visualization ('viz'): {args.dataset_for_viz} ---")
305
+ output_viz_dir = os.path.join(args.output_dir, 'viz')
306
+ os.makedirs(output_viz_dir, exist_ok=True)
307
+
308
+ target_dataset_name = f'{args.dataset_for_viz}'
309
+ data_root = f'data/mazes/{target_dataset_name}/test'
310
+ test_data = MazeImageFolder(
311
+ root=data_root, which_set='test',
312
+ maze_route_length=100, # Max route length for viz data
313
+ expand_range=not args.legacy_scaling, # # Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
314
+ trunc=True
315
+ )
316
+ testloader = torch.utils.data.DataLoader(
317
+ test_data, batch_size=args.batch_size_test,
318
+ shuffle=False, num_workers=1
319
+ )
320
+
321
+ all_inputs, all_targets, all_lengths = [], [], []
322
+ for b_in, b_tgt in testloader:
323
+ all_inputs.append(b_in)
324
+ all_targets.append(b_tgt)
325
+ all_lengths.append((b_tgt != 4).sum(dim=-1))
326
+
327
+ if not all_inputs:
328
+ print("Error: No data in visualization loader. Exiting 'viz' action.")
329
+ exit()
330
+
331
+ all_inputs, all_targets, all_lengths = torch.cat(all_inputs), torch.cat(all_targets), torch.cat(all_lengths)
332
+
333
+ num_viz_mazes = 10
334
+ num_viz_mazes = min(num_viz_mazes, len(all_lengths))
335
+
336
+ if num_viz_mazes == 0:
337
+ print("Error: No mazes found to visualize. Exiting 'viz' action.")
338
+ exit()
339
+
340
+ top_indices = torch.argsort(all_lengths, descending=True)[:num_viz_mazes]
341
+ inputs_viz, targets_viz = all_inputs[top_indices].to(device), all_targets[top_indices]
342
+
343
+ print(f"Visualizing {len(inputs_viz)} longest mazes...")
344
+
345
+ with torch.no_grad():
346
+ predictions, _, _, _, _, attention_tracking = model(inputs_viz, track=True)
347
+
348
+ # Reshape attention: (Steps, Batch, Heads, H_feat, W_feat) assuming model.kv_features has H_feat, W_feat
349
+ # The original reshape was slightly different, this tries to match the likely intended dimensions for per-step, per-batch item attention
350
+ if attention_tracking is not None and hasattr(model, 'kv_features') and model.kv_features is not None:
351
+ attention_tracking = attention_tracking.reshape(
352
+ attention_tracking.shape[0], # Iterations/Steps
353
+ inputs_viz.size(0), # Batch size (num_viz_mazes)
354
+ -1, # Heads (inferred)
355
+ model.kv_features.shape[-2], # H_feat
356
+ model.kv_features.shape[-1] # W_feat
357
+ )
358
+ else:
359
+ attention_tracking = None # Ensure it's None if it can't be reshaped
360
+ print("Warning: Could not reshape attention_tracking. Visualizations may not include attention overlays.")
361
+
362
+
363
+ for maze_i in range(inputs_viz.size(0)):
364
+ maze_idx_display = maze_i + 1
365
+ maze_output_dir = os.path.join(output_viz_dir, f"maze_{maze_idx_display}")
366
+ os.makedirs(maze_output_dir, exist_ok=True)
367
+
368
+ current_input_np_original = inputs_viz[maze_i].permute(1,2,0).detach().cpu().numpy()
369
+ # Apply scaling for visualization based on legacy_scaling: Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
370
+ current_input_np_display = (current_input_np_original + 1) / 2 if not args.legacy_scaling else current_input_np_original
371
+
372
+ current_target_route = targets_viz[maze_i].detach().cpu().numpy()
373
+ print(f"Generating viz for maze {maze_idx_display}...")
374
+
375
+ try:
376
+ solution_maze_img = draw_path(current_input_np_display, current_target_route, gt=True)
377
+ cv2.imwrite(os.path.join(maze_output_dir, 'solution_ground_truth.png'), (solution_maze_img * 255).astype(np.uint8)[:,:,::-1])
378
+ except Exception: # Keep broad except for visualization robustness
379
+ print(f"Could not save ground truth solution for maze {maze_idx_display}")
380
+ pass
381
+
382
+ frames = []
383
+ n_steps_viz = predictions.shape[-1] # Use a different name
384
+ step_linspace = np.linspace(0, 1, n_steps_viz)
385
+
386
+ for stepi in range(n_steps_viz):
387
+ pred_route = predictions[maze_i, :, stepi].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
388
+ frame = draw_path(current_input_np_display, pred_route)
389
+
390
+ if attention_tracking is not None and stepi < attention_tracking.shape[0] and maze_i < attention_tracking.shape[1]:
391
+
392
+ # Attention for current step (stepi) and current maze in batch (maze_i), average over heads
393
+ attn = attention_tracking[stepi, maze_i].mean(0)
394
+ attn_resized = cv2.resize(attn, (current_input_np_display.shape[1], current_input_np_display.shape[0]), interpolation=cv2.INTER_LINEAR)
395
+ if attn_resized.max() > attn_resized.min():
396
+ attn_norm = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min())
397
+ attn_norm[attn_norm < np.percentile(attn_norm, 80)] = 0.0
398
+ frame = np.clip((np.copy(frame)*(1-attn_norm[:,:,np.newaxis])*0.9 + (attn_norm[:,:,np.newaxis]*1.2 * np.reshape(np.array(cmap(step_linspace[stepi]))[:3], (1, 1, 3)))), 0, 1)
399
+
400
+
401
+ frame_resized = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_NEAREST)
402
+ frames.append((np.clip(frame_resized, 0, 1) * 255).astype(np.uint8))
403
+
404
+ if frames:
405
+ imageio.mimsave(os.path.join(maze_output_dir, 'attention_overlay.gif'), frames, fps=15, loop=0)
406
+
407
+ print("\n--- Visualization Action ('viz') Complete ---")
tasks/mazes/plotting.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import cv2
4
+ import torch
5
+ import os
6
+ import matplotlib.pyplot as plt
7
+ import imageio
8
+
9
+ from tqdm.auto import tqdm
10
+
11
+ def find_center_of_mass(array_2d):
12
+ """
13
+ Alternative implementation using np.average and meshgrid.
14
+ This version is generally faster and more concise.
15
+
16
+ Args:
17
+ array_2d: A 2D numpy array of values between 0 and 1.
18
+
19
+ Returns:
20
+ A tuple (x, y) representing the coordinates of the center of mass.
21
+ """
22
+ total_mass = np.sum(array_2d)
23
+ if total_mass == 0:
24
+ return (np.nan, np.nan)
25
+
26
+ y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
27
+ x_center = np.average(x_coords, weights=array_2d)
28
+ y_center = np.average(y_coords, weights=array_2d)
29
+ return (round(y_center, 4), round(x_center, 4))
30
+
31
+ def draw_path(x, route, valid_only=False, gt=False, cmap=None):
32
+ """
33
+ Draws a path on a maze image based on a given route.
34
+
35
+ Args:
36
+ maze: A numpy array representing the maze image.
37
+ route: A list of integers representing the route, where 0 is up, 1 is down, 2 is left, and 3 is right.
38
+ valid_only: A boolean indicating whether to only draw valid steps (i.e., steps that don't go into walls).
39
+
40
+ Returns:
41
+ A numpy array representing the maze image with the path drawn in blue.
42
+ """
43
+ x = np.copy(x)
44
+ start = np.argwhere((x == [1, 0, 0]).all(axis=2))
45
+ end = np.argwhere((x == [0, 1, 0]).all(axis=2))
46
+ if cmap is None:
47
+ cmap = plt.get_cmap('winter') if not valid_only else plt.get_cmap('summer')
48
+
49
+ # Initialize the current position
50
+ current_pos = start[0]
51
+
52
+ # Draw the path
53
+ colors = cmap(np.linspace(0, 1, len(route)))
54
+ si = 0
55
+ for step in route:
56
+ new_pos = current_pos
57
+ if step == 0: # Up
58
+ new_pos = (current_pos[0] - 1, current_pos[1])
59
+ elif step == 1: # Down
60
+ new_pos = (current_pos[0] + 1, current_pos[1])
61
+ elif step == 2: # Left
62
+ new_pos = (current_pos[0], current_pos[1] - 1)
63
+ elif step == 3: # Right
64
+ new_pos = (current_pos[0], current_pos[1] + 1)
65
+ elif step == 4: # Do nothing
66
+ pass
67
+ else:
68
+ raise ValueError("Invalid step: {}".format(step))
69
+
70
+ # Check if the new position is valid
71
+ if valid_only:
72
+ try:
73
+ if np.all(x[new_pos] == [0,0,0]): # Check if it's a wall
74
+ continue # Skip this step if it's invalid
75
+ except IndexError:
76
+ continue # Skip this step if it's out of bounds
77
+
78
+ # Draw the step
79
+ if new_pos[0] >= 0 and new_pos[0] < x.shape[0] and new_pos[1] >= 0 and new_pos[1] < x.shape[1]:
80
+ if not ((x[new_pos] == [1,0,0]).all() or (x[new_pos] == [0,1,0]).all()):
81
+ colour = colors[si][:3]
82
+ si += 1
83
+ x[new_pos] = x[new_pos]*0.5 + colour*0.5
84
+
85
+ # Update the current position
86
+ current_pos = new_pos
87
+ # cv2.imwrite('maze2.png', x[:,:,::-1]*255)
88
+
89
+ return x
90
+
91
+ def make_maze_gif(inputs, predictions, targets, attention_tracking, save_location):
92
+ """
93
+ Expect inputs, predictions, targets as numpy arrays
94
+ """
95
+ route_steps = []
96
+ route_colours = []
97
+ solution_maze = draw_path(np.moveaxis(inputs, 0, -1), targets)
98
+
99
+ # cv2.imwrite(f'{save_location}/ground_truth.png', solution_maze[:,:,::-1]*255)
100
+ mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
101
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
102
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
103
+ ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
104
+ ['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
105
+ ['head_8', 'head_9', 'head_10', 'head_11', 'head_12', 'head_13', 'head_14', 'head_15'],
106
+ ]
107
+ img_aspect = 1
108
+ figscale = 1
109
+ aspect_ratio = (8 * figscale, 6 * figscale * img_aspect) # W, H
110
+
111
+ route_steps = [np.unravel_index(np.argmax((inputs == np.reshape(np.array([1, 0, 0]), (3, 1, 1))).all(0)), inputs.shape[1:])] # Starting point
112
+ frames = []
113
+ cmap = plt.get_cmap('gist_rainbow')
114
+ cmap_viridis = plt.get_cmap('viridis')
115
+ step_linspace = np.linspace(0, 1, predictions.shape[-1]) # For sampling colours
116
+ with tqdm(total=predictions.shape[-1], initial=0, leave=True, position=1, dynamic_ncols=True) as pbar:
117
+ pbar.set_description('Processing frames for maze plotting')
118
+ for stepi in np.arange(0, predictions.shape[-1], 1):
119
+ fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
120
+ for ax in axes.values():
121
+ ax.axis('off')
122
+ guess_maze = draw_path(np.moveaxis(inputs, 0, -1), predictions.argmax(1)[:,stepi], cmap=cmap)
123
+ attention_now = attention_tracking[stepi]
124
+ for hi in range(min((attention_tracking.shape[1], 16))):
125
+ ax = axes[f'head_{hi}']
126
+ attn = attention_tracking[stepi, hi]
127
+ attn = (attn - attn.min())/(np.ptp(attn))
128
+ ax.imshow(attn, cmap=cmap_viridis)
129
+ # Upsample attention just for visualisation
130
+ aggregated_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy()
131
+
132
+ # Get approximate center of mass
133
+ com_attn = np.copy(aggregated_attention)
134
+ com_attn[com_attn < np.percentile(com_attn, 96)] = 0.0
135
+ aggregated_attention[aggregated_attention < np.percentile(aggregated_attention, 80)] = 0.0
136
+ route_steps.append(find_center_of_mass(com_attn))
137
+
138
+
139
+ colour = list(cmap(step_linspace[stepi]))
140
+ route_colours.append(colour)
141
+
142
+ mapped_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy()
143
+ mapped_attention = (mapped_attention - mapped_attention.min())/np.ptp(mapped_attention)
144
+ # np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.5) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.3, 0, 1)
145
+ overlay_img = np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.6) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.1, 0, 1)#np.clip((np.copy(guess_maze)*(1-aggregated_attention[:,:,np.newaxis])*0.7 + (aggregated_attention[:,:,np.newaxis]*3 * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1)
146
+ axes['overlay'].imshow(overlay_img)
147
+
148
+ y_coords, x_coords = zip(*route_steps)
149
+ y_coords = inputs.shape[-1] - np.array(list(y_coords))-1
150
+
151
+
152
+ axes['route'].imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower')
153
+ # ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
154
+ arrow_scale = 2
155
+ for i in range(len(route_steps)-1):
156
+ dx = x_coords[i+1] - x_coords[i]
157
+ dy = y_coords[i+1] - y_coords[i]
158
+ axes['route'].arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
159
+
160
+ fig.tight_layout(pad=0.1) # Adjust spacing
161
+
162
+ # Render the plot to a numpy array
163
+ canvas = fig.canvas
164
+ canvas.draw()
165
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
166
+ image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
167
+
168
+ frames.append(image_numpy) # Add to list for GIF
169
+
170
+ # fig.savefig(f'{save_location}/frame.png', dpi=200)
171
+
172
+ plt.close(fig)
173
+
174
+ # # frame = np.clip((np.copy(guess_maze)*0.5 + (aggregated_attention[:,:,np.newaxis] * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1)
175
+ # frame = torch.nn.functional.interpolate(torch.from_numpy(frame).permute(2,0,1).unsqueeze(0), 256)[0].permute(1,2,0).detach().cpu().numpy()
176
+ # frames.append((frame*255).astype(np.uint8))
177
+ pbar.update(1)
178
+
179
+
180
+ y_coords, x_coords = zip(*route_steps)
181
+ y_coords = inputs.shape[-1] - np.array(list(y_coords))-1
182
+
183
+ fig = plt.figure(figsize=(5,5))
184
+ ax = fig.add_subplot(111)
185
+
186
+ ax.imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower')
187
+ # ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
188
+ arrow_scale = 2
189
+ for i in range(len(route_steps)-1):
190
+ dx = x_coords[i+1] - x_coords[i]
191
+ dy = y_coords[i+1] - y_coords[i]
192
+ plt.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
193
+
194
+ ax.axis('off')
195
+ fig.tight_layout(pad=0)
196
+ fig.savefig(f'{save_location}/route_approximation.png', dpi=200)
197
+ imageio.mimsave(f'{save_location}/prediction.gif', frames, fps=15, loop=100)
198
+ plt.close(fig)
tasks/mazes/scripts/train_ctm.sh ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python -m tasks.mazes.train \
2
+ --model ctm \
3
+ --log_dir logs/mazes/ctm/d=2048--i=512--heads=16--sd=8--nlm=32--synch=64-32-h=32-first-last--iters=75x25--backbone=34-2 \
4
+ --neuron_select_type first-last \
5
+ --dataset mazes-large \
6
+ --synapse_depth 8 \
7
+ --heads 16 \
8
+ --iterations 75 \
9
+ --memory_length 25 \
10
+ --d_model 2048 \
11
+ --d_input 512 \
12
+ --backbone_type resnet34-2 \
13
+ --n_synch_out 64 \
14
+ --n_synch_action 32 \
15
+ --memory_hidden_dims 32 \
16
+ --deep_memory \
17
+ --weight_decay 0.000 \
18
+ --batch_size 64 \
19
+ --batch_size_test 128 \
20
+ --n_test_batches 20 \
21
+ --gradient_clipping -1 \
22
+ --use_scheduler \
23
+ --scheduler_type cosine \
24
+ --warmup_steps 10000 \
25
+ --training_iterations 1000001 \
26
+ --no-do_normalisation \
27
+ --track_every 1000 \
28
+ --lr 1e-4 \
29
+ --no-reload \
30
+ --dropout 0.1 \
31
+ --positional_embedding_type none \
32
+ --maze_route_length 100 \
33
+ --cirriculum_lookahead 5 \
34
+ --device 0 \
35
+ --no-expand_range
tasks/mazes/train.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import seaborn as sns
8
+ sns.set_style('darkgrid')
9
+ import torch
10
+ if torch.cuda.is_available():
11
+ # For faster
12
+ torch.set_float32_matmul_precision('high')
13
+ from tqdm.auto import tqdm
14
+
15
+ from data.custom_datasets import MazeImageFolder
16
+ from models.ctm import ContinuousThoughtMachine
17
+ from models.lstm import LSTMBaseline
18
+ from models.ff import FFBaseline
19
+ from tasks.mazes.plotting import make_maze_gif
20
+ from tasks.image_classification.plotting import plot_neural_dynamics
21
+ from utils.housekeeping import set_seed, zip_python_code
22
+ from utils.losses import maze_loss
23
+ from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
24
+
25
+ import torchvision
26
+ torchvision.disable_beta_transforms_warning()
27
+
28
+ import warnings
29
+ warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
30
+ warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
31
+ warnings.filterwarnings(
32
+ "ignore",
33
+ "Corrupt EXIF data",
34
+ UserWarning,
35
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
36
+ )
37
+ warnings.filterwarnings(
38
+ "ignore",
39
+ "UserWarning: Metadata Warning",
40
+ UserWarning,
41
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
42
+ )
43
+ warnings.filterwarnings(
44
+ "ignore",
45
+ "UserWarning: Truncated File Read",
46
+ UserWarning,
47
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
48
+ )
49
+
50
+
51
+ def parse_args():
52
+ parser = argparse.ArgumentParser()
53
+
54
+ # Model Selection
55
+ parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
56
+
57
+ # Model Architecture
58
+ # Common across all or most
59
+ parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
60
+ parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
61
+ parser.add_argument('--backbone_type', type=str, default='resnet34-2', help='Type of backbone featureiser.') # Default changed from original script
62
+ # CTM / LSTM specific
63
+ parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
64
+ parser.add_argument('--heads', type=int, default=8, help='Number of attention heads (CTM, LSTM).') # Default changed
65
+ parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
66
+ parser.add_argument('--positional_embedding_type', type=str, default='none',
67
+ help='Type of positional embedding (CTM, LSTM).', choices=['none',
68
+ 'learnable-fourier',
69
+ 'multi-learnable-fourier',
70
+ 'custom-rotational'])
71
+
72
+ # CTM specific
73
+ parser.add_argument('--synapse_depth', type=int, default=8, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).') # Default changed
74
+ parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).') # Default changed
75
+ parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).') # Default changed
76
+ parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
77
+ parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
78
+ parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
79
+ parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True,
80
+ help='Use deep memory (CTM only).')
81
+ parser.add_argument('--memory_hidden_dims', type=int, default=32, help='Hidden dimensions of the memory if using deep memory (CTM only).') # Default changed
82
+ parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
83
+ parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
84
+ # LSTM specific
85
+ parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).') # Added LSTM arg
86
+
87
+ # Task Specific Args (Common to all models for this task)
88
+ parser.add_argument('--maze_route_length', type=int, default=100, help='Length to truncate targets.')
89
+ parser.add_argument('--cirriculum_lookahead', type=int, default=5, help='How far to look ahead for cirriculum.')
90
+
91
+
92
+ # Training
93
+ parser.add_argument('--expand_range', action=argparse.BooleanOptionalAction, default=True, help='Mazes between 0 and 1 = False. Between -1 and 1 = True. Legacy checkpoints use 0 and 1.')
94
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training.') # Default changed
95
+ parser.add_argument('--batch_size_test', type=int, default=64, help='Batch size for testing.') # Default changed
96
+ parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate for the model.') # Default changed
97
+ parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
98
+ parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
99
+ parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
100
+ parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
101
+ parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
102
+ parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
103
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
104
+ parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
105
+ parser.add_argument('--num_workers_train', type=int, default=0, help='Num workers training.') # Renamed from num_workers, kept default
106
+ parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
107
+ parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
108
+
109
+ # Logging and Saving
110
+ parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
111
+ parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large'])
112
+ parser.add_argument('--data_root', type=str, default='data/mazes', help='Data root.')
113
+
114
+ parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
115
+ parser.add_argument('--seed', type=int, default=412, help='Random seed.')
116
+ parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
117
+ parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
118
+ parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
119
+ parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading (for debugging)?') # Added back
120
+
121
+ # Tracking
122
+ parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
123
+ parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval') # Default changed
124
+
125
+ # Device
126
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
127
+ parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
128
+
129
+
130
+ args = parser.parse_args()
131
+ return args
132
+
133
+
134
+ if __name__=='__main__':
135
+
136
+ # Hosuekeeping
137
+ args = parse_args()
138
+
139
+ set_seed(args.seed, False)
140
+ if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
141
+
142
+ assert args.dataset in ['mazes-medium', 'mazes-large']
143
+
144
+
145
+
146
+ prediction_reshaper = [args.maze_route_length, 5] # Problem specific
147
+ args.out_dims = args.maze_route_length * 5 # Output dimension before reshaping
148
+
149
+ # For total reproducibility
150
+ zip_python_code(f'{args.log_dir}/repo_state.zip')
151
+ with open(f'{args.log_dir}/args.txt', 'w') as f:
152
+ print(args, file=f)
153
+
154
+ # Configure device string
155
+ device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
156
+ print(f'Running model {args.model} on {device} for dataset {args.dataset}')
157
+
158
+ # Build model conditionally
159
+ model = None
160
+ if args.model == 'ctm':
161
+ model = ContinuousThoughtMachine(
162
+ iterations=args.iterations,
163
+ d_model=args.d_model,
164
+ d_input=args.d_input,
165
+ heads=args.heads,
166
+ n_synch_out=args.n_synch_out,
167
+ n_synch_action=args.n_synch_action,
168
+ synapse_depth=args.synapse_depth,
169
+ memory_length=args.memory_length,
170
+ deep_nlms=args.deep_memory,
171
+ memory_hidden_dims=args.memory_hidden_dims,
172
+ do_layernorm_nlm=args.do_normalisation,
173
+ backbone_type=args.backbone_type,
174
+ positional_embedding_type=args.positional_embedding_type,
175
+ out_dims=args.out_dims,
176
+ prediction_reshaper=prediction_reshaper,
177
+ dropout=args.dropout,
178
+ dropout_nlm=args.dropout_nlm,
179
+ neuron_select_type=args.neuron_select_type,
180
+ n_random_pairing_self=args.n_random_pairing_self,
181
+ ).to(device)
182
+ elif args.model == 'lstm':
183
+ model = LSTMBaseline(
184
+ num_layers=args.num_layers,
185
+ iterations=args.iterations,
186
+ d_model=args.d_model,
187
+ d_input=args.d_input,
188
+ heads=args.heads,
189
+ backbone_type=args.backbone_type,
190
+ positional_embedding_type=args.positional_embedding_type,
191
+ out_dims=args.out_dims,
192
+ prediction_reshaper=prediction_reshaper,
193
+ dropout=args.dropout,
194
+ ).to(device)
195
+ elif args.model == 'ff':
196
+ model = FFBaseline(
197
+ d_model=args.d_model,
198
+ backbone_type=args.backbone_type,
199
+ out_dims=args.out_dims,
200
+ dropout=args.dropout,
201
+ ).to(device)
202
+ else:
203
+ raise ValueError(f"Unknown model type: {args.model}")
204
+
205
+ try:
206
+ # Determine pseudo input shape based on dataset
207
+ h_w = 39 if args.dataset in ['mazes-small', 'mazes-medium'] else 99 # Example dimensions
208
+ pseudo_inputs = torch.zeros((1, 3, h_w, h_w), device=device).float()
209
+ model(pseudo_inputs)
210
+ except Exception as e:
211
+ print(f"Warning: Pseudo forward pass failed: {e}")
212
+
213
+ print(f'Total params: {sum(p.numel() for p in model.parameters())}')
214
+
215
+ # Data
216
+ dataset_mean = [0,0,0] # For plotting later
217
+ dataset_std = [1,1,1]
218
+
219
+ which_maze = args.dataset.split('-')[-1]
220
+ data_root = f'{args.data_root}/{which_maze}'
221
+
222
+ train_data = MazeImageFolder(root=f'{data_root}/train/', which_set='train', maze_route_length=args.maze_route_length, expand_range=args.expand_range)
223
+ test_data = MazeImageFolder(root=f'{data_root}/test/', which_set='test', maze_route_length=args.maze_route_length, expand_range=args.expand_range)
224
+
225
+ num_workers_test = 1 # Defaulting to 1, can be changed
226
+ trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train, drop_last=True)
227
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False)
228
+
229
+ # For lazy modules so that we can get param count
230
+
231
+
232
+ model.train()
233
+
234
+ # Optimizer and scheduler
235
+ decay_params = []
236
+ no_decay_params = []
237
+ no_decay_names = []
238
+ for name, param in model.named_parameters():
239
+ if not param.requires_grad:
240
+ continue # Skip parameters that don't require gradients
241
+ if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
242
+ no_decay_params.append(param)
243
+ no_decay_names.append(name)
244
+ else:
245
+ decay_params.append(param)
246
+ if len(no_decay_names):
247
+ print(f'WARNING, excluding: {no_decay_names}')
248
+
249
+ # Optimizer and scheduler (Common setup)
250
+ if len(no_decay_names) and args.weight_decay!=0:
251
+ optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
252
+ {'params': no_decay_params, 'weight_decay':0}],
253
+ lr=args.lr,
254
+ eps=1e-8 if not args.use_amp else 1e-6)
255
+ else:
256
+ optimizer = torch.optim.AdamW(model.parameters(),
257
+ lr=args.lr,
258
+ eps=1e-8 if not args.use_amp else 1e-6,
259
+ weight_decay=args.weight_decay)
260
+
261
+ warmup_schedule = warmup(args.warmup_steps)
262
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
263
+ if args.use_scheduler:
264
+ if args.scheduler_type == 'multistep':
265
+ scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
266
+ elif args.scheduler_type == 'cosine':
267
+ scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
268
+ else:
269
+ raise NotImplementedError
270
+
271
+
272
+ # Metrics tracking
273
+ start_iter = 0
274
+ train_losses = []
275
+ test_losses = []
276
+ train_accuracies = [] # Per tick/step accuracy list
277
+ test_accuracies = []
278
+ train_accuracies_most_certain = [] # Accuracy, fine-grained
279
+ test_accuracies_most_certain = []
280
+ train_accuracies_most_certain_permaze = [] # Full maze accuracy
281
+ test_accuracies_most_certain_permaze = []
282
+ iters = []
283
+
284
+ scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
285
+ if args.reload:
286
+ checkpoint_path = f'{args.log_dir}/checkpoint.pt'
287
+ if os.path.isfile(checkpoint_path):
288
+ print(f'Reloading from: {checkpoint_path}')
289
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
290
+ if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
291
+ load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
292
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
293
+
294
+ if not args.reload_model_only:
295
+ print('Reloading optimizer etc.')
296
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
297
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
298
+ scaler.load_state_dict(checkpoint['scaler_state_dict']) # Load scaler state
299
+ start_iter = checkpoint['iteration']
300
+
301
+ if not args.ignore_metrics_when_reloading:
302
+ train_losses = checkpoint['train_losses']
303
+ test_losses = checkpoint['test_losses']
304
+ train_accuracies = checkpoint['train_accuracies']
305
+ test_accuracies = checkpoint['test_accuracies']
306
+ iters = checkpoint['iters']
307
+ train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
308
+ test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
309
+ train_accuracies_most_certain_permaze = checkpoint['train_accuracies_most_certain_permaze']
310
+ test_accuracies_most_certain_permaze = checkpoint['test_accuracies_most_certain_permaze']
311
+ else:
312
+ print("Ignoring metrics history upon reload.")
313
+
314
+ else:
315
+ print('Only reloading model!')
316
+
317
+ if 'torch_rng_state' in checkpoint:
318
+ # Reset seeds
319
+ torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte())
320
+ np.random.set_state(checkpoint['numpy_rng_state'])
321
+ random.setstate(checkpoint['random_rng_state'])
322
+
323
+ del checkpoint
324
+ import gc
325
+ gc.collect()
326
+ if torch.cuda.is_available():
327
+ torch.cuda.empty_cache()
328
+
329
+ if args.do_compile:
330
+ print('Compiling...')
331
+ if hasattr(model, 'backbone'):
332
+ model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True)
333
+ # Compile synapses only for CTM
334
+ if args.model == 'ctm':
335
+ model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True)
336
+
337
+ # Training
338
+ iterator = iter(trainloader)
339
+ with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar:
340
+ for bi in range(start_iter, args.training_iterations):
341
+ current_lr = optimizer.param_groups[-1]['lr']
342
+
343
+ try:
344
+ inputs, targets = next(iterator)
345
+ except StopIteration:
346
+ iterator = iter(trainloader)
347
+ inputs, targets = next(iterator)
348
+
349
+ inputs = inputs.to(device)
350
+ targets = targets.to(device) # Shape (B, SeqLength)
351
+
352
+ # All for nice metric printing:
353
+ loss = None
354
+ accuracy_finegrained = None # Per-step accuracy at chosen tick
355
+ where_most_certain_val = -1.0 # Default value
356
+ where_most_certain_std = 0.0
357
+ where_most_certain_min = -1
358
+ where_most_certain_max = -1
359
+ upto_where_mean = -1.0
360
+ upto_where_std = 0.0
361
+ upto_where_min = -1
362
+ upto_where_max = -1
363
+
364
+
365
+ # Model-specific forward, reshape, and loss calculation
366
+ with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
367
+ if args.do_compile: # CUDAGraph marking applied if compiling any model
368
+ torch.compiler.cudagraph_mark_step_begin()
369
+
370
+ if args.model == 'ctm':
371
+ # CTM output: (B, SeqLength*5, Ticks), Certainties: (B, Ticks)
372
+ predictions_raw, certainties, synchronisation = model(inputs)
373
+ # Reshape predictions: (B, SeqLength, 5, Ticks)
374
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))
375
+ loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=True)
376
+ # Accuracy uses predictions[B, S, C, T] indexed at where_most_certain[B] -> gives (B, S, C) -> argmax(2) -> (B,S)
377
+ accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()
378
+
379
+ elif args.model == 'lstm':
380
+ # LSTM output: (B, SeqLength*5, Ticks), Certainties: (B, Ticks)
381
+ predictions_raw, certainties, synchronisation = model(inputs)
382
+ # Reshape predictions: (B, SeqLength, 5, Ticks)
383
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))
384
+ loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False)
385
+ # where_most_certain should be -1 (last tick) here. Accuracy calc follows same logic.
386
+ accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()
387
+
388
+ elif args.model == 'ff':
389
+ # Assume FF output: (B, SeqLength*5)
390
+ predictions_raw = model(inputs)
391
+ # Reshape predictions: (B, SeqLength, 5)
392
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5)
393
+ # FF has no certainties, pass None. maze_loss must handle this.
394
+ # Unsqueeze predictions for compatibility with maze loss calcluation
395
+ loss, where_most_certain, upto_where = maze_loss(predictions.unsqueeze(-1), None, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False)
396
+ # where_most_certain should be -1 here. Accuracy uses 3D prediction tensor.
397
+ accuracy_finegrained = (predictions.argmax(2) == targets).float().mean().item()
398
+
399
+
400
+ # Extract stats from loss outputs if they are tensors
401
+ if torch.is_tensor(where_most_certain):
402
+ where_most_certain_val = where_most_certain.float().mean().item()
403
+ where_most_certain_std = where_most_certain.float().std().item()
404
+ where_most_certain_min = where_most_certain.min().item()
405
+ where_most_certain_max = where_most_certain.max().item()
406
+ elif isinstance(where_most_certain, int): # Handle case where it might return -1 directly
407
+ where_most_certain_val = float(where_most_certain)
408
+ where_most_certain_min = where_most_certain
409
+ where_most_certain_max = where_most_certain
410
+
411
+ if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0: # Check if it's a list/array
412
+ upto_where_mean = np.mean(upto_where)
413
+ upto_where_std = np.std(upto_where)
414
+ upto_where_min = np.min(upto_where)
415
+ upto_where_max = np.max(upto_where)
416
+
417
+
418
+ scaler.scale(loss).backward()
419
+
420
+ if args.gradient_clipping!=-1:
421
+ scaler.unscale_(optimizer)
422
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
423
+
424
+ scaler.step(optimizer)
425
+ scaler.update()
426
+ optimizer.zero_grad(set_to_none=True)
427
+ scheduler.step()
428
+
429
+ # Conditional Tqdm Description
430
+ pbar_desc = f'Loss={loss.item():0.3f}. Acc(step)={accuracy_finegrained:0.3f}. LR={current_lr:0.6f}.'
431
+ if args.model in ['ctm', 'lstm'] or torch.is_tensor(where_most_certain): # Show stats if available
432
+ pbar_desc += f' Where_certain={where_most_certain_val:0.2f}+-{where_most_certain_std:0.2f} ({where_most_certain_min:d}<->{where_most_certain_max:d}).'
433
+ if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
434
+ pbar_desc += f' Path pred stats: {upto_where_mean:0.2f}+-{upto_where_std:0.2f} ({upto_where_min:d} --> {upto_where_max:d})'
435
+
436
+ pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
437
+
438
+
439
+ # Metrics tracking and plotting
440
+ if bi%args.track_every==0 and (bi != 0 or args.reload_model_only):
441
+ model.eval() # Use eval mode for consistency during tracking
442
+ with torch.inference_mode(): # Use inference mode for tracking
443
+
444
+
445
+
446
+
447
+ # --- Quantitative Metrics ---
448
+ iters.append(bi)
449
+ # Re-initialize metric lists for this evaluation step
450
+ current_train_losses_eval = []
451
+ current_test_losses_eval = []
452
+ current_train_accuracies_eval = []
453
+ current_test_accuracies_eval = []
454
+ current_train_accuracies_most_certain_eval = []
455
+ current_test_accuracies_most_certain_eval = []
456
+ current_train_accuracies_most_certain_permaze_eval = []
457
+ current_test_accuracies_most_certain_permaze_eval = []
458
+
459
+ # TRAIN METRICS
460
+ pbar.set_description('Tracking: Computing TRAIN metrics')
461
+ loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test) # Use consistent num_workers
462
+ all_targets_list = []
463
+ all_predictions_list = [] # Per step/tick predictions argmax (N, S, T) or (N, S)
464
+ all_predictions_most_certain_list = [] # Predictions at chosen step/tick argmax (N, S)
465
+ all_losses = []
466
+
467
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
468
+ for inferi, (inputs, targets) in enumerate(loader):
469
+ inputs = inputs.to(device)
470
+ targets = targets.to(device)
471
+ all_targets_list.append(targets.detach().cpu().numpy()) # N x S
472
+
473
+ # Model-specific forward, reshape, loss for evaluation
474
+ if args.model == 'ctm':
475
+ predictions_raw, certainties, _ = model(inputs)
476
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
477
+ loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
478
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,C,T -> argmax class -> B,S,T
479
+ pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S
480
+ all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
481
+
482
+ elif args.model == 'lstm':
483
+ predictions_raw, certainties, _ = model(inputs)
484
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
485
+ loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
486
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,C,T
487
+ pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S (at last tick)
488
+ all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
489
+
490
+ elif args.model == 'ff':
491
+ predictions_raw = model(inputs) # B, S*C
492
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
493
+ loss, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
494
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S
495
+ all_predictions_most_certain_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S (same as above for FF)
496
+
497
+
498
+ all_losses.append(loss.item())
499
+
500
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break
501
+ pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})')
502
+ pbar_inner.update(1)
503
+
504
+ all_targets = np.concatenate(all_targets_list) # N, S
505
+ all_predictions = np.concatenate(all_predictions_list) # N, S, T or N, S
506
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list) # N, S
507
+
508
+ train_losses.append(np.mean(all_losses))
509
+ # Calculate per step/tick accuracy averaged over batches
510
+ if args.model in ['ctm', 'lstm']:
511
+ # all_predictions shape (N, S, T), all_targets shape (N, S) -> compare targets to each tick prediction
512
+ train_accuracies.append(np.mean(all_predictions == all_targets[:,:,np.newaxis], axis=0)) # Mean over N -> (S, T)
513
+ else: # FF
514
+ # all_predictions shape (N, S), all_targets shape (N, S)
515
+ train_accuracies.append(np.mean(all_predictions == all_targets, axis=0)) # Mean over N -> (S,)
516
+
517
+ # Calculate accuracy at chosen step/tick ("most certain") averaged over all steps and batches
518
+ train_accuracies_most_certain.append((all_targets == all_predictions_most_certain).mean()) # Scalar
519
+ # Calculate full maze accuracy at chosen step/tick averaged over batches
520
+ train_accuracies_most_certain_permaze.append((all_targets == all_predictions_most_certain).reshape(all_targets.shape[0], -1).all(-1).mean()) # Scalar
521
+
522
+
523
+ # TEST METRICS
524
+ pbar.set_description('Tracking: Computing TEST metrics')
525
+ loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
526
+ all_targets_list = []
527
+ all_predictions_list = []
528
+ all_predictions_most_certain_list = []
529
+ all_losses = []
530
+
531
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
532
+ for inferi, (inputs, targets) in enumerate(loader):
533
+ inputs = inputs.to(device)
534
+ targets = targets.to(device)
535
+ all_targets_list.append(targets.detach().cpu().numpy())
536
+
537
+ # Model-specific forward, reshape, loss for evaluation
538
+ if args.model == 'ctm':
539
+ predictions_raw, certainties, _ = model(inputs)
540
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
541
+ loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
542
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,T
543
+ pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S
544
+ all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
545
+
546
+ elif args.model == 'lstm':
547
+ predictions_raw, certainties, _ = model(inputs)
548
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
549
+ loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
550
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,T
551
+ pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S (at last tick)
552
+ all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
553
+
554
+ elif args.model == 'ff':
555
+ predictions_raw = model(inputs) # B, S*C
556
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
557
+ loss, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
558
+ all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S
559
+ all_predictions_most_certain_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S (same as above for FF)
560
+
561
+
562
+ all_losses.append(loss.item())
563
+
564
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
565
+ pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})')
566
+ pbar_inner.update(1)
567
+
568
+ all_targets = np.concatenate(all_targets_list)
569
+ all_predictions = np.concatenate(all_predictions_list)
570
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
571
+
572
+ test_losses.append(np.mean(all_losses))
573
+ # Calculate per step/tick accuracy
574
+ if args.model in ['ctm', 'lstm']:
575
+ test_accuracies.append(np.mean(all_predictions == all_targets[:,:,np.newaxis], axis=0)) # -> (S, T)
576
+ else: # FF
577
+ test_accuracies.append(np.mean(all_predictions == all_targets, axis=0)) # -> (S,)
578
+
579
+ # Calculate "most certain" accuracy
580
+ test_accuracies_most_certain.append((all_targets == all_predictions_most_certain).mean()) # Scalar
581
+ # Calculate full maze accuracy
582
+ test_accuracies_most_certain_permaze.append((all_targets == all_predictions_most_certain).reshape(all_targets.shape[0], -1).all(-1).mean()) # Scalar
583
+
584
+
585
+ # --- Plotting ---
586
+ # Accuracy Plot (Handling different dimensions)
587
+ figacc = plt.figure(figsize=(10, 10))
588
+ axacc_train = figacc.add_subplot(211)
589
+ axacc_test = figacc.add_subplot(212)
590
+ cm = sns.color_palette("viridis", as_cmap=True)
591
+
592
+ # Plot per step/tick accuracy
593
+ # train_accuracies is List[(S, T)] or List[(S,)]
594
+ # We need to average over S dimension for plotting
595
+ train_acc_plot = [np.mean(acc_s) for acc_s in train_accuracies] # List[Scalar] or List[Scalar] after mean
596
+ test_acc_plot = [np.mean(acc_s) for acc_s in test_accuracies] # List[Scalar] or List[Scalar] after mean
597
+
598
+ axacc_train.plot(iters, train_acc_plot, 'g-', alpha=0.5, label='Avg Step Acc')
599
+ axacc_test.plot(iters, test_acc_plot, 'g-', alpha=0.5, label='Avg Step Acc')
600
+
601
+
602
+ # Plot most certain accuracy
603
+ axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most Certain (Avg Step)')
604
+ axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most Certain (Avg Step)')
605
+ # Plot full maze accuracy
606
+ axacc_train.plot(iters, train_accuracies_most_certain_permaze, 'r-', alpha=0.6, label='Full Maze')
607
+ axacc_test.plot(iters, test_accuracies_most_certain_permaze, 'r-', alpha=0.6, label='Full Maze')
608
+
609
+ axacc_train.set_title('Train Accuracy')
610
+ axacc_test.set_title('Test Accuracy')
611
+ axacc_train.legend(loc='lower right')
612
+ axacc_test.legend(loc='lower right')
613
+ axacc_train.set_xlim([0, args.training_iterations])
614
+ axacc_test.set_xlim([0, args.training_iterations])
615
+ axacc_train.set_ylim([0, 1]) # Set Ylim for accuracy
616
+ axacc_test.set_ylim([0, 1])
617
+
618
+ figacc.tight_layout()
619
+ figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
620
+ plt.close(figacc)
621
+
622
+ # Loss Plot
623
+ figloss = plt.figure(figsize=(10, 5))
624
+ axloss = figloss.add_subplot(111)
625
+ axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}')
626
+ axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}')
627
+ axloss.legend(loc='upper right')
628
+ axloss.set_xlim([0, args.training_iterations])
629
+ axloss.set_ylim(bottom=0)
630
+
631
+ figloss.tight_layout()
632
+ figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
633
+ plt.close(figloss)
634
+
635
+ # --- Visualization Section (Conditional) ---
636
+ if args.model in ['ctm', 'lstm']:
637
+ # try:
638
+ inputs_viz, targets_viz = next(iter(testloader))
639
+ inputs_viz = inputs_viz.to(device)
640
+ targets_viz = targets_viz.to(device)
641
+ # Find longest path in batch for potentially better visualization
642
+ longest_index = (targets_viz!=4).sum(-1).argmax() # Action 4 assumed padding/end
643
+
644
+ # Track internal states
645
+ predictions_viz_raw, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
646
+
647
+ # Reshape predictions (assuming raw is B, D, T)
648
+ predictions_viz = predictions_viz_raw.reshape(predictions_viz_raw.size(0), -1, 5, predictions_viz_raw.size(-1)) # B, S, C, T
649
+
650
+ att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
651
+ attention_tracking_viz = attention_tracking_viz.reshape(
652
+ attention_tracking_viz.shape[0],
653
+ attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
654
+
655
+ # Plot dynamics (common plotting function)
656
+ plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
657
+
658
+ # Create maze GIF (task-specific plotting)
659
+ make_maze_gif((inputs_viz[longest_index].detach().cpu().numpy()+1)/2,
660
+ predictions_viz[longest_index].detach().cpu().numpy(), # Pass reshaped B,S,C,T -> S,C,T
661
+ targets_viz[longest_index].detach().cpu().numpy(), # S
662
+ attention_tracking_viz[:, longest_index], # Pass T, (H), H, W
663
+ args.log_dir)
664
+ # except Exception as e:
665
+ # print(f"Visualization failed for model {args.model}: {e}")
666
+ # --- End Visualization ---
667
+
668
+ model.train() # Switch back to train mode
669
+
670
+
671
+ # Save model checkpoint
672
+ if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
673
+ pbar.set_description('Saving model checkpoint...')
674
+ checkpoint_data = {
675
+ 'model_state_dict': model.state_dict(),
676
+ 'optimizer_state_dict': optimizer.state_dict(),
677
+ 'scheduler_state_dict': scheduler.state_dict(),
678
+ 'scaler_state_dict': scaler.state_dict(), # Save scaler state
679
+ 'iteration': bi,
680
+ # Save all tracked metrics
681
+ 'train_losses': train_losses,
682
+ 'test_losses': test_losses,
683
+ 'train_accuracies': train_accuracies, # List of (S, T) or (S,) arrays
684
+ 'test_accuracies': test_accuracies, # List of (S, T) or (S,) arrays
685
+ 'train_accuracies_most_certain': train_accuracies_most_certain, # List of scalars
686
+ 'test_accuracies_most_certain': test_accuracies_most_certain, # List of scalars
687
+ 'train_accuracies_most_certain_permaze': train_accuracies_most_certain_permaze, # List of scalars
688
+ 'test_accuracies_most_certain_permaze': test_accuracies_most_certain_permaze, # List of scalars
689
+ 'iters': iters,
690
+ 'args': args, # Save args used for this run
691
+ # RNG states
692
+ 'torch_rng_state': torch.get_rng_state(),
693
+ 'numpy_rng_state': np.random.get_state(),
694
+ 'random_rng_state': random.getstate(),
695
+ }
696
+ torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
697
+
698
+ pbar.update(1)
tasks/mazes/train_distributed.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import gc
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import seaborn as sns
9
+ sns.set_style('darkgrid')
10
+ import torch
11
+ if torch.cuda.is_available():
12
+ # For faster
13
+ torch.set_float32_matmul_precision('high')
14
+ import torch.distributed as dist
15
+ from torch.nn.parallel import DistributedDataParallel as DDP
16
+ from torch.utils.data.distributed import DistributedSampler
17
+ from utils.samplers import FastRandomDistributedSampler
18
+ from tqdm.auto import tqdm
19
+
20
+ # Data/Task Specific Imports
21
+ from data.custom_datasets import MazeImageFolder
22
+
23
+ # Model Imports
24
+ from models.ctm import ContinuousThoughtMachine
25
+ from models.lstm import LSTMBaseline
26
+ from models.ff import FFBaseline
27
+
28
+ # Plotting/Utils Imports
29
+ from tasks.mazes.plotting import make_maze_gif
30
+ from tasks.image_classification.plotting import plot_neural_dynamics
31
+ from utils.housekeeping import set_seed, zip_python_code
32
+ from utils.losses import maze_loss
33
+ from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
34
+
35
+ import torchvision
36
+ torchvision.disable_beta_transforms_warning()
37
+
38
+ import warnings
39
+ warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
40
+ warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
41
+ warnings.filterwarnings(
42
+ "ignore",
43
+ "Corrupt EXIF data",
44
+ UserWarning,
45
+ r"^PIL\.TiffImagePlugin$"
46
+ )
47
+ warnings.filterwarnings(
48
+ "ignore",
49
+ "UserWarning: Metadata Warning",
50
+ UserWarning,
51
+ r"^PIL\.TiffImagePlugin$"
52
+ )
53
+ warnings.filterwarnings(
54
+ "ignore",
55
+ "UserWarning: Truncated File Read",
56
+ UserWarning,
57
+ r"^PIL\.TiffImagePlugin$"
58
+ )
59
+
60
+
61
+ def parse_args():
62
+ parser = argparse.ArgumentParser()
63
+
64
+ # Model Selection
65
+ parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
66
+
67
+ # Model Architecture
68
+ parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
69
+ parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
70
+ parser.add_argument('--backbone_type', type=str, default='resnet34-2', help='Type of backbone featureiser.')
71
+ # CTM / LSTM specific
72
+ parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
73
+ parser.add_argument('--heads', type=int, default=8, help='Number of attention heads (CTM, LSTM).')
74
+ parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
75
+ parser.add_argument('--positional_embedding_type', type=str, default='none',
76
+ help='Type of positional embedding (CTM, LSTM).', choices=['none',
77
+ 'learnable-fourier',
78
+ 'multi-learnable-fourier',
79
+ 'custom-rotational'])
80
+ # CTM specific
81
+ parser.add_argument('--synapse_depth', type=int, default=8, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
82
+ parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).')
83
+ parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).')
84
+ parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
85
+ parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
86
+ parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
87
+ parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
88
+ parser.add_argument('--memory_hidden_dims', type=int, default=32, help='Hidden dimensions of the memory if using deep memory (CTM only).')
89
+ parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
90
+ parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
91
+ # LSTM specific
92
+ parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
93
+
94
+ # Task Specific Args
95
+ parser.add_argument('--maze_route_length', type=int, default=100, help='Length to truncate targets.')
96
+ parser.add_argument('--cirriculum_lookahead', type=int, default=5, help='How far to look ahead for cirriculum.')
97
+
98
+ # Training
99
+ parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training (per GPU).')
100
+ parser.add_argument('--batch_size_test', type=int, default=64, help='Batch size for testing (per GPU).')
101
+ parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate for the model.')
102
+ parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
103
+ parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
104
+ parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
105
+ parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
106
+ parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
107
+ parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
108
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
109
+ parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
110
+ parser.add_argument('--num_workers_train', type=int, default=0, help='Num workers training.')
111
+ parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
112
+ parser.add_argument('--use_custom_sampler', action=argparse.BooleanOptionalAction, default=False, help='Use custom fast sampler to avoid reshuffling.')
113
+ parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
114
+
115
+ # Logging and Saving
116
+ parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
117
+ parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large'])
118
+ parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
119
+ parser.add_argument('--seed', type=int, default=412, help='Random seed.')
120
+ parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
121
+ parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?') # Default False based on user edit
122
+ parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=False, help='Should use strict reload for model weights.')
123
+ parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading (for debugging)?')
124
+
125
+ # Tracking
126
+ parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
127
+ parser.add_argument('--n_test_batches', type=int, default=2, help='How many minibatches to approx metrics. Set to -1 for full eval')
128
+
129
+ # Precision
130
+ parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
131
+
132
+ args = parser.parse_args()
133
+ return args
134
+
135
+ # --- DDP Setup Functions ---
136
+ def setup_ddp():
137
+ if 'RANK' not in os.environ:
138
+ os.environ['RANK'] = '0'
139
+ os.environ['WORLD_SIZE'] = '1'
140
+ os.environ['MASTER_ADDR'] = 'localhost'
141
+ os.environ['MASTER_PORT'] = '12356' # Different port from image classification
142
+ os.environ['LOCAL_RANK'] = '0'
143
+ print("Running in non-distributed mode (simulated DDP setup).")
144
+ if not torch.cuda.is_available() or int(os.environ['WORLD_SIZE']) == 1:
145
+ dist.init_process_group(backend='gloo')
146
+ print("Initialized process group with Gloo backend for single/CPU process.")
147
+ rank = int(os.environ['RANK'])
148
+ world_size = int(os.environ['WORLD_SIZE'])
149
+ local_rank = int(os.environ['LOCAL_RANK'])
150
+ return rank, world_size, local_rank
151
+
152
+ dist.init_process_group(backend='nccl')
153
+ rank = int(os.environ['RANK'])
154
+ world_size = int(os.environ['WORLD_SIZE'])
155
+ local_rank = int(os.environ['LOCAL_RANK'])
156
+ if torch.cuda.is_available():
157
+ torch.cuda.set_device(local_rank)
158
+ print(f"Rank {rank} setup on GPU {local_rank}")
159
+ else:
160
+ print(f"Rank {rank} setup on CPU")
161
+ return rank, world_size, local_rank
162
+
163
+ def cleanup_ddp():
164
+ if dist.is_initialized():
165
+ dist.destroy_process_group()
166
+ print("DDP cleanup complete.")
167
+
168
+ def is_main_process(rank):
169
+ return rank == 0
170
+ # --- End DDP Setup ---
171
+
172
+
173
+ if __name__=='__main__':
174
+
175
+ args = parse_args()
176
+
177
+ rank, world_size, local_rank = setup_ddp()
178
+
179
+ set_seed(args.seed + rank, False)
180
+
181
+ # Rank 0 handles directory creation and initial logging
182
+ if is_main_process(rank):
183
+ if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
184
+ zip_python_code(f'{args.log_dir}/repo_state.zip')
185
+ with open(f'{args.log_dir}/args.txt', 'w') as f:
186
+ print(args, file=f)
187
+ if world_size > 1: dist.barrier()
188
+
189
+
190
+ assert args.dataset in ['mazes-medium', 'mazes-large']
191
+
192
+ # Setup Device
193
+ if torch.cuda.is_available():
194
+ device = torch.device(f'cuda:{local_rank}')
195
+ else:
196
+ device = torch.device('cpu')
197
+ if world_size > 1: warnings.warn("Running DDP on CPU is not recommended.")
198
+
199
+ if is_main_process(rank):
200
+ print(f'Main process (Rank {rank}): Using device {device}. World size: {world_size}. Model: {args.model}')
201
+
202
+
203
+ prediction_reshaper = [args.maze_route_length, 5]
204
+ args.out_dims = args.maze_route_length * 5
205
+
206
+ # --- Model Definition (Conditional) ---
207
+ model_base = None # Base model before DDP wrapping
208
+ if args.model == 'ctm':
209
+ model_base = ContinuousThoughtMachine(
210
+ iterations=args.iterations,
211
+ d_model=args.d_model,
212
+ d_input=args.d_input,
213
+ heads=args.heads,
214
+ n_synch_out=args.n_synch_out,
215
+ n_synch_action=args.n_synch_action,
216
+ synapse_depth=args.synapse_depth,
217
+ memory_length=args.memory_length,
218
+ deep_nlms=args.deep_memory,
219
+ memory_hidden_dims=args.memory_hidden_dims,
220
+ do_layernorm_nlm=args.do_normalisation,
221
+ backbone_type=args.backbone_type,
222
+ positional_embedding_type=args.positional_embedding_type,
223
+ out_dims=args.out_dims,
224
+ prediction_reshaper=prediction_reshaper,
225
+ dropout=args.dropout,
226
+ dropout_nlm=args.dropout_nlm,
227
+ neuron_select_type=args.neuron_select_type,
228
+ n_random_pairing_self=args.n_random_pairing_self,
229
+ ).to(device)
230
+ elif args.model == 'lstm':
231
+ model_base = LSTMBaseline(
232
+ num_layers=args.num_layers,
233
+ iterations=args.iterations,
234
+ d_model=args.d_model,
235
+ d_input=args.d_input,
236
+ heads=args.heads,
237
+ backbone_type=args.backbone_type,
238
+ positional_embedding_type=args.positional_embedding_type,
239
+ out_dims=args.out_dims,
240
+ prediction_reshaper=prediction_reshaper,
241
+ dropout=args.dropout,
242
+ ).to(device)
243
+ elif args.model == 'ff':
244
+ model_base = FFBaseline(
245
+ d_model=args.d_model,
246
+ backbone_type=args.backbone_type,
247
+ out_dims=args.out_dims,
248
+ dropout=args.dropout,
249
+ ).to(device)
250
+ else:
251
+ raise ValueError(f"Unknown model type: {args.model}")
252
+
253
+ # Use pseudo-input *before* DDP wrapping
254
+ try:
255
+ # Determine pseudo input shape based on dataset
256
+ h_w = 39 if args.dataset in ['mazes-small', 'mazes-medium'] else 99 # Example dimensions
257
+ pseudo_inputs = torch.zeros((1, 3, h_w, h_w), device=device).float()
258
+ model_base(pseudo_inputs)
259
+ except Exception as e:
260
+ print(f"Warning: Pseudo forward pass failed: {e}")
261
+
262
+ if is_main_process(rank):
263
+ print(f'Total params: {sum(p.numel() for p in model_base.parameters() if p.requires_grad)}')
264
+
265
+ # Wrap model with DDP
266
+ if device.type == 'cuda' and world_size > 1:
267
+ model = DDP(model_base, device_ids=[local_rank], output_device=local_rank)
268
+ elif device.type == 'cpu' and world_size > 1:
269
+ model = DDP(model_base)
270
+ else:
271
+ model = model_base
272
+ # --- End Model Definition ---
273
+
274
+
275
+ # Data Loading (After model setup to allow pseudo pass first)
276
+ dataset_mean = [0,0,0]
277
+ dataset_std = [1,1,1]
278
+ which_maze = args.dataset.split('-')[-1]
279
+ data_root = f'data/mazes/{which_maze}'
280
+
281
+ train_data = MazeImageFolder(root=f'{data_root}/train/', which_set='train', maze_route_length=args.maze_route_length)
282
+ test_data = MazeImageFolder(root=f'{data_root}/test/', which_set='test', maze_route_length=args.maze_route_length)
283
+
284
+ train_sampler = (FastRandomDistributedSampler(train_data, num_replicas=world_size, rank=rank, seed=args.seed, epoch_steps=int(10e10))
285
+ if args.use_custom_sampler else
286
+ DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True, seed=args.seed))
287
+ test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank, shuffle=False, seed=args.seed)
288
+
289
+ num_workers_test = 1
290
+ trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler,
291
+ num_workers=args.num_workers_train, pin_memory=True, drop_last=True)
292
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, sampler=test_sampler,
293
+ num_workers=num_workers_test, pin_memory=True, drop_last=False)
294
+
295
+
296
+ # Optimizer and scheduler
297
+ decay_params = []
298
+ no_decay_params = []
299
+ no_decay_names = []
300
+ for name, param in model.named_parameters():
301
+ if not param.requires_grad:
302
+ continue # Skip parameters that don't require gradients
303
+ if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
304
+ no_decay_params.append(param)
305
+ no_decay_names.append(name)
306
+ else:
307
+ decay_params.append(param)
308
+ if len(no_decay_names) and is_main_process(rank):
309
+ print(f'WARNING, excluding: {no_decay_names}')
310
+
311
+ # Optimizer and scheduler (Common setup)
312
+ if len(no_decay_names) and args.weight_decay!=0:
313
+ optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
314
+ {'params': no_decay_params, 'weight_decay':0}],
315
+ lr=args.lr,
316
+ eps=1e-8 if not args.use_amp else 1e-6)
317
+ else:
318
+ optimizer = torch.optim.AdamW(model.parameters(),
319
+ lr=args.lr,
320
+ eps=1e-8 if not args.use_amp else 1e-6,
321
+ weight_decay=args.weight_decay)
322
+
323
+ warmup_schedule = warmup(args.warmup_steps)
324
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
325
+ if args.use_scheduler:
326
+ if args.scheduler_type == 'multistep':
327
+ scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
328
+ elif args.scheduler_type == 'cosine':
329
+ scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
330
+ else:
331
+ raise NotImplementedError
332
+
333
+
334
+ # Metrics tracking (Rank 0 stores history)
335
+ start_iter = 0
336
+ iters = []
337
+ train_losses, test_losses = [], []
338
+ train_accuracies, test_accuracies = [], [] # Avg Step Acc (scalar list)
339
+ train_accuracies_most_certain, test_accuracies_most_certain = [], [] # Avg Step Acc @ Certain tick (scalar list)
340
+ train_accuracies_most_certain_permaze, test_accuracies_most_certain_permaze = [], [] # Full Maze Acc @ Certain tick (scalar list)
341
+
342
+
343
+ scaler = torch.amp.GradScaler("cuda" if device.type == 'cuda' else "cpu", enabled=args.use_amp)
344
+
345
+ # Reloading Logic
346
+ if args.reload:
347
+ map_location = device
348
+ chkpt_path = f'{args.log_dir}/checkpoint.pt'
349
+ if os.path.isfile(chkpt_path):
350
+ print(f'Rank {rank}: Reloading from: {chkpt_path}')
351
+ if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
352
+
353
+ checkpoint = torch.load(chkpt_path, map_location=map_location, weights_only=False)
354
+
355
+ model_to_load = model.module if isinstance(model, DDP) else model
356
+ state_dict = checkpoint['model_state_dict']
357
+ has_module_prefix = all(k.startswith('module.') for k in state_dict)
358
+ is_wrapped = isinstance(model, DDP)
359
+
360
+ if has_module_prefix and not is_wrapped:
361
+ state_dict = {k.partition('module.')[2]: v for k,v in state_dict.items()}
362
+ elif not has_module_prefix and is_wrapped:
363
+ load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
364
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
365
+ state_dict = None # Prevent loading again
366
+
367
+ if state_dict is not None:
368
+ load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
369
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
370
+
371
+
372
+
373
+ if not args.reload_model_only:
374
+ print(f'Rank {rank}: Reloading optimizer, scheduler, scaler, iteration.')
375
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
376
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
377
+ scaler.load_state_dict(checkpoint['scaler_state_dict'])
378
+ start_iter = checkpoint['iteration']
379
+
380
+ if is_main_process(rank) and not args.ignore_metrics_when_reloading:
381
+ print(f'Rank {rank}: Reloading metrics history.')
382
+ iters = checkpoint['iters']
383
+ train_losses = checkpoint['train_losses']
384
+ test_losses = checkpoint['test_losses']
385
+ train_accuracies = checkpoint['train_accuracies'] # Reloading simplified avg step acc list
386
+ test_accuracies = checkpoint['test_accuracies']
387
+ train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
388
+ test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
389
+ train_accuracies_most_certain_permaze = checkpoint['train_accuracies_most_certain_permaze']
390
+ test_accuracies_most_certain_permaze = checkpoint['test_accuracies_most_certain_permaze']
391
+ elif is_main_process(rank) and args.ignore_metrics_when_reloading:
392
+ print(f'Rank {rank}: Ignoring metrics history upon reload.')
393
+ else:
394
+ print(f'Rank {rank}: Only reloading model weights!')
395
+
396
+ if is_main_process(rank) and 'torch_rng_state' in checkpoint and not args.reload_model_only:
397
+ print(f'Rank {rank}: Loading RNG states.')
398
+ torch.set_rng_state(checkpoint['torch_rng_state'].cpu())
399
+ np.random.set_state(checkpoint['numpy_rng_state'])
400
+ random.setstate(checkpoint['random_rng_state'])
401
+
402
+ del checkpoint
403
+ gc.collect()
404
+ if torch.cuda.is_available():
405
+ torch.cuda.empty_cache()
406
+ print(f"Rank {rank}: Reload finished, starting from iteration {start_iter}")
407
+ else:
408
+ print(f"Rank {rank}: Checkpoint not found at {chkpt_path}, starting from scratch.")
409
+
410
+
411
+ if world_size > 1: dist.barrier()
412
+
413
+
414
+ # Conditional Compilation
415
+ if args.do_compile:
416
+ if is_main_process(rank): print('Compiling model components...')
417
+ model_to_compile = model.module if isinstance(model, DDP) else model
418
+ if hasattr(model_to_compile, 'backbone'):
419
+ model_to_compile.backbone = torch.compile(model_to_compile.backbone, mode='reduce-overhead', fullgraph=True)
420
+ if args.model == 'ctm':
421
+ model_to_compile.synapses = torch.compile(model_to_compile.synapses, mode='reduce-overhead', fullgraph=True)
422
+ if world_size > 1: dist.barrier()
423
+ if is_main_process(rank): print('Compilation finished.')
424
+
425
+
426
+ # --- Training Loop ---
427
+ model.train()
428
+ pbar = tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True, disable=not is_main_process(rank))
429
+
430
+ iterator = iter(trainloader)
431
+
432
+ for bi in range(start_iter, args.training_iterations):
433
+
434
+ # --- Evaluation and Plotting (Rank 0 + Aggregation) ---
435
+ if bi % args.track_every == 0 and (bi != 0 or args.reload_model_only):
436
+ model.eval()
437
+ with torch.inference_mode():
438
+
439
+ # --- Distributed Evaluation ---
440
+ if is_main_process(rank): iters.append(bi) # Track iterations on rank 0
441
+
442
+ # Initialize accumulators on device
443
+ total_train_loss = torch.tensor(0.0, device=device)
444
+ total_train_correct_certain = torch.tensor(0.0, device=device) # Sum correct steps @ certain tick
445
+ total_train_mazes_solved = torch.tensor(0.0, device=device) # Sum solved mazes @ certain tick
446
+ total_train_steps = torch.tensor(0.0, device=device) # Total steps evaluated (B * S)
447
+ total_train_mazes = torch.tensor(0.0, device=device) # Total mazes evaluated (B)
448
+
449
+ # TRAIN METRICS
450
+ train_eval_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=False)
451
+ train_eval_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, sampler=train_eval_sampler, num_workers=num_workers_test, pin_memory=True)
452
+
453
+ pbar_inner_desc = 'Eval Train (Rank 0)' if is_main_process(rank) else None
454
+ with tqdm(total=len(train_eval_loader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
455
+ for inferi, (inputs, targets) in enumerate(train_eval_loader):
456
+ inputs = inputs.to(device, non_blocking=True)
457
+ targets = targets.to(device, non_blocking=True) # B, S
458
+ batch_size = inputs.size(0)
459
+ seq_len = targets.size(1)
460
+
461
+ loss_eval = None
462
+ pred_at_certain = None # Shape B, S
463
+ if args.model == 'ctm':
464
+ predictions_raw, certainties, _ = model(inputs)
465
+ predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1)) # B,S,C,T
466
+ loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
467
+ pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
468
+ elif args.model == 'lstm':
469
+ predictions_raw, certainties, _ = model(inputs)
470
+ predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1)) # B,S,C,T
471
+ loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
472
+ pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
473
+ elif args.model == 'ff':
474
+ predictions_raw = model(inputs) # B, S*C
475
+ predictions = predictions_raw.reshape(batch_size, -1, 5) # B,S,C
476
+ loss_eval, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
477
+ pred_at_certain = predictions.argmax(2)
478
+
479
+ # Accumulate metrics
480
+ total_train_loss += loss_eval * batch_size # Sum losses
481
+ correct_steps = (pred_at_certain == targets) # B, S boolean
482
+ total_train_correct_certain += correct_steps.sum() # Sum correct steps across batch
483
+ total_train_mazes_solved += correct_steps.all(dim=-1).sum() # Sum mazes where all steps are correct
484
+ total_train_steps += batch_size * seq_len
485
+ total_train_mazes += batch_size
486
+
487
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
488
+ pbar_inner.update(1)
489
+
490
+ # Aggregate Train Metrics
491
+ if world_size > 1:
492
+ dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
493
+ dist.all_reduce(total_train_correct_certain, op=dist.ReduceOp.SUM)
494
+ dist.all_reduce(total_train_mazes_solved, op=dist.ReduceOp.SUM)
495
+ dist.all_reduce(total_train_steps, op=dist.ReduceOp.SUM)
496
+ dist.all_reduce(total_train_mazes, op=dist.ReduceOp.SUM)
497
+
498
+ # Calculate final Train metrics on Rank 0
499
+ if is_main_process(rank) and total_train_mazes > 0:
500
+ avg_train_loss = total_train_loss.item() / total_train_mazes.item() # Avg loss per maze/sample
501
+ avg_train_acc_step = total_train_correct_certain.item() / total_train_steps.item() # Avg correct step %
502
+ avg_train_acc_maze = total_train_mazes_solved.item() / total_train_mazes.item() # Avg full maze solved %
503
+ train_losses.append(avg_train_loss)
504
+ train_accuracies_most_certain.append(avg_train_acc_step)
505
+ train_accuracies_most_certain_permaze.append(avg_train_acc_maze)
506
+ # train_accuracies list remains unused/placeholder for this simplified metric structure
507
+ print(f"Iter {bi} Train Metrics (Agg): Loss={avg_train_loss:.4f}, StepAcc={avg_train_acc_step:.4f}, MazeAcc={avg_train_acc_maze:.4f}")
508
+
509
+ # TEST METRICS
510
+ total_test_loss = torch.tensor(0.0, device=device)
511
+ total_test_correct_certain = torch.tensor(0.0, device=device)
512
+ total_test_mazes_solved = torch.tensor(0.0, device=device)
513
+ total_test_steps = torch.tensor(0.0, device=device)
514
+ total_test_mazes = torch.tensor(0.0, device=device)
515
+
516
+ pbar_inner_desc = 'Eval Test (Rank 0)' if is_main_process(rank) else None
517
+ with tqdm(total=len(testloader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
518
+ for inferi, (inputs, targets) in enumerate(testloader):
519
+ inputs = inputs.to(device, non_blocking=True)
520
+ targets = targets.to(device, non_blocking=True)
521
+ batch_size = inputs.size(0)
522
+ seq_len = targets.size(1)
523
+
524
+ loss_eval = None
525
+ pred_at_certain = None
526
+ if args.model == 'ctm':
527
+ predictions_raw, certainties, _ = model(inputs)
528
+ predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1))
529
+ loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
530
+ pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
531
+ elif args.model == 'lstm':
532
+ predictions_raw, certainties, _ = model(inputs)
533
+ predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1))
534
+ loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False)
535
+ pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
536
+ elif args.model == 'ff':
537
+ predictions_raw = model(inputs)
538
+ predictions = predictions_raw.reshape(batch_size, -1, 5)
539
+ loss_eval, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False)
540
+ pred_at_certain = predictions.argmax(2)
541
+
542
+ total_test_loss += loss_eval * batch_size
543
+ correct_steps = (pred_at_certain == targets)
544
+ total_test_correct_certain += correct_steps.sum()
545
+ total_test_mazes_solved += correct_steps.all(dim=-1).sum()
546
+ total_test_steps += batch_size * seq_len
547
+ total_test_mazes += batch_size
548
+
549
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
550
+ pbar_inner.update(1)
551
+
552
+ # Aggregate Test Metrics
553
+ if world_size > 1:
554
+ dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
555
+ dist.all_reduce(total_test_correct_certain, op=dist.ReduceOp.SUM)
556
+ dist.all_reduce(total_test_mazes_solved, op=dist.ReduceOp.SUM)
557
+ dist.all_reduce(total_test_steps, op=dist.ReduceOp.SUM)
558
+ dist.all_reduce(total_test_mazes, op=dist.ReduceOp.SUM)
559
+
560
+ # Calculate and Plot final Test metrics on Rank 0
561
+ if is_main_process(rank) and total_test_mazes > 0:
562
+ avg_test_loss = total_test_loss.item() / total_test_mazes.item()
563
+ avg_test_acc_step = total_test_correct_certain.item() / total_test_steps.item()
564
+ avg_test_acc_maze = total_test_mazes_solved.item() / total_test_mazes.item()
565
+ test_losses.append(avg_test_loss)
566
+ test_accuracies_most_certain.append(avg_test_acc_step)
567
+ test_accuracies_most_certain_permaze.append(avg_test_acc_maze)
568
+ print(f"Iter {bi} Test Metrics (Agg): Loss={avg_test_loss:.4f}, StepAcc={avg_test_acc_step:.4f}, MazeAcc={avg_test_acc_maze:.4f}\n")
569
+
570
+ # --- Plotting ---
571
+ figacc = plt.figure(figsize=(10, 10))
572
+ axacc_train = figacc.add_subplot(211)
573
+ axacc_test = figacc.add_subplot(212)
574
+
575
+ # Plot Avg Step Accuracy
576
+ axacc_train.plot(iters, train_accuracies_most_certain, 'k-', alpha=0.7, label=f'Avg Step Acc ({train_accuracies_most_certain[-1]:.3f})')
577
+ axacc_test.plot(iters, test_accuracies_most_certain, 'k-', alpha=0.7, label=f'Avg Step Acc ({test_accuracies_most_certain[-1]:.3f})')
578
+ # Plot Full Maze Accuracy
579
+ axacc_train.plot(iters, train_accuracies_most_certain_permaze, 'r-', alpha=0.6, label=f'Full Maze Acc ({train_accuracies_most_certain_permaze[-1]:.3f})')
580
+ axacc_test.plot(iters, test_accuracies_most_certain_permaze, 'r-', alpha=0.6, label=f'Full Maze Acc ({test_accuracies_most_certain_permaze[-1]:.3f})')
581
+
582
+ axacc_train.set_title('Train Accuracy (Aggregated)')
583
+ axacc_test.set_title('Test Accuracy (Aggregated)')
584
+ axacc_train.legend(loc='lower right')
585
+ axacc_test.legend(loc='lower right')
586
+ axacc_train.set_xlim([0, args.training_iterations])
587
+ axacc_test.set_xlim([0, args.training_iterations])
588
+ axacc_train.set_ylim([0, 1])
589
+ axacc_test.set_ylim([0, 1])
590
+
591
+ figacc.tight_layout()
592
+ figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
593
+ plt.close(figacc)
594
+
595
+ # Loss Plot
596
+ figloss = plt.figure(figsize=(10, 5))
597
+ axloss = figloss.add_subplot(111)
598
+ axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train (Agg): {train_losses[-1]:.4f}')
599
+ axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test (Agg): {test_losses[-1]:.4f}')
600
+ axloss.legend(loc='upper right')
601
+ axloss.set_xlabel("Iteration")
602
+ axloss.set_ylabel("Loss")
603
+ axloss.set_xlim([0, args.training_iterations])
604
+ axloss.set_ylim(bottom=0)
605
+ figloss.tight_layout()
606
+ figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
607
+ plt.close(figloss)
608
+ # --- End Plotting ---
609
+
610
+
611
+ # --- Visualization (Rank 0, Conditional) ---
612
+ if is_main_process(rank) and args.model in ['ctm', 'lstm']:
613
+ # try:
614
+ model_module = model.module if isinstance(model, DDP) else model
615
+ # Use a consistent batch for viz if possible, or just next batch
616
+ inputs_viz, targets_viz = next(iter(testloader))
617
+ inputs_viz = inputs_viz.to(device)
618
+ targets_viz = targets_viz.to(device)
619
+ longest_index = (targets_viz!=4).sum(-1).argmax() # 4 assumed padding
620
+
621
+ pbar.set_description('Tracking (Rank 0): Viz Fwd Pass')
622
+ predictions_viz_raw, _, _, _, post_activations_viz, attention_tracking_viz = model_module(inputs_viz, track=True)
623
+ predictions_viz = predictions_viz_raw.reshape(predictions_viz_raw.size(0), -1, 5, predictions_viz_raw.size(-1))
624
+
625
+ att_shape = (model.module.kv_features.shape[2], model.module.kv_features.shape[3])
626
+ attention_tracking_viz = attention_tracking_viz.reshape(
627
+ attention_tracking_viz.shape[0],
628
+ attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
629
+
630
+ pbar.set_description('Tracking (Rank 0): Dynamics Plot')
631
+ plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
632
+
633
+ pbar.set_description('Tracking (Rank 0): Maze GIF')
634
+ if attention_tracking_viz is not None:
635
+ make_maze_gif((inputs_viz[longest_index].detach().cpu().numpy()+1)/2,
636
+ predictions_viz[longest_index].detach().cpu().numpy(),
637
+ targets_viz[longest_index].detach().cpu().numpy(),
638
+ attention_tracking_viz[:, longest_index],
639
+ args.log_dir)
640
+ # else:
641
+ # print("Skipping maze GIF due to attention shape issue.")
642
+
643
+ # except Exception as e_viz:
644
+ # print(f"Rank 0 visualization failed: {e_viz}")
645
+ # --- End Visualization ---
646
+
647
+ gc.collect()
648
+ if torch.cuda.is_available():
649
+ torch.cuda.empty_cache()
650
+ if world_size > 1: dist.barrier()
651
+ model.train()
652
+ # --- End Evaluation Block ---
653
+
654
+
655
+
656
+
657
+ if hasattr(train_sampler, 'set_epoch'): # Check if sampler has set_epoch
658
+ train_sampler.set_epoch(bi)
659
+
660
+ current_lr = optimizer.param_groups[-1]['lr']
661
+
662
+ try:
663
+ inputs, targets = next(iterator)
664
+ except StopIteration:
665
+ iterator = iter(trainloader)
666
+ inputs, targets = next(iterator)
667
+
668
+ inputs = inputs.to(device, non_blocking=True)
669
+ targets = targets.to(device, non_blocking=True)
670
+
671
+ # Defaults for logging
672
+ loss = torch.tensor(0.0, device=device) # Need loss defined for logging scope
673
+ accuracy_finegrained = 0.0
674
+ where_most_certain_val = -1.0
675
+ where_most_certain_std = 0.0
676
+ where_most_certain_min = -1
677
+ where_most_certain_max = -1
678
+ upto_where_mean = -1.0
679
+ upto_where_std = 0.0
680
+ upto_where_min = -1
681
+ upto_where_max = -1
682
+
683
+ with torch.autocast(device_type="cuda" if device.type == 'cuda' else "cpu", dtype=torch.float16, enabled=args.use_amp):
684
+ if args.do_compile: torch.compiler.cudagraph_mark_step_begin()
685
+
686
+ if args.model == 'ctm':
687
+ predictions_raw, certainties, _ = model(inputs)
688
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
689
+ loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=True)
690
+ with torch.no_grad(): # Calculate local accuracy for logging
691
+ accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=device), :, where_most_certain] == targets).float().mean().item()
692
+ elif args.model == 'lstm':
693
+ predictions_raw, certainties, _ = model(inputs)
694
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
695
+ loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False) # where = -1
696
+ with torch.no_grad():
697
+ accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=device), :, where_most_certain] == targets).float().mean().item()
698
+ elif args.model == 'ff':
699
+ predictions_raw = model(inputs) # B, S*C
700
+ predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
701
+ loss, where_most_certain, upto_where = maze_loss(predictions.unsqueeze(-1), None, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False) # where = -1
702
+ with torch.no_grad():
703
+ accuracy_finegrained = (predictions.argmax(2) == targets).float().mean().item()
704
+
705
+ # Extract stats from loss outputs
706
+ if torch.is_tensor(where_most_certain):
707
+ where_most_certain_val = where_most_certain.float().mean().item()
708
+ where_most_certain_std = where_most_certain.float().std().item()
709
+ where_most_certain_min = where_most_certain.min().item()
710
+ where_most_certain_max = where_most_certain.max().item()
711
+ elif isinstance(where_most_certain, int):
712
+ where_most_certain_val = float(where_most_certain); where_most_certain_min = where_most_certain; where_most_certain_max = where_most_certain
713
+ if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
714
+ upto_where_mean = np.mean(upto_where); upto_where_std = np.std(upto_where); upto_where_min = np.min(upto_where); upto_where_max = np.max(upto_where)
715
+
716
+ # Backprop / Step
717
+ scaler.scale(loss).backward()
718
+ if args.gradient_clipping!=-1:
719
+ scaler.unscale_(optimizer)
720
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
721
+ scaler.step(optimizer)
722
+ scaler.update()
723
+ optimizer.zero_grad(set_to_none=True)
724
+ scheduler.step()
725
+
726
+ # --- Aggregation and Logging (Rank 0) ---
727
+ loss_log = loss.detach()
728
+ if world_size > 1: dist.all_reduce(loss_log, op=dist.ReduceOp.AVG)
729
+
730
+ if is_main_process(rank):
731
+ pbar_desc = f'Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_finegrained:.3f} LR={current_lr:.6f}'
732
+ if args.model in ['ctm', 'lstm'] or torch.is_tensor(where_most_certain):
733
+ pbar_desc += f' Cert={where_most_certain_val:.2f}'#+-{where_most_certain_std:.2f}' # Removed std for brevity
734
+ if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
735
+ pbar_desc += f' Path={upto_where_mean:.1f}'#+-{upto_where_std:.1f}'
736
+ pbar.set_description(f'{args.model.upper()} {pbar_desc}')
737
+ # --- End Aggregation and Logging ---
738
+
739
+
740
+
741
+
742
+
743
+ # --- Checkpointing (Rank 0) ---
744
+ if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter and is_main_process(rank):
745
+ pbar.set_description('Rank 0: Saving checkpoint...')
746
+ save_path = f'{args.log_dir}/checkpoint.pt'
747
+ model_state_to_save = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
748
+
749
+ checkpoint_data = {
750
+ 'model_state_dict': model_state_to_save,
751
+ 'optimizer_state_dict': optimizer.state_dict(),
752
+ 'scheduler_state_dict': scheduler.state_dict(),
753
+ 'scaler_state_dict': scaler.state_dict(),
754
+ 'iteration': bi,
755
+ 'train_losses': train_losses,
756
+ 'test_losses': test_losses,
757
+ 'train_accuracies': train_accuracies, # Saving simplified scalar list
758
+ 'test_accuracies': test_accuracies, # Saving simplified scalar list
759
+ 'train_accuracies_most_certain': train_accuracies_most_certain,
760
+ 'test_accuracies_most_certain': test_accuracies_most_certain,
761
+ 'train_accuracies_most_certain_permaze': train_accuracies_most_certain_permaze,
762
+ 'test_accuracies_most_certain_permaze': test_accuracies_most_certain_permaze,
763
+ 'iters': iters,
764
+ 'args': args,
765
+ 'torch_rng_state': torch.get_rng_state(),
766
+ 'numpy_rng_state': np.random.get_state(),
767
+ 'random_rng_state': random.getstate(),
768
+ }
769
+ torch.save(checkpoint_data, save_path)
770
+ # --- End Checkpointing ---
771
+
772
+
773
+ if world_size > 1: dist.barrier()
774
+
775
+ if is_main_process(rank):
776
+ pbar.update(1)
777
+ # --- End Training Loop ---
778
+
779
+ if is_main_process(rank):
780
+ pbar.close()
781
+
782
+ cleanup_ddp()
tasks/parity/README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Parity
2
+
3
+ ## Training
4
+ To run the parity training that we used for the paper, run bash scripts from the root level of the repository. For example, to train the 75-iteration, 25-memory-length CTM, run:
5
+
6
+ ```
7
+ bash tasks/parity/scripts/train_ctm_75_25.sh
8
+ ```
9
+
10
+
11
+ ## Analysis
12
+ To run the analysis, first make sure the checkpoints are saved in the log directory (specified by the `log_dir` argument). The checkpoints can be obtained by either running the training code, or downloading them from [this link](https://drive.google.com/file/d/1itUS5_i9AyUo_7awllTx8X0PXYw9fnaG/view?usp=drive_link).
13
+
14
+ ```
15
+ python -m tasks.parity.analysis.run --log_dir <PATH_TO_LOG_DIR>
16
+ ```
tasks/parity/analysis/make_blog_gifs.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import os
4
+ import math
5
+ import imageio
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from matplotlib.patches import FancyArrowPatch
9
+ from scipy.special import softmax
10
+ import matplotlib.cm as cm
11
+ from data.custom_datasets import ParityDataset
12
+ import umap
13
+ from tqdm import tqdm
14
+
15
+
16
+ from models.utils import reshape_predictions
17
+ from tasks.parity.utils import reshape_inputs
18
+ from tasks.parity.analysis.run import build_model_from_checkpoint_path
19
+
20
+ from tasks.image_classification.analysis.build_imagenet_viz_blog import save_frames_to_mp4
21
+
22
+
23
+ def make_parity_gif(
24
+ predictions,
25
+ targets,
26
+ post_activations,
27
+ attention_weights,
28
+ inputs_to_model,
29
+ save_path,
30
+ umap_positions,
31
+ umap_point_scaler=1.0,
32
+ ):
33
+ batch_index = 0
34
+ figscale = 0.32
35
+ n_steps, n_heads, seqLen = attention_weights.shape[:3]
36
+ grid_side = int(np.sqrt(seqLen))
37
+ frames = []
38
+
39
+ inputs_this_batch = inputs_to_model[:, batch_index]
40
+ preds_this_batch = predictions[batch_index]
41
+ targets_this_batch = targets[batch_index]
42
+ post_act_this_batch = post_activations[:, batch_index]
43
+
44
+ # build a flexible mosaic
45
+ mosaic = [
46
+ [f"att_0", f"in_0", "probs", "probs", "target", "target"],
47
+ [f"att_1", f"in_1", "probs", "probs", "target", "target"],
48
+ ]
49
+ for h in range(2, n_heads):
50
+ mosaic.append(
51
+ [f"att_{h}", f"in_{h}", "umap", "umap",
52
+ "umap", "umap"]
53
+ )
54
+
55
+ for t in range(n_steps):
56
+ rows = len(mosaic)
57
+ cell_size = figscale * 4
58
+ fig_h = rows * cell_size
59
+
60
+ fig, ax = plt.subplot_mosaic(
61
+ mosaic,
62
+ figsize=(6 * cell_size, fig_h),
63
+ constrained_layout=False,
64
+ gridspec_kw={'wspace': 0.05, 'hspace': 0.05}, # small gaps
65
+ )
66
+ # restore a little margin
67
+ fig.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)
68
+
69
+ # probabilities heatmap
70
+ logits_t = preds_this_batch[:, :, t]
71
+ probs_t = softmax(logits_t, axis=1)[:, 0].reshape(grid_side, grid_side)
72
+ ax["probs"].imshow(probs_t, cmap="gray", vmin=0, vmax=1)
73
+ ax["probs"].axis("off")
74
+
75
+ # target overlay
76
+ ax["target"].imshow(
77
+ targets_this_batch.reshape(grid_side, grid_side),
78
+ cmap="gray_r", vmin=0, vmax=1
79
+ )
80
+ ax["target"].axis("off")
81
+ ax["target"].grid(which="minor", color="black", linestyle="-", linewidth=0.5)
82
+
83
+ z = post_act_this_batch[t]
84
+ low, high = np.percentile(z, 5), np.percentile(z, 95)
85
+ z_norm = np.clip((z - low) / (high - low), 0, 1)
86
+ point_sizes = (np.abs(z_norm - 0.5) * 100 + 5) * umap_point_scaler
87
+ cmap = plt.get_cmap("Spectral")
88
+ ax["umap"].scatter(
89
+ umap_positions[:, 0],
90
+ umap_positions[:, 1],
91
+ s=point_sizes,
92
+ c=cmap(z_norm),
93
+ alpha=0.8
94
+ )
95
+ ax["umap"].axis("off")
96
+
97
+
98
+ # normalize attention
99
+ att_t = attention_weights[t, :, :]
100
+ a_min, a_max = att_t.min(), att_t.max()
101
+ if not np.isclose(a_min, a_max):
102
+ att_t = (att_t - a_min) / (a_max - a_min + 1e-8)
103
+ else:
104
+ att_t = np.zeros_like(att_t)
105
+
106
+ # input image for arrows
107
+ img_t = inputs_this_batch[t].transpose(1, 2, 0)
108
+
109
+ if t == 0:
110
+ route_history = [[] for _ in range(n_heads)]
111
+
112
+ img_h, img_w = img_t.shape[:2]
113
+ cell_h = img_h // grid_side
114
+ cell_w = img_w // grid_side
115
+
116
+ for h in range(n_heads):
117
+ head_map = att_t[h].reshape(grid_side, grid_side)
118
+ ax[f"att_{h}"].imshow(head_map, cmap="viridis", vmin=0, vmax=1)
119
+ ax[f"att_{h}"].axis("off")
120
+ ax[f"in_{h}"].imshow(img_t, cmap="gray", vmin=0, vmax=1)
121
+ ax[f"in_{h}"].axis("off")
122
+
123
+ # track argmax center
124
+ flat_idx = np.argmax(head_map)
125
+ gy, gx = divmod(flat_idx, grid_side)
126
+ cx = int((gx + 0.5) * cell_w)
127
+ cy = int((gy + 0.5) * cell_h)
128
+ route_history[h].append((cx, cy))
129
+
130
+ cmap_steps = plt.colormaps.get_cmap("Spectral")
131
+ colors = [cmap_steps(i / (n_steps - 1)) for i in range(n_steps)]
132
+ for i in range(len(route_history[h]) - 1):
133
+ x0, y0 = route_history[h][i]
134
+ x1, y1 = route_history[h][i + 1]
135
+ color = colors[i]
136
+ is_last = (i == len(route_history[h]) - 2)
137
+ style = '->' if is_last else '-'
138
+ lw = 2.0 if is_last else 1.6
139
+ alpha = 1.0 if is_last else 0.9
140
+ scale = 10 if is_last else 1
141
+
142
+ # draw arrow
143
+ arr = FancyArrowPatch(
144
+ (x0, y0), (x1, y1),
145
+ arrowstyle=style,
146
+ linewidth=lw,
147
+ mutation_scale=scale,
148
+ alpha=alpha,
149
+ facecolor=color,
150
+ edgecolor=color,
151
+ shrinkA=0, shrinkB=0,
152
+ capstyle='round', joinstyle='round',
153
+ zorder=3 if is_last else 2,
154
+ clip_on=False,
155
+ )
156
+ ax[f"in_{h}"].add_patch(arr)
157
+
158
+ ax[f"in_{h}"].scatter(
159
+ x1, y1,
160
+ marker='x',
161
+ s=40,
162
+ color=color,
163
+ linewidths=lw,
164
+ zorder=4
165
+ )
166
+
167
+ canvas = fig.canvas
168
+ canvas.draw()
169
+ frame = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
170
+ w, h = canvas.get_width_height()
171
+ frames.append(frame.reshape(h, w, 4)[..., :3])
172
+ plt.close(fig)
173
+
174
+ # save gif
175
+ imageio.mimsave(f"{save_path}/activation.gif", frames, fps=15, loop=0)
176
+
177
+ # save mp4
178
+ save_frames_to_mp4(
179
+ [fm[:, :, ::-1] for fm in frames], # RGB→BGR
180
+ f"{save_path}/activation.mp4",
181
+ fps=15,
182
+ gop_size=1,
183
+ preset="slow"
184
+ )
185
+
186
+ def run_umap(model, testloader):
187
+ all_post_activations = []
188
+ point_counts = 150
189
+ sampled = 0
190
+ with tqdm(total=point_counts, desc="Collecting UMAP data") as pbar:
191
+ for inputs, _ in testloader:
192
+ for i in range(inputs.size(0)):
193
+ if sampled >= point_counts:
194
+ break
195
+ input_i = inputs[i].unsqueeze(0).to(device)
196
+ _, _, _, _, post_activations, _ = model(input_i, track=True)
197
+ all_post_activations.append(post_activations)
198
+ sampled += 1
199
+ pbar.update(1)
200
+ if sampled >= point_counts:
201
+ break
202
+
203
+ stacked = np.stack(all_post_activations, 1)
204
+ umap_features = stacked.reshape(-1, stacked.shape[-1])
205
+ reducer = umap.UMAP(
206
+ n_components=2,
207
+ n_neighbors=20,
208
+ min_dist=1,
209
+ spread=1,
210
+ metric='cosine',
211
+ local_connectivity=1
212
+ )
213
+ positions = reducer.fit_transform(umap_features.T)
214
+ return positions
215
+
216
+
217
+ def run_model_and_make_gif(checkpoint_path, save_path, device):
218
+
219
+ parity_sequence_length = 64
220
+ iterations = 75
221
+
222
+ test_data = ParityDataset(sequence_length=parity_sequence_length, length=10000)
223
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=True, num_workers=0, drop_last=False)
224
+
225
+
226
+ model, _ = build_model_from_checkpoint_path(checkpoint_path, "ctm", device=device)
227
+
228
+ input = torch.randint(0, 2, (64,), dtype=torch.float32, device=device) * 2 - 1
229
+ input = input.unsqueeze(0)
230
+
231
+ target = torch.cumsum((input == -1).to(torch.long), dim=1) % 2
232
+ target = target.unsqueeze(0)
233
+
234
+ positions = run_umap(model, testloader)
235
+
236
+ model.eval()
237
+ with torch.inference_mode():
238
+ predictions, _, _, _, post_activations, attention = model(input, track=True)
239
+ predictons = reshape_predictions(predictions, prediction_reshaper=[parity_sequence_length, 2])
240
+ input_images = reshape_inputs(input, iterations, grid_size=int(math.sqrt(parity_sequence_length)))
241
+
242
+ make_parity_gif(
243
+ predictions=predictons.detach().cpu().numpy(),
244
+ targets=target.detach().cpu().numpy(),
245
+ post_activations=post_activations,
246
+ attention_weights=attention.squeeze(1).squeeze(2),
247
+ inputs_to_model=input_images,
248
+ save_path=save_path,
249
+ umap_positions=positions,
250
+ umap_point_scaler=1.0,
251
+ )
252
+
253
+
254
+
255
+ if __name__ == "__main__":
256
+
257
+ CHECKPOINT_PATH = "checkpoints/parity/run1/ctm_75_25/checkpoint_200000.pt"
258
+ SAVE_PATH = f"tasks/parity/analysis/outputs/blog_gifs/"
259
+ os.makedirs(SAVE_PATH, exist_ok=True)
260
+
261
+ device = "cuda" if torch.cuda.is_available() else "cpu"
262
+
263
+ run_model_and_make_gif(CHECKPOINT_PATH, SAVE_PATH, device)
tasks/parity/analysis/run.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import argparse
4
+ import multiprocessing
5
+ from tqdm import tqdm
6
+ import math
7
+ import os
8
+ import csv
9
+ from utils.housekeeping import set_seed
10
+ from data.custom_datasets import ParityDataset
11
+ from tasks.parity.utils import prepare_model, reshape_attention_weights, reshape_inputs, get_where_most_certain
12
+ from tasks.parity.plotting import plot_attention_trajectory, plot_input, plot_target, plot_probabilities, plot_prediction, plot_accuracy_training, create_attentions_heatmap_gif, create_accuracies_heatmap_gif, create_stacked_gif, plot_training_curve_all_runs, plot_accuracy_thinking_time, make_parity_gif, plot_lstm_last_and_certain_accuracy
13
+ from models.utils import compute_normalized_entropy, reshape_predictions, get_latest_checkpoint_file, get_checkpoint_files, load_checkpoint, get_model_args_from_checkpoint, get_all_log_dirs
14
+ from tasks.image_classification.plotting import plot_neural_dynamics
15
+
16
+ import seaborn as sns
17
+ sns.set_palette("hls")
18
+ sns.set_style('darkgrid')
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser(description='Parity Analysis')
22
+ parser.add_argument('--log_dir', type=str, default='checkpoints/parity', help='Directory to save logs.')
23
+ parser.add_argument('--batch_size_test', type=int, default=128, help='batch size for testing')
24
+ parser.add_argument('--scale_training_curve', type=float, default=0.6, help='Scaling factor for plots.')
25
+ parser.add_argument('--scale_heatmap', type=float, default=0.4, help='Scaling factor for heatmap plots.')
26
+ parser.add_argument('--scale_training_index_accuracy', type=float, default=0.4, help='Scaling factor for training index accuracy plots.')
27
+ parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility.')
28
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
29
+ parser.add_argument('--model_type', type=str, choices=['ctm', 'lstm'], default='ctm', help='Type of model to analyze (ctm or lstm).')
30
+ return parser.parse_args()
31
+
32
+ def calculate_corrects(predictions, targets):
33
+ predicted_labels = predictions.argmax(2)
34
+ accuracy = (predicted_labels == targets.unsqueeze(-1))
35
+ return accuracy.detach().cpu().numpy()
36
+
37
+ def get_corrects_per_element_at_most_certain_time(predictions, certainty, targets):
38
+ where_most_certain = get_where_most_certain(certainty)
39
+ corrects = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device),:,where_most_certain] == targets).float()
40
+ return corrects.detach().cpu().numpy()
41
+
42
+ def calculate_entropy_average_over_batch(normalized_entropy_per_elements):
43
+ normalized_entropy_per_elements_avg_batch = normalized_entropy_per_elements.mean(axis=1)
44
+ return normalized_entropy_per_elements_avg_batch
45
+
46
+ def calculate_thinking_time_average_over_batch(normalized_entropy_per_elements):
47
+ first_occurrence = calculate_thinking_time(normalized_entropy_per_elements)
48
+ average_thinking_time = np.mean(first_occurrence, axis=0)
49
+ return average_thinking_time
50
+
51
+ def calculate_thinking_time(normalized_entropy_per_elements, finish_type="min", entropy_threshold=0.1):
52
+ if finish_type == "min":
53
+ min_entropy_time = np.argmin(normalized_entropy_per_elements, axis=0)
54
+ return min_entropy_time
55
+ elif finish_type == "threshold":
56
+ T, B, S = normalized_entropy_per_elements.shape
57
+ below_threshold = normalized_entropy_per_elements < entropy_threshold
58
+ first_occurrence = np.argmax(below_threshold, axis=0)
59
+ no_true = ~np.any(below_threshold, axis=0)
60
+ first_occurrence[no_true] = T
61
+ return first_occurrence
62
+
63
+ def test_handcrafted_examples(model, args, run_model_spefic_save_dir, device):
64
+ test_cases = []
65
+ all_even_input = torch.full((args.parity_sequence_length,), 1.0, dtype=torch.float32, device=device)
66
+ all_even_target = torch.zeros_like(all_even_input, dtype=torch.long)
67
+ test_cases.append((all_even_input, all_even_target))
68
+
69
+ all_odd_input = torch.full((args.parity_sequence_length,), -1.0, dtype=torch.float32, device=device)
70
+ all_odd_target = torch.cumsum((all_odd_input == -1).to(torch.long), dim=0) % 2
71
+ test_cases.append((all_odd_input, all_odd_target))
72
+
73
+ random_input = torch.randint(0, 2, (args.parity_sequence_length,), dtype=torch.float32, device=device) * 2 - 1
74
+ random_target = torch.cumsum((random_input == -1).to(torch.long), dim=0) % 2
75
+ test_cases.append((random_input, random_target))
76
+
77
+ for i, (inputs, targets) in enumerate(test_cases):
78
+ inputs = inputs.unsqueeze(0)
79
+ targets = targets.unsqueeze(0)
80
+ filename = f"eval_handcrafted_{i}"
81
+ extend_inference_time = False
82
+ handcraft_dir = f"{run_model_spefic_save_dir}/handcrafted_examples/{i}"
83
+ os.makedirs(handcraft_dir, exist_ok=True)
84
+
85
+ model.eval()
86
+ with torch.inference_mode():
87
+ if extend_inference_time:
88
+ model.iterations = model.iterations * 2
89
+ predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
90
+ predictions = reshape_predictions(predictions, prediction_reshaper=[args.parity_sequence_length, 2])
91
+ input_images = reshape_inputs(inputs, args.iterations, grid_size=int(math.sqrt(args.parity_sequence_length)))
92
+
93
+ plot_neural_dynamics(post_activations, 100, handcraft_dir, axis_snap=False)
94
+
95
+ process = multiprocessing.Process(
96
+ target=make_parity_gif,
97
+ args=(
98
+ predictions.detach().cpu().numpy(),
99
+ certainties.detach().cpu().numpy(),
100
+ targets.detach().cpu().numpy(),
101
+ pre_activations,
102
+ post_activations,
103
+ reshape_attention_weights(attention),
104
+ input_images,
105
+ f"{handcraft_dir}/eval_output_val_{0}_iter_{0}.gif",
106
+ ))
107
+ process.start()
108
+
109
+
110
+ input_images = input_images.squeeze(1).squeeze(1)
111
+ attention = attention.squeeze(1)
112
+
113
+ for h in range(args.heads):
114
+ plot_attention_trajectory(attention[:, h, :, :], certainties, input_images, handcraft_dir, filename + f"_head_{h}", args)
115
+
116
+ plot_attention_trajectory(attention.mean(1), certainties, input_images, handcraft_dir, filename, args)
117
+ plot_input(input_images, handcraft_dir, filename)
118
+ plot_target(targets, handcraft_dir, filename, args)
119
+ plot_probabilities(predictions, certainties, handcraft_dir, filename, args)
120
+ plot_prediction(predictions, certainties,handcraft_dir, filename, args)
121
+
122
+ if extend_inference_time:
123
+ model.iterations = model.iterations // 2
124
+ model.train()
125
+ pass
126
+
127
+ def build_model_from_checkpoint_path(checkpoint_path, model_type, device="cpu"):
128
+ checkpoint = load_checkpoint(checkpoint_path, device)
129
+ model_args = get_model_args_from_checkpoint(checkpoint)
130
+ model = prepare_model([model_args.parity_sequence_length, 2], model_args, device)
131
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
132
+ return model, model_args
133
+
134
+ def analyze_trained_model(run_model_spefic_save_dir, args, device):
135
+ with torch.no_grad():
136
+
137
+ latest_checkpoint_path = get_latest_checkpoint_file(args.log_dir)
138
+ model, model_args = build_model_from_checkpoint_path(latest_checkpoint_path, args.model_type, device=device)
139
+ model.eval()
140
+ model_args.log_dir = args.log_dir
141
+ test_data = ParityDataset(sequence_length=model_args.parity_sequence_length, length=10000)
142
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=0, drop_last=False)
143
+
144
+ corrects, corrects_at_most_certain_times, entropys, attentions = [], [], [], []
145
+
146
+ for inputs, targets in testloader:
147
+ inputs = inputs.to(device)
148
+ targets = targets.to(device)
149
+ predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
150
+ predictions = reshape_predictions(predictions, prediction_reshaper=[model_args.parity_sequence_length, 2])
151
+ corrects_batch = calculate_corrects(predictions, targets)
152
+ corrects_at_most_certain_time_batch = get_corrects_per_element_at_most_certain_time(predictions, certainties, targets)
153
+ corrects.append(corrects_batch)
154
+ corrects_at_most_certain_times.append(corrects_at_most_certain_time_batch)
155
+ attentions.append(attention)
156
+
157
+ test_handcrafted_examples(model, model_args, run_model_spefic_save_dir, device)
158
+
159
+ overall_mean_accuracy = np.mean(np.vstack(corrects_at_most_certain_times))
160
+ overall_std_accuracy = np.std(np.mean(np.vstack(corrects_at_most_certain_times), axis=1))
161
+
162
+ return overall_mean_accuracy, overall_std_accuracy, model_args.iterations
163
+
164
+ def analyze_training(run_model_spefic_save_dir, args, device):
165
+ checkpoint_files = get_checkpoint_files(args.log_dir)
166
+ all_accuracies = []
167
+ all_accuracies_at_most_certain_time = []
168
+ all_average_thinking_times = []
169
+ all_std_thinking_times = []
170
+ all_attentions = []
171
+ for checkpoint_path in checkpoint_files:
172
+ model, model_args = build_model_from_checkpoint_path(checkpoint_path, args.model_type, device=device)
173
+ model_args.log_dir = run_model_spefic_save_dir
174
+ test_data = ParityDataset(sequence_length=model_args.parity_sequence_length, length=1000)
175
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=0, drop_last=False)
176
+ corrects = []
177
+ corrects_at_most_certain_times = []
178
+ thinking_times = []
179
+ attentions = []
180
+
181
+ for inputs, targets in testloader:
182
+ inputs = inputs.to(device)
183
+ targets = targets.to(device)
184
+ predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
185
+ predictions = reshape_predictions(predictions, prediction_reshaper=[model_args.parity_sequence_length, 2])
186
+ attention = reshape_attention_weights(attention)
187
+
188
+ corrects_batch = calculate_corrects(predictions, targets)
189
+ corrects_at_most_certain_time_batch = get_corrects_per_element_at_most_certain_time(predictions, certainties, targets)
190
+ entropy_per_element = compute_normalized_entropy(predictions.permute(0,3,1,2), reduction='none').detach().cpu().numpy()
191
+ thinking_times_batch = np.argmin(entropy_per_element, axis=1)
192
+
193
+ corrects.append(corrects_batch)
194
+ corrects_at_most_certain_times.append(corrects_at_most_certain_time_batch)
195
+ thinking_times.append(thinking_times_batch)
196
+ attentions.append(attention)
197
+
198
+ checkpoint_average_accuracies = np.mean(np.concatenate(corrects, axis=0), axis=0).transpose(1,0)
199
+ all_accuracies.append(checkpoint_average_accuracies)
200
+
201
+ stacked_corrects_at_most_certain_times = np.vstack(corrects_at_most_certain_times)
202
+ checkpoint_average_accuracy_at_most_certain_time = np.mean(stacked_corrects_at_most_certain_times, axis=0)
203
+ all_accuracies_at_most_certain_time.append(checkpoint_average_accuracy_at_most_certain_time)
204
+
205
+ checkpoint_thinking_times = np.concatenate(thinking_times, axis=0)
206
+ checkpoint_average_thinking_time = np.mean(checkpoint_thinking_times, axis=0)
207
+ checkpoint_std_thinking_time = np.std(checkpoint_thinking_times, axis=0)
208
+ all_average_thinking_times.append(checkpoint_average_thinking_time)
209
+ all_std_thinking_times.append(checkpoint_std_thinking_time)
210
+
211
+ checkpoint_average_attentions = np.mean(np.concatenate(attentions, axis=1), axis=1)
212
+ all_attentions.append(checkpoint_average_attentions)
213
+
214
+ plot_accuracy_training(all_accuracies_at_most_certain_time, args.scale_training_index_accuracy, run_model_spefic_save_dir, args=model_args)
215
+ create_attentions_heatmap_gif(all_attentions, args.scale_heatmap, run_model_spefic_save_dir, model_args)
216
+ create_accuracies_heatmap_gif(np.array(all_accuracies), all_average_thinking_times, all_std_thinking_times, args.scale_heatmap, run_model_spefic_save_dir, model_args)
217
+ create_stacked_gif(run_model_spefic_save_dir)
218
+
219
+ def get_accuracy_and_loss_from_checkpoint(checkpoint):
220
+ training_iteration = checkpoint.get('training_iteration', 0)
221
+ train_losses = checkpoint.get('train_losses', [])
222
+ test_losses = checkpoint.get('test_losses', [])
223
+ train_accuracies = checkpoint.get('train_accuracies_most_certain', [])
224
+ test_accuracies = checkpoint.get('test_accuracies_most_certain', [])
225
+ return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies
226
+
227
+ if __name__ == "__main__":
228
+
229
+ args = parse_args()
230
+
231
+ device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
232
+
233
+ set_seed(args.seed)
234
+
235
+ save_dir = "tasks/parity/analysis/outputs"
236
+ os.makedirs(save_dir, exist_ok=True)
237
+
238
+ accuracy_csv_file_path = os.path.join(save_dir, "accuracy.csv")
239
+ if os.path.exists(accuracy_csv_file_path):
240
+ os.remove(accuracy_csv_file_path)
241
+
242
+ all_runs_log_dirs = get_all_log_dirs(args.log_dir)
243
+
244
+ plot_training_curve_all_runs(all_runs_log_dirs, save_dir, args.scale_training_curve, device, x_max=200_000)
245
+ plot_lstm_last_and_certain_accuracy(all_folders=all_runs_log_dirs, save_path=f"{save_dir}/lstm_final_vs_certain_accuracy.png", scale=args.scale_training_curve)
246
+
247
+ progress_bar = tqdm(all_runs_log_dirs, desc="Analyzing Runs", dynamic_ncols=True)
248
+ for folder in progress_bar:
249
+
250
+ run, model_name = folder.strip("/").split("/")[-2:]
251
+
252
+ run_model_spefic_save_dir = f"{save_dir}/{model_name}/{run}"
253
+ os.makedirs(run_model_spefic_save_dir, exist_ok=True)
254
+
255
+ args.log_dir = folder
256
+ progress_bar.set_description(f"Analyzing Trained Model at {folder}")
257
+
258
+ accuracy_mean, accuracy_std, num_iterations = analyze_trained_model(run_model_spefic_save_dir, args, device)
259
+
260
+ with open(accuracy_csv_file_path, mode='a', newline='') as file:
261
+ writer = csv.writer(file)
262
+ if file.tell() == 0:
263
+ writer.writerow(["Run", "Overall Mean Accuracy", "Overall Std Accuracy", "Num Iterations"])
264
+ writer.writerow([folder, accuracy_mean, accuracy_std, num_iterations])
265
+
266
+ progress_bar.set_description(f"Analyzing Training at {folder}")
267
+ analyze_training(run_model_spefic_save_dir, args, device)
268
+
269
+ plot_accuracy_thinking_time(accuracy_csv_file_path, scale=args.scale_training_curve, output_dir=save_dir)
tasks/parity/plotting.py ADDED
@@ -0,0 +1,896 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import seaborn as sns
3
+ import numpy as np
4
+ import pandas as pd
5
+ from collections import defaultdict
6
+ from matplotlib.lines import Line2D
7
+ import matplotlib as mpl
8
+ import matplotlib.pyplot as plt
9
+ import matplotlib.patheffects as path_effects
10
+ from matplotlib.ticker import FuncFormatter
11
+ from scipy.special import softmax
12
+ import imageio.v2 as imageio
13
+ from PIL import Image
14
+ import math
15
+ import re
16
+ sns.set_style('darkgrid')
17
+ mpl.use('Agg')
18
+
19
+ from tasks.parity.utils import get_where_most_certain, parse_folder_name
20
+ from models.utils import get_latest_checkpoint_file, load_checkpoint, get_model_args_from_checkpoint, get_accuracy_and_loss_from_checkpoint
21
+ from tasks.image_classification.plotting import save_frames_to_mp4
22
+
23
+ def make_parity_gif(predictions, certainties, targets, pre_activations, post_activations, attention_weights, inputs_to_model, filename):
24
+
25
+ # Config
26
+ batch_index = 0
27
+ n_neurons_to_visualise = 16
28
+ figscale = 0.28
29
+ n_steps = len(pre_activations)
30
+ frames = []
31
+ heatmap_cmap = sns.color_palette("viridis", as_cmap=True)
32
+
33
+ these_pre_acts = pre_activations[:, batch_index, :] # Shape: (T, H)
34
+ these_post_acts = post_activations[:, batch_index, :] # Shape: (T, H)
35
+ these_inputs = inputs_to_model[:, batch_index, :, :, :] # Shape: (T, C, H, W)
36
+ these_predictions = predictions[batch_index, :, :, :] # Shape: (d, C, T)
37
+ these_certainties = certainties[batch_index, :, :] # Shape: (C, T)
38
+ these_attention_weights = attention_weights[:, batch_index, :, :]
39
+
40
+ # Create mosaic layout
41
+ mosaic = [['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'target', 'target'] for _ in range(2)] + \
42
+ [['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'target', 'target'] for _ in range(2)] + \
43
+ [['certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty']] + \
44
+ [[f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}'] for ti in range(n_neurons_to_visualise)]
45
+
46
+ for stepi in range(n_steps):
47
+ fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale))
48
+
49
+ # Plot predictions
50
+ d = these_predictions.shape[0]
51
+ grid_side = int(np.sqrt(d))
52
+ logits = these_predictions[:, :, stepi]
53
+
54
+ probs = softmax(logits, axis=1)
55
+ probs_grid = probs[:, 0].reshape(grid_side, grid_side)
56
+ axes_gif["probs"].imshow(probs_grid, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
57
+ axes_gif["probs"].axis('off')
58
+ axes_gif["probs"].set_title('Probabilties')
59
+
60
+ # Create and show attention heatmap
61
+ this_input_gate = these_attention_weights[stepi]
62
+ gate_min, gate_max = np.nanmin(this_input_gate), np.nanmax(this_input_gate)
63
+ if not np.isclose(gate_min, gate_max):
64
+ normalized_gate = (this_input_gate - gate_min) / (gate_max - gate_min + 1e-8)
65
+ else:
66
+ normalized_gate = np.zeros_like(this_input_gate)
67
+ attention_weights_heatmap = heatmap_cmap(normalized_gate)[:,:,:3]
68
+
69
+ # Show heatmaps
70
+ axes_gif['attention'].imshow(attention_weights_heatmap, vmin=0, vmax=1)
71
+ axes_gif['attention'].axis('off')
72
+ axes_gif['attention'].set_title('Attention')
73
+
74
+
75
+ # Plot target
76
+ target_grid = targets[batch_index].reshape(grid_side, grid_side)
77
+ axes_gif["target"].imshow(target_grid, cmap='viridis_r', interpolation='nearest', vmin=0, vmax=1)
78
+ axes_gif["target"].axis('off')
79
+ axes_gif["target"].set_title('Target')
80
+
81
+ # Add certainty plot
82
+ axes_gif['certainty'].plot(np.arange(n_steps), these_certainties[1], 'k-', linewidth=2)
83
+ axes_gif['certainty'].set_xlim([0, n_steps-1])
84
+ axes_gif['certainty'].axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
85
+ axes_gif['certainty'].set_xticklabels([])
86
+ axes_gif['certainty'].set_yticklabels([])
87
+ axes_gif['certainty'].grid(False)
88
+
89
+ # Plot neuron traces
90
+ for neuroni in range(n_neurons_to_visualise):
91
+ ax = axes_gif[f'trace_{neuroni}']
92
+
93
+ pre_activation = these_pre_acts[:, neuroni]
94
+ post_activation = these_post_acts[:, neuroni]
95
+
96
+ ax_pre = ax.twinx()
97
+
98
+ pre_min, pre_max = np.min(pre_activation), np.max(pre_activation)
99
+ post_min, post_max = np.min(post_activation), np.max(post_activation)
100
+
101
+ ax_pre.plot(np.arange(n_steps), pre_activation,
102
+ color='grey',
103
+ linestyle='--',
104
+ linewidth=1,
105
+ alpha=0.4,
106
+ label='Pre-activation')
107
+
108
+ color = 'blue' if neuroni % 2 else 'red'
109
+ ax.plot(np.arange(n_steps), post_activation,
110
+ color=color,
111
+ linestyle='-',
112
+ linewidth=2,
113
+ alpha=1.0,
114
+ label='Post-activation')
115
+
116
+ ax.set_xlim([0, n_steps-1])
117
+ ax_pre.set_xlim([0, n_steps-1])
118
+
119
+ if pre_min != pre_max:
120
+ ax_pre.set_ylim([pre_min, pre_max])
121
+ if post_min != post_max:
122
+ ax.set_ylim([post_min, post_max])
123
+
124
+ ax.axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
125
+
126
+ ax.set_xticklabels([])
127
+ ax.set_yticklabels([])
128
+ ax.grid(False)
129
+
130
+ ax_pre.set_xticklabels([])
131
+ ax_pre.set_yticklabels([])
132
+ ax_pre.grid(False)
133
+
134
+ # Show input image
135
+ this_image = these_inputs[stepi].transpose(1, 2, 0)
136
+ axes_gif['img_data'].imshow(this_image, cmap='viridis', vmin=0, vmax=1)
137
+ axes_gif['img_data'].grid(False)
138
+ axes_gif['img_data'].set_xticks([])
139
+ axes_gif['img_data'].set_yticks([])
140
+ axes_gif['img_data'].set_title('Input')
141
+
142
+ # Save frames
143
+ fig_gif.tight_layout(pad=0.1)
144
+ if stepi == 0:
145
+ fig_gif.savefig(filename.split('.gif')[0]+'_frame0.png', dpi=100)
146
+ if stepi == 1:
147
+ fig_gif.savefig(filename.split('.gif')[0]+'_frame1.png', dpi=100)
148
+ if stepi == n_steps-1:
149
+ fig_gif.savefig(filename.split('.gif')[0]+'_frame-1.png', dpi=100)
150
+
151
+ # Convert to frame
152
+ canvas = fig_gif.canvas
153
+ canvas.draw()
154
+ image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
155
+ image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3]
156
+ frames.append(image_numpy)
157
+ plt.close(fig_gif)
158
+
159
+ imageio.mimsave(filename, frames, fps=15, loop=100)
160
+
161
+ pass
162
+
163
+
164
+ def plot_attention_trajectory(attention, certainties, input_images, save_dir, filename, args):
165
+ where_most_certain = get_where_most_certain(certainties)
166
+ grid_size = int(math.sqrt(args.parity_sequence_length))
167
+ trajectory = [np.unravel_index(np.argmax(attention[t]), (grid_size, grid_size)) for t in range(args.iterations)]
168
+ x_coords, y_coords = zip(*trajectory)
169
+
170
+ plt.figure(figsize=(5, 5))
171
+ plt.imshow(input_images[0], cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
172
+
173
+ ax = plt.gca()
174
+ nrows, ncols = input_images[0].shape
175
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
176
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
177
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
178
+ ax.tick_params(which="minor", size=0)
179
+ ax.set_axisbelow(False)
180
+ plt.xticks([])
181
+ plt.yticks([])
182
+
183
+ cmap = plt.get_cmap("plasma")
184
+ norm_time = np.linspace(0, 1, len(trajectory))
185
+
186
+ for i in range(len(trajectory) - 1):
187
+ x1, y1 = x_coords[i], y_coords[i]
188
+ x2, y2 = x_coords[i + 1], y_coords[i + 1]
189
+ color = cmap(norm_time[i])
190
+ line, = plt.plot([y1, y2], [x1, x2], color=color, linewidth=6, alpha=0.5, zorder=4)
191
+ line.set_path_effects([
192
+ path_effects.Stroke(linewidth=8, foreground='white'),
193
+ path_effects.Normal()
194
+ ])
195
+
196
+ for i, (x, y) in enumerate(trajectory):
197
+ plt.scatter(y, x, color=cmap(norm_time[i]), s=100, edgecolor='white', linewidth=1.5, zorder=5)
198
+
199
+ most_certain_point = trajectory[where_most_certain]
200
+
201
+ plt.plot(most_certain_point[1], most_certain_point[0],
202
+ marker='x', markersize=18, markeredgewidth=5,
203
+ color='white', linestyle='', zorder=6)
204
+ plt.plot(most_certain_point[1], most_certain_point[0],
205
+ marker='x', markersize=15, markeredgewidth=3,
206
+ color=cmap(norm_time[where_most_certain]), linestyle='', zorder=7)
207
+
208
+ plt.tight_layout()
209
+ plt.savefig(f"{save_dir}/{filename}_traj.png", dpi=300, bbox_inches='tight', pad_inches=0)
210
+ plt.savefig(f"{save_dir}/{filename}_traj.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
211
+ plt.show()
212
+ plt.close()
213
+
214
+ def plot_input(input_images, save_dir, filename):
215
+
216
+ plt.figure(figsize=(5, 5))
217
+ plt.imshow(input_images[0], cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
218
+
219
+ ax = plt.gca()
220
+ nrows, ncols = input_images[0].shape
221
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
222
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
223
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
224
+ ax.tick_params(which="minor", size=0)
225
+ ax.set_axisbelow(False)
226
+ plt.xticks([])
227
+ plt.yticks([])
228
+
229
+ plt.tight_layout()
230
+ plt.savefig(f"{save_dir}/{filename}_input.png", dpi=300, bbox_inches='tight', pad_inches=0)
231
+ plt.savefig(f"{save_dir}/{filename}_input.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
232
+ plt.show()
233
+ plt.close()
234
+
235
+ def plot_target(targets, save_dir, filename, args):
236
+ grid_size = int(math.sqrt(args.parity_sequence_length))
237
+ targets_grid = targets[0].reshape(grid_size, grid_size).detach().cpu().numpy()
238
+ plt.figure(figsize=(5, 5))
239
+ plt.imshow(targets_grid, cmap="gray_r", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
240
+ ax = plt.gca()
241
+ nrows, ncols = targets_grid.shape
242
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
243
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
244
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
245
+ ax.tick_params(which="minor", size=0)
246
+ ax.set_axisbelow(False)
247
+ plt.xticks([])
248
+ plt.yticks([])
249
+ plt.tight_layout()
250
+ plt.savefig(f"{save_dir}/{filename}_target.png", dpi=300, bbox_inches='tight', pad_inches=0)
251
+ plt.savefig(f"{save_dir}/{filename}_target.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
252
+ plt.show()
253
+ plt.close()
254
+
255
+ def plot_probabilities(predictions, certainties, save_dir, filename, args):
256
+ grid_size = int(math.sqrt(args.parity_sequence_length))
257
+ where_most_certain = get_where_most_certain(certainties)
258
+ predictions_most_certain = predictions[0, :, :, where_most_certain].detach().cpu().numpy()
259
+ probs = softmax(predictions_most_certain, axis=1)
260
+ probs_grid = probs[:, 0].reshape(grid_size, grid_size)
261
+ plt.figure(figsize=(5, 5))
262
+ plt.imshow(probs_grid, cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
263
+ ax = plt.gca()
264
+ nrows, ncols = probs_grid.shape
265
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
266
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
267
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
268
+ ax.tick_params(which="minor", size=0)
269
+ ax.set_axisbelow(False)
270
+ plt.xticks([])
271
+ plt.yticks([])
272
+ plt.tight_layout()
273
+ plt.savefig(f"{save_dir}/{filename}_probs.png", dpi=300, bbox_inches='tight', pad_inches=0)
274
+ plt.savefig(f"{save_dir}/{filename}_probs.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
275
+ plt.show()
276
+ plt.close()
277
+
278
+ def plot_prediction(predictions, certainties, save_dir, filename, args):
279
+ grid_size = int(math.sqrt(args.parity_sequence_length))
280
+ where_most_certain = get_where_most_certain(certainties)
281
+ predictions_most_certain = predictions[0, :, :, where_most_certain].detach().cpu().numpy()
282
+ class_grid = np.argmax(predictions_most_certain, axis=1).reshape(grid_size, grid_size)
283
+
284
+ plt.figure(figsize=(5, 5))
285
+ plt.imshow(class_grid, cmap="gray_r", origin="upper", vmin=0, vmax=1, interpolation='nearest')
286
+
287
+ ax = plt.gca()
288
+ nrows, ncols = class_grid.shape
289
+ ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
290
+ ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
291
+ ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
292
+ ax.tick_params(which="minor", size=0)
293
+ ax.set_axisbelow(False)
294
+ plt.xticks([])
295
+ plt.yticks([])
296
+
297
+ plt.tight_layout()
298
+ plt.savefig(f"{save_dir}/{filename}_prediction.png", dpi=300, bbox_inches='tight', pad_inches=0)
299
+ plt.savefig(f"{save_dir}/{filename}_prediction.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
300
+ plt.show()
301
+ plt.close()
302
+
303
+ def plot_accuracy_heatmap(overall_accuracies_avg, average_thinking_time, std_thinking_time, scale, save_path, args):
304
+ fig, ax = plt.subplots(figsize=(scale*10, scale*5))
305
+ im = ax.imshow(overall_accuracies_avg.T * 100, aspect='auto', cmap="viridis", origin='lower', extent=[0, args.iterations-1, 0, args.parity_sequence_length-1], vmin=50, vmax=100)
306
+ cbar = fig.colorbar(im, ax=ax, format="%.1f")
307
+ cbar.set_label("Accuracy (%)")
308
+ ax.errorbar(average_thinking_time, np.arange(args.parity_sequence_length), xerr=std_thinking_time, fmt='ko', markersize=2, capsize=2, elinewidth=1, label="Min. Entropy")
309
+ ax.set_xlabel("Time Step")
310
+ ax.set_ylabel("Sequence Index")
311
+ ax.set_xlim(0, args.iterations-1)
312
+ ax.set_ylim(0, args.parity_sequence_length-1)
313
+ ax.grid(False)
314
+ ax.legend(loc="upper left")
315
+ fig.tight_layout(pad=0.1)
316
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
317
+ fig.savefig(save_path.replace(".png", ".pdf"), format='pdf', bbox_inches="tight")
318
+ plt.close(fig)
319
+
320
+ def plot_attention_heatmap(overall_attentions_avg, scale, save_path, vmin=None, vmax=None):
321
+ overall_attentions_avg = overall_attentions_avg.reshape(overall_attentions_avg.shape[0], -1)
322
+ fig, ax = plt.subplots(figsize=(scale*10, scale*5))
323
+ im = ax.imshow(overall_attentions_avg.T, aspect='auto', cmap="viridis", origin='lower', extent=[0, overall_attentions_avg.shape[0]-1, 0, overall_attentions_avg.shape[1]-1], vmin=vmin, vmax=vmax)
324
+ cbar = fig.colorbar(im, ax=ax, format=FuncFormatter(lambda x, _: f"{x:05.2f}"))
325
+ cbar.set_label("Attention Weight")
326
+ ax.set_xlabel("Time Step")
327
+ ax.set_ylabel("Sequence Index")
328
+ ax.set_xlim(0, overall_attentions_avg.shape[0]-1)
329
+ ax.set_ylim(0, overall_attentions_avg.shape[1]-1)
330
+ ax.grid(False)
331
+ fig.tight_layout(pad=0.1)
332
+ fig.savefig(save_path, dpi=300, bbox_inches="tight")
333
+ fig.savefig(save_path.replace(".png", ".pdf"), format='pdf', bbox_inches="tight")
334
+ plt.close(fig)
335
+
336
+ def create_accuracies_heatmap_gif(all_accuracies, all_average_thinking_times, all_std_thinking_times, scale, save_dir, args):
337
+ heatmap_components_dir = os.path.join(save_dir, "accuracy_heatmaps")
338
+ os.makedirs(heatmap_components_dir, exist_ok=True)
339
+
340
+ image_paths = []
341
+
342
+ for i, (accuracies, avg_thinking_time, std_thinking_time) in enumerate(zip(all_accuracies, all_average_thinking_times, all_std_thinking_times)):
343
+ save_path = os.path.join(heatmap_components_dir, f"frame_{i:04d}.png")
344
+ plot_accuracy_heatmap(accuracies, avg_thinking_time, std_thinking_time, scale, save_path, args)
345
+ image_paths.append(save_path)
346
+
347
+ gif_path = os.path.join(save_dir, "accuracy_heatmap.gif")
348
+ with imageio.get_writer(gif_path, mode='I', duration=0.3) as writer:
349
+ for image_path in image_paths:
350
+ image = imageio.imread(image_path)
351
+ writer.append_data(image)
352
+
353
+ def create_attentions_heatmap_gif(all_attentions, scale, save_path, args):
354
+ heatmap_components_dir = os.path.join(args.log_dir, "attention_heatmaps")
355
+ os.makedirs(heatmap_components_dir, exist_ok=True)
356
+
357
+ global_min = min(attentions.min() for attentions in all_attentions)
358
+ global_max = max(attentions.max() for attentions in all_attentions)
359
+
360
+ image_paths = []
361
+
362
+ for i, attentions in enumerate(all_attentions):
363
+ save_path_component = os.path.join(heatmap_components_dir, f"frame_{i:04d}.png")
364
+ plot_attention_heatmap(attentions, scale, save_path_component, vmin=global_min, vmax=global_max)
365
+ image_paths.append(save_path_component)
366
+
367
+ gif_path = os.path.join(save_path, "attention_heatmap.gif")
368
+ with imageio.get_writer(gif_path, mode='I', duration=0.3) as writer:
369
+ for image_path in image_paths:
370
+ image = imageio.imread(image_path)
371
+ writer.append_data(image)
372
+
373
+ def create_stacked_gif(save_path, y_shift=200):
374
+ accuracy_gif_path = os.path.join(save_path, "accuracy_heatmap.gif")
375
+ attention_gif_path = os.path.join(save_path, "attention_heatmap.gif")
376
+ stacked_gif_path = os.path.join(save_path, "stacked_heatmap.gif")
377
+
378
+ accuracy_reader = imageio.get_reader(accuracy_gif_path)
379
+ attention_reader = imageio.get_reader(attention_gif_path)
380
+
381
+ accuracy_frames = [Image.fromarray(frame) for frame in accuracy_reader]
382
+ attention_frames = [Image.fromarray(frame) for frame in attention_reader]
383
+
384
+ assert len(accuracy_frames) == len(attention_frames), "Mismatch in frame counts between accuracy and attention GIFs"
385
+
386
+ stacked_frames = []
387
+ for acc_frame, att_frame in zip(accuracy_frames, attention_frames):
388
+ acc_width, acc_height = acc_frame.size
389
+ att_width, att_height = att_frame.size
390
+
391
+ # Create base canvas
392
+ stacked_height = acc_height + att_height - y_shift
393
+ stacked_width = max(acc_width, att_width)
394
+
395
+ stacked_frame = Image.new("RGB", (stacked_width, stacked_height), color=(255, 255, 255))
396
+
397
+ # Paste attention frame first, shifted up
398
+ stacked_frame.paste(att_frame, (0, 0)) # Paste at top
399
+ stacked_frame.paste(acc_frame, (0, att_height - y_shift)) # Shift accuracy up by overlap
400
+
401
+ stacked_frames.append(stacked_frame)
402
+
403
+ stacked_frames[0].save(
404
+ stacked_gif_path,
405
+ save_all=True,
406
+ append_images=stacked_frames[1:],
407
+ duration=300,
408
+ loop=0
409
+ )
410
+
411
+ save_frames_to_mp4(
412
+ [np.array(fm)[:, :, ::-1] for fm in stacked_frames],
413
+ f"{stacked_gif_path.replace('gif', 'mp4')}",
414
+ fps=15,
415
+ gop_size=1,
416
+ preset="slow"
417
+ )
418
+
419
+
420
+ def plot_accuracy_training(all_accuracies, scale, run_model_spefic_save_dir, args):
421
+ scale=0.5
422
+ seq_indices = range(args.parity_sequence_length)
423
+ fig, ax = plt.subplots(figsize=(scale*10, scale*5))
424
+ cmap = plt.get_cmap("viridis")
425
+
426
+ for i, acc in enumerate(all_accuracies):
427
+ color = cmap(i / (len(all_accuracies) - 1))
428
+ ax.plot(seq_indices, acc*100, color=color, alpha=0.7, linewidth=1)
429
+
430
+ num_checkpoints = 5
431
+ checkpoint_percentages = np.linspace(0, 100, num_checkpoints)
432
+
433
+ sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=100))
434
+ sm.set_array([])
435
+ cbar = fig.colorbar(sm, ax=ax)
436
+ cbar.set_label("Training Progress (%)")
437
+ cbar.set_ticks(checkpoint_percentages)
438
+ cbar.set_ticklabels([f"{int(p)}%" for p in checkpoint_percentages])
439
+
440
+ ax.set_xlabel("Sequence Index")
441
+ ax.set_ylabel("Accuracy (%)")
442
+ ax.set_xticks([0, 16 ,32, 48, 63])
443
+ ax.grid(True, alpha=0.5)
444
+ ax.set_xlim(0, args.parity_sequence_length - 1)
445
+
446
+ fig.tight_layout(pad=0.1)
447
+ fig.savefig(f"{run_model_spefic_save_dir}/accuracy_vs_seq_element.png", dpi=300, bbox_inches="tight")
448
+ fig.savefig(f"{run_model_spefic_save_dir}/accuracy_vs_seq_element.pdf", format='pdf', bbox_inches="tight")
449
+ plt.close(fig)
450
+
451
+
452
+ def plot_loss_all_runs(training_data, evaluate_every, save_path="train_loss_comparison_parity.png", step=1, scale=1.0, x_max=None):
453
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
454
+
455
+ grouped = defaultdict(list)
456
+ label_map = {}
457
+ linestyle_map = {}
458
+ iters_map = {}
459
+ model_map = {}
460
+
461
+ for folder, data in training_data.items():
462
+ label, model_type, iters = parse_folder_name(folder)
463
+ if iters is None:
464
+ continue
465
+
466
+ key = f"{model_type}_{iters}"
467
+ grouped[key].append(data["train_losses"])
468
+ label_map[key] = f"{model_type}, {iters} Iters."
469
+ linestyle_map[key] = "--" if model_type == "LSTM" else "-"
470
+ iters_map[key] = iters
471
+ model_map[key] = model_type
472
+
473
+ unique_iters = sorted(set(iters_map.values()))
474
+ base_colors = sns.color_palette("hls", n_colors=len(unique_iters))
475
+ color_lookup = {iters: base_colors[i] for i, iters in enumerate(unique_iters)}
476
+
477
+ legend_entries = []
478
+ global_max_x = 0
479
+ for key in sorted(grouped.keys(), key=lambda k: (iters_map[k], model_map[k])):
480
+ runs = grouped[key]
481
+ if not runs:
482
+ continue
483
+
484
+ iters = iters_map[key]
485
+ color = color_lookup[iters]
486
+ linestyle = linestyle_map[key]
487
+
488
+ min_len = min(len(r) for r in runs)
489
+ trimmed = np.array([r[:min_len] for r in runs])[:, ::step]
490
+
491
+ mean = np.mean(trimmed, axis=0)
492
+ std = np.std(trimmed, axis=0)
493
+ x = np.arange(len(mean)) * step * evaluate_every
494
+ group_max_x = len(mean) * step * evaluate_every
495
+ global_max_x = max(global_max_x, group_max_x)
496
+
497
+ line, = ax.plot(x, mean, color=color, linestyle=linestyle, label=label_map[key])
498
+ ax.fill_between(x, mean - std, mean + std, alpha=0.1, color=color)
499
+
500
+ legend_entries.append((line, label_map[key]))
501
+
502
+ ax.set_xlabel("Training Iterations")
503
+ ax.set_ylabel("Loss")
504
+ ax.grid(True, alpha=0.5)
505
+
506
+ style_legend = [
507
+ Line2D([0], [0], color='black', linestyle='-', label='CTM'),
508
+ Line2D([0], [0], color='black', linestyle='--', label='LSTM')
509
+ ]
510
+ color_legend = [
511
+ Line2D([0], [0], color=color_lookup[it], linestyle='-', label=f"{it} Iters.")
512
+ for it in unique_iters
513
+ ]
514
+
515
+ if not x_max:
516
+ x_max = global_max_x
517
+
518
+ ax.set_xlim([0, x_max])
519
+ ax.set_ylim(bottom=0)
520
+ ax.set_xticks(np.arange(0, x_max + 1, 50000))
521
+ ax.legend(handles=color_legend + style_legend, loc="upper left")
522
+ fig.tight_layout(pad=0.1)
523
+ fig.savefig(save_path, dpi=300)
524
+ fig.savefig(save_path.replace("png", "pdf"), format='pdf')
525
+ plt.close(fig)
526
+
527
+ def plot_accuracy_all_runs(training_data, evaluate_every, save_path="test_accuracy_comparison_parity.png", step=1, scale=1.0, smooth=False, x_max=None):
528
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
529
+
530
+ grouped = defaultdict(list)
531
+ label_map = {}
532
+ linestyle_map = {}
533
+ iters_map = {}
534
+ model_map = {}
535
+
536
+ for folder, data in training_data.items():
537
+ label, model_type, iters = parse_folder_name(folder)
538
+ if iters is None:
539
+ continue
540
+
541
+ key = f"{model_type}_{iters}"
542
+ grouped[key].append(data["test_accuracies"])
543
+ label_map[key] = f"{model_type}, {iters} Iters."
544
+ linestyle_map[key] = "--" if model_type == "LSTM" else "-"
545
+ iters_map[key] = iters
546
+ model_map[key] = model_type
547
+
548
+ unique_iters = sorted(set(iters_map.values()))
549
+ base_colors = sns.color_palette("hls", n_colors=len(unique_iters))
550
+ color_lookup = {iters: base_colors[i] for i, iters in enumerate(unique_iters)}
551
+
552
+ legend_entries = []
553
+ global_max_x = 0
554
+
555
+ for key in sorted(grouped.keys(), key=lambda k: (iters_map[k], model_map[k])):
556
+ runs = grouped[key]
557
+ if not runs:
558
+ continue
559
+
560
+ iters = iters_map[key]
561
+ model = model_map[key]
562
+ color = color_lookup[iters]
563
+ linestyle = linestyle_map[key]
564
+
565
+ min_len = min(len(r) for r in runs)
566
+ trimmed = np.array([r[:min_len] for r in runs])[:, ::step]
567
+
568
+ mean = np.mean(trimmed, axis=0) * 100
569
+ std = np.std(trimmed, axis=0) * 100
570
+
571
+ if smooth:
572
+ window_size = max(1, int(0.05 * len(mean)))
573
+ if window_size % 2 == 0:
574
+ window_size += 1
575
+ kernel = np.ones(window_size) / window_size
576
+
577
+ smoothed_mean = np.convolve(mean, kernel, mode='same')
578
+ smoothed_std = np.convolve(std, kernel, mode='same')
579
+
580
+ valid_start = window_size // 2
581
+ valid_end = len(mean) - window_size // 2
582
+ valid_length = valid_end - valid_start
583
+
584
+ mean = smoothed_mean[valid_start:valid_end]
585
+ std = smoothed_std[valid_start:valid_end]
586
+ x = np.arange(valid_length) * step * evaluate_every
587
+ group_max_x = valid_length * step * evaluate_every
588
+ else:
589
+ x = np.arange(len(mean)) * step * evaluate_every
590
+ group_max_x = len(mean) * step * evaluate_every
591
+
592
+ global_max_x = max(global_max_x, group_max_x)
593
+
594
+ line, = ax.plot(x, mean, color=color, linestyle=linestyle, label=label_map[key])
595
+ ax.fill_between(x, mean - std, mean + std, alpha=0.1, color=color)
596
+ legend_entries.append((line, label_map[key]))
597
+
598
+ if smooth or x_max is None:
599
+ x_max = global_max_x
600
+
601
+ ax.set_xlim([0, x_max])
602
+ ax.set_ylim(top=100)
603
+ ax.set_xticks(np.arange(0, x_max + 1, 50000))
604
+ ax.set_xlabel("Training Iterations")
605
+ ax.set_ylabel("Accuracy (%)")
606
+ ax.grid(True, alpha=0.5)
607
+
608
+ style_legend = [
609
+ Line2D([0], [0], color='black', linestyle='-', label='CTM'),
610
+ Line2D([0], [0], color='black', linestyle='--', label='LSTM')
611
+ ]
612
+ color_legend = [
613
+ Line2D([0], [0], color=color_lookup[it], linestyle='-', label=f"{it} Iters.")
614
+ for it in unique_iters
615
+ ]
616
+ ax.legend(handles=color_legend + style_legend, loc="upper left")
617
+
618
+ fig.tight_layout(pad=0.1)
619
+ fig.savefig(save_path, dpi=300)
620
+ fig.savefig(save_path.replace("png", "pdf"), format='pdf')
621
+ plt.close(fig)
622
+
623
+ def extract_run_name(folder, run_index=None):
624
+ # Try to extract from parent folder
625
+ parent = os.path.basename(os.path.dirname(folder))
626
+ match = re.search(r'run(\d+)', parent, re.IGNORECASE)
627
+ if match:
628
+ return f"Run {int(match.group(1))}"
629
+ # Try current folder name
630
+ basename = os.path.basename(folder)
631
+ match = re.search(r'run(\d+)', basename, re.IGNORECASE)
632
+ if match:
633
+ return f"Run {int(match.group(1))}"
634
+ # Fallback: use run index
635
+ if run_index is not None:
636
+ return f"Run {run_index + 1}"
637
+ raise ValueError(f"Could not extract run number from: {folder}")
638
+
639
+ def plot_loss_individual_runs(training_data, evaluate_every, save_dir, scale=1.0, x_max=None):
640
+
641
+ grouped = defaultdict(list)
642
+ label_map = {}
643
+ iters_map = {}
644
+ model_map = {}
645
+
646
+ base_colors = sns.color_palette("hls", n_colors=3)
647
+ color_lookup = {f"Run {i+1}": base_colors[i] for i in range(3)}
648
+
649
+ for i, (folder, data) in enumerate(training_data.items()):
650
+ checkpoint = load_checkpoint(get_latest_checkpoint_file(folder), device="cpu")
651
+ model_args = get_model_args_from_checkpoint(checkpoint)
652
+ label, model_type, iters = parse_folder_name(folder)
653
+ if iters is None:
654
+ continue
655
+
656
+ if model_type.lower() == "ctm":
657
+ memory_length = getattr(model_args, "memory_length", None)
658
+ if memory_length is None:
659
+ raise ValueError(f"CTM model missing memory_length in checkpoint args from: {folder}")
660
+ key = f"{model_type}_{iters}_{memory_length}".lower()
661
+ else:
662
+ key = f"{model_type}_{iters}".lower()
663
+
664
+ run_name = extract_run_name(folder, run_index=i)
665
+ grouped[key].append((run_name, data["train_losses"]))
666
+ label_map[key] = f"{model_type}, {iters} Iters."
667
+ iters_map[key] = iters
668
+ model_map[key] = model_type
669
+
670
+ for key, runs in grouped.items():
671
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
672
+ for run_name, losses in runs:
673
+ x = np.arange(len(losses)) * evaluate_every
674
+ color = color_lookup.get(run_name, 'gray')
675
+ ax.plot(x, losses, label=run_name, color=color, alpha=0.7)
676
+
677
+ ax.set_xlabel("Training Iterations")
678
+ ax.set_ylabel("Loss")
679
+ ax.set_ylim(bottom=-0.01)
680
+ ax.grid(True, alpha=0.5)
681
+ if x_max:
682
+ ax.set_xlim([0, x_max])
683
+ ax.set_xticks(np.arange(0, x_max + 1, 50000))
684
+ ax.legend()
685
+ fig.tight_layout(pad=0.1)
686
+
687
+ subdir = os.path.join(save_dir, key)
688
+ os.makedirs(subdir, exist_ok=True)
689
+ fname = os.path.join(subdir, f"individual_runs_loss_{key}.png")
690
+ fig.savefig(fname, dpi=300)
691
+ fig.savefig(fname.replace("png", "pdf"), format="pdf")
692
+ plt.close(fig)
693
+
694
+ def plot_accuracy_individual_runs(training_data, evaluate_every, save_dir, scale=1.0, smooth=False, x_max=None):
695
+
696
+ grouped = defaultdict(list)
697
+ label_map = {}
698
+ iters_map = {}
699
+ model_map = {}
700
+
701
+ base_colors = sns.color_palette("hls", n_colors=3)
702
+ color_lookup = {f"Run {i+1}": base_colors[i] for i in range(3)}
703
+
704
+ for i, (folder, data) in enumerate(training_data.items()):
705
+ checkpoint = load_checkpoint(get_latest_checkpoint_file(folder), device="cpu")
706
+ model_args = get_model_args_from_checkpoint(checkpoint)
707
+ label, model_type, iters = parse_folder_name(folder)
708
+ if iters is None:
709
+ continue
710
+
711
+ if model_type.lower() == "ctm":
712
+ memory_length = getattr(model_args, "memory_length", None)
713
+ if memory_length is None:
714
+ raise ValueError(f"CTM model missing memory_length in checkpoint args from: {folder}")
715
+ key = f"{model_type}_{iters}_{memory_length}".lower()
716
+ else:
717
+ key = f"{model_type}_{iters}".lower()
718
+
719
+ run_name = extract_run_name(folder, run_index=i)
720
+ grouped[key].append((run_name, data["test_accuracies"]))
721
+ label_map[key] = f"{model_type}, {iters} Iters."
722
+ iters_map[key] = iters
723
+ model_map[key] = model_type
724
+
725
+ for key, runs in grouped.items():
726
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
727
+ for run_name, acc in runs:
728
+ acc = np.array(acc) * 100
729
+ if smooth:
730
+ window_size = max(1, int(0.05 * len(acc)))
731
+ if window_size % 2 == 0:
732
+ window_size += 1
733
+ kernel = np.ones(window_size) / window_size
734
+ acc = np.convolve(acc, kernel, mode="same")
735
+
736
+ x = np.arange(len(acc)) * evaluate_every
737
+ color = color_lookup.get(run_name, 'gray')
738
+ ax.plot(x, acc, label=run_name, color=color, alpha=0.7)
739
+
740
+ ax.set_xlabel("Training Iterations")
741
+ ax.set_ylabel("Accuracy (%)")
742
+ ax.set_ylim([50, 101])
743
+ ax.grid(True, alpha=0.5)
744
+ if x_max:
745
+ ax.set_xlim([0, x_max])
746
+ ax.set_xticks(np.arange(0, x_max + 1, 50000))
747
+ ax.legend()
748
+ fig.tight_layout(pad=0.1)
749
+
750
+ subdir = os.path.join(save_dir, key)
751
+ os.makedirs(subdir, exist_ok=True)
752
+ fname = os.path.join(subdir, f"individual_runs_accuracy_{key}.png")
753
+ fig.savefig(fname, dpi=300)
754
+ fig.savefig(fname.replace("png", "pdf"), format="pdf")
755
+ plt.close(fig)
756
+
757
+ def plot_training_curve_all_runs(all_folders, save_dir, scale, device, smooth=False, x_max=None, plot_individual_runs=True):
758
+
759
+ all_folders = [folder for folder in all_folders if "certain" not in folder]
760
+
761
+ training_data = {}
762
+ evaluation_intervals = []
763
+ for folder in all_folders:
764
+ latest_checkpoint_path = get_latest_checkpoint_file(folder)
765
+ if latest_checkpoint_path:
766
+ checkpoint = load_checkpoint(latest_checkpoint_path, device=device)
767
+ model_args = get_model_args_from_checkpoint(checkpoint)
768
+ evaluation_intervals.append(model_args.track_every)
769
+
770
+ _, train_losses, test_losses, train_accuracies, test_accuracies = get_accuracy_and_loss_from_checkpoint(checkpoint, device=device)
771
+ training_data[folder] = {
772
+ "train_losses": train_losses,
773
+ "test_losses": test_losses,
774
+ "train_accuracies": train_accuracies,
775
+ "test_accuracies": test_accuracies
776
+ }
777
+ else:
778
+ print(f"No checkpoint found for {folder}")
779
+
780
+ assert len(evaluation_intervals) > 0, "No valid checkpoints found."
781
+ assert all(interval == evaluation_intervals[0] for interval in evaluation_intervals), "Evaluation intervals are not consistent across runs."
782
+
783
+ evaluate_every = evaluation_intervals[0]
784
+
785
+ if plot_individual_runs:
786
+ plot_loss_individual_runs(training_data, evaluate_every, save_dir=save_dir, scale=scale, x_max=x_max)
787
+ plot_accuracy_individual_runs(training_data, evaluate_every, save_dir=save_dir, scale=scale, smooth=smooth, x_max=x_max)
788
+
789
+ plot_loss_all_runs(training_data, evaluate_every, save_path=f"{save_dir}/loss_comparison.png", scale=scale, x_max=x_max)
790
+ plot_accuracy_all_runs(training_data, evaluate_every, save_path=f"{save_dir}/accuracy_comparison.png", scale=scale, smooth=smooth, x_max=x_max)
791
+
792
+ return training_data
793
+
794
+ def plot_accuracy_thinking_time(csv_path, scale, output_dir="analysis/cifar"):
795
+ if not os.path.exists(csv_path):
796
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
797
+
798
+ df = pd.read_csv(csv_path)
799
+ df["RunName"] = df["Run"].apply(lambda x: os.path.basename(os.path.dirname(x)))
800
+ df["Model"] = df["Run"].apply(lambda x: "CTM" if "ctm" in x.lower() else "LSTM")
801
+
802
+ grouped = df.groupby(["Model", "Num Iterations"])
803
+ summary = grouped.agg(
804
+ mean_accuracy=("Overall Mean Accuracy", "mean"),
805
+ std_accuracy=("Overall Std Accuracy", lambda x: np.sqrt(np.mean(x**2)))
806
+ ).reset_index()
807
+
808
+ summary["mean_accuracy"] *= 100
809
+ summary["std_accuracy"] *= 100
810
+
811
+ fig, ax = plt.subplots(figsize=(scale*5, scale*5))
812
+
813
+ for model in ("CTM", "LSTM"):
814
+ subset = summary[summary["Model"] == model].sort_values(by="Num Iterations")
815
+ linestyle = "-" if model == "CTM" else "--"
816
+ ax.errorbar(
817
+ subset["Num Iterations"],
818
+ subset["mean_accuracy"],
819
+ yerr=subset["std_accuracy"],
820
+ linestyle=linestyle,
821
+ color="black",
822
+ marker='.',
823
+ label=model,
824
+ capsize=3,
825
+ elinewidth=1,
826
+ errorevery=1
827
+ )
828
+
829
+ ax.set_xlabel("Internal Ticks")
830
+ ax.set_ylabel("Accuracy (%)")
831
+ custom_lines = [
832
+ Line2D([0], [0], color='black', linestyle='-', label='CTM'),
833
+ Line2D([0], [0], color='black', linestyle='--', label='LSTM')
834
+ ]
835
+ ax.legend(handles=custom_lines, loc="lower right")
836
+ ax.grid(True, alpha=0.5)
837
+
838
+ os.makedirs(output_dir, exist_ok=True)
839
+ output_path_png = os.path.join(output_dir, "accuracy_vs_thinking_time.png")
840
+ fig.tight_layout(pad=0.1)
841
+ fig.savefig(output_path_png, dpi=300)
842
+ fig.savefig(output_path_png.replace("png", "pdf"), format='pdf')
843
+ plt.close(fig)
844
+
845
+
846
+ def plot_lstm_last_and_certain_accuracy(all_folders, save_path="lstm_last_and_certain_accuracy.png", scale=1.0, step=1, x_max=None):
847
+
848
+ tags = ["lstm_10", "lstm_10_certain", "lstm_25", "lstm_25_certain"]
849
+ folders = [f for f in all_folders if any(tag in f.lower() for tag in tags)]
850
+
851
+ training_data, eval_intervals = {}, []
852
+ for f in folders:
853
+ cp = get_latest_checkpoint_file(f)
854
+ if not cp:
855
+ print(f"⚠️ No checkpoint in {f}")
856
+ continue
857
+ ckpt = load_checkpoint(cp, device="cpu")
858
+ args = get_model_args_from_checkpoint(ckpt)
859
+ eval_intervals.append(args.track_every)
860
+ _, _, _, _, acc = get_accuracy_and_loss_from_checkpoint(ckpt)
861
+ iters = "25" if "25" in f.lower() else "10"
862
+ label = "Certain" if "certain" in f.lower() else "Final"
863
+ training_data.setdefault((iters, label), []).append(acc)
864
+
865
+ assert training_data and all(i == eval_intervals[0] for i in eval_intervals), "Missing or inconsistent eval intervals."
866
+ evaluate_every = eval_intervals[0]
867
+
868
+ keys = sorted(training_data.keys())
869
+ colors = sns.color_palette("hls", n_colors=len(keys))
870
+ style_map = {key: ("--" if key[1] == "Certain" else "-") for key in keys}
871
+ color_map = {key: colors[i] for i, key in enumerate(keys)}
872
+
873
+ fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
874
+ max_x = 0
875
+
876
+ for key in keys:
877
+ runs = training_data[key]
878
+ min_len = min(len(r) for r in runs)
879
+ trimmed = np.stack([r[:min_len] for r in runs], axis=0)[:, ::step]
880
+ mean, std = np.mean(trimmed, 0) * 100, np.std(trimmed, 0) * 100
881
+ x = np.arange(len(mean)) * step * evaluate_every
882
+ ax.plot(x, mean, color=color_map[key], linestyle=style_map[key],
883
+ label=f"{key[0]} Iters, {key[1]}", linewidth=2, alpha=0.7)
884
+ ax.fill_between(x, mean - std, mean + std, color=color_map[key], alpha=0.1)
885
+ max_x = max(max_x, x[-1])
886
+
887
+ ax.set_xlim([0, x_max or max_x])
888
+ ax.set_xticks(np.arange(0, (x_max or max_x) + 1, 50000))
889
+ ax.set_xlabel("Training Iterations")
890
+ ax.set_ylabel("Accuracy (%)")
891
+ ax.grid(True, alpha=0.5)
892
+ ax.legend(loc="lower right")
893
+ fig.tight_layout(pad=0.1)
894
+ fig.savefig(save_path, dpi=300)
895
+ fig.savefig(save_path.replace("png", "pdf"), format="pdf")
896
+ plt.close(fig)
tasks/parity/scripts/train_ctm_100_50.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=100
4
+ MEMORY_LENGTH=50
5
+ LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --memory_length $MEMORY_LENGTH \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 1024 \
16
+ --d_input 512 \
17
+ --n_synch_out 32 \
18
+ --n_synch_action 32 \
19
+ --synapse_depth 1 \
20
+ --heads 8 \
21
+ --memory_hidden_dims 16 \
22
+ --dropout 0.0 \
23
+ --deep_memory \
24
+ --no-do_normalisation \
25
+ --positional_embedding_type="custom-rotational-1d" \
26
+ --backbone_type="parity_backbone" \
27
+ --no-full_eval \
28
+ --weight_decay 0.0 \
29
+ --gradient_clipping 0.9 \
30
+ --use_scheduler \
31
+ --scheduler_type "cosine" \
32
+ --milestones 0 0 0 \
33
+ --gamma 0 \
34
+ --dataset "parity" \
35
+ --batch_size 64 \
36
+ --batch_size_test 256 \
37
+ --lr=0.0001 \
38
+ --training_iterations 200001 \
39
+ --warmup_steps 500 \
40
+ --track_every 1000 \
41
+ --save_every 10000 \
42
+ --no-reload \
43
+ --no-reload_model_only \
44
+ --device 0 \
45
+ --no-use_amp \
46
+ --neuron_select_type "random"
tasks/parity/scripts/train_ctm_10_5.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=10
4
+ MEMORY_LENGTH=5
5
+ LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --memory_length $MEMORY_LENGTH \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 1024 \
16
+ --d_input 512 \
17
+ --n_synch_out 32 \
18
+ --n_synch_action 32 \
19
+ --synapse_depth 1 \
20
+ --heads 8 \
21
+ --memory_hidden_dims 16 \
22
+ --dropout 0.0 \
23
+ --deep_memory \
24
+ --no-do_normalisation \
25
+ --positional_embedding_type="custom-rotational-1d" \
26
+ --backbone_type="parity_backbone" \
27
+ --no-full_eval \
28
+ --weight_decay 0.0 \
29
+ --gradient_clipping 0.9 \
30
+ --use_scheduler \
31
+ --scheduler_type "cosine" \
32
+ --milestones 0 0 0 \
33
+ --gamma 0 \
34
+ --dataset "parity" \
35
+ --batch_size 64 \
36
+ --batch_size_test 256 \
37
+ --lr=0.0001 \
38
+ --training_iterations 200001 \
39
+ --warmup_steps 500 \
40
+ --track_every 1000 \
41
+ --save_every 10000 \
42
+ --no-reload \
43
+ --no-reload_model_only \
44
+ --device 0 \
45
+ --no-use_amp \
46
+ --neuron_select_type "random"
tasks/parity/scripts/train_ctm_1_1.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=1
4
+ MEMORY_LENGTH=1
5
+ LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --memory_length $MEMORY_LENGTH \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 1024 \
16
+ --d_input 512 \
17
+ --n_synch_out 32 \
18
+ --n_synch_action 32 \
19
+ --synapse_depth 1 \
20
+ --heads 8 \
21
+ --memory_hidden_dims 16 \
22
+ --dropout 0.0 \
23
+ --deep_memory \
24
+ --no-do_normalisation \
25
+ --positional_embedding_type="custom-rotational-1d" \
26
+ --backbone_type="parity_backbone" \
27
+ --no-full_eval \
28
+ --weight_decay 0.0 \
29
+ --gradient_clipping 0.9 \
30
+ --use_scheduler \
31
+ --scheduler_type "cosine" \
32
+ --milestones 0 0 0 \
33
+ --gamma 0 \
34
+ --dataset "parity" \
35
+ --batch_size 64 \
36
+ --batch_size_test 256 \
37
+ --lr=0.0001 \
38
+ --training_iterations 200001 \
39
+ --warmup_steps 500 \
40
+ --track_every 1000 \
41
+ --save_every 10000 \
42
+ --no-reload \
43
+ --no-reload_model_only \
44
+ --device 0 \
45
+ --no-use_amp \
46
+ --neuron_select_type "random"
tasks/parity/scripts/train_ctm_25_10.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=25
4
+ MEMORY_LENGTH=10
5
+ LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --memory_length $MEMORY_LENGTH \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 1024 \
16
+ --d_input 512 \
17
+ --n_synch_out 32 \
18
+ --n_synch_action 32 \
19
+ --synapse_depth 1 \
20
+ --heads 8 \
21
+ --memory_hidden_dims 16 \
22
+ --dropout 0.0 \
23
+ --deep_memory \
24
+ --no-do_normalisation \
25
+ --positional_embedding_type="custom-rotational-1d" \
26
+ --backbone_type="parity_backbone" \
27
+ --no-full_eval \
28
+ --weight_decay 0.0 \
29
+ --gradient_clipping 0.9 \
30
+ --use_scheduler \
31
+ --scheduler_type "cosine" \
32
+ --milestones 0 0 0 \
33
+ --gamma 0 \
34
+ --dataset "parity" \
35
+ --batch_size 64 \
36
+ --batch_size_test 256 \
37
+ --lr=0.0001 \
38
+ --training_iterations 200001 \
39
+ --warmup_steps 500 \
40
+ --track_every 1000 \
41
+ --save_every 10000 \
42
+ --no-reload \
43
+ --no-reload_model_only \
44
+ --device 0 \
45
+ --no-use_amp \
46
+ --neuron_select_type "random"
tasks/parity/scripts/train_ctm_50_25.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=50
4
+ MEMORY_LENGTH=25
5
+ LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --memory_length $MEMORY_LENGTH \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 1024 \
16
+ --d_input 512 \
17
+ --n_synch_out 32 \
18
+ --n_synch_action 32 \
19
+ --synapse_depth 1 \
20
+ --heads 8 \
21
+ --memory_hidden_dims 16 \
22
+ --dropout 0.0 \
23
+ --deep_memory \
24
+ --no-do_normalisation \
25
+ --positional_embedding_type="custom-rotational-1d" \
26
+ --backbone_type="parity_backbone" \
27
+ --no-full_eval \
28
+ --weight_decay 0.0 \
29
+ --gradient_clipping 0.9 \
30
+ --use_scheduler \
31
+ --scheduler_type "cosine" \
32
+ --milestones 0 0 0 \
33
+ --gamma 0 \
34
+ --dataset "parity" \
35
+ --batch_size 64 \
36
+ --batch_size_test 256 \
37
+ --lr=0.0001 \
38
+ --training_iterations 200001 \
39
+ --warmup_steps 500 \
40
+ --track_every 1000 \
41
+ --save_every 10000 \
42
+ --no-reload \
43
+ --no-reload_model_only \
44
+ --device 0 \
45
+ --no-use_amp \
46
+ --neuron_select_type "random"
tasks/parity/scripts/train_ctm_75_25.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=75
4
+ MEMORY_LENGTH=25
5
+ LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --memory_length $MEMORY_LENGTH \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 1024 \
16
+ --d_input 512 \
17
+ --n_synch_out 32 \
18
+ --n_synch_action 32 \
19
+ --synapse_depth 1 \
20
+ --heads 8 \
21
+ --memory_hidden_dims 16 \
22
+ --dropout 0.0 \
23
+ --deep_memory \
24
+ --no-do_normalisation \
25
+ --positional_embedding_type="custom-rotational-1d" \
26
+ --backbone_type="parity_backbone" \
27
+ --no-full_eval \
28
+ --weight_decay 0.0 \
29
+ --gradient_clipping 0.9 \
30
+ --use_scheduler \
31
+ --scheduler_type "cosine" \
32
+ --milestones 0 0 0 \
33
+ --gamma 0 \
34
+ --dataset "parity" \
35
+ --batch_size 64 \
36
+ --batch_size_test 256 \
37
+ --lr=0.0001 \
38
+ --training_iterations 200001 \
39
+ --warmup_steps 500 \
40
+ --track_every 1000 \
41
+ --save_every 10000 \
42
+ --no-reload \
43
+ --no-reload_model_only \
44
+ --device 0 \
45
+ --no-use_amp \
46
+ --neuron_select_type "random"
tasks/parity/scripts/train_lstm_1.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=1
4
+ MODEL_TYPE="lstm"
5
+ LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --model_type $MODEL_TYPE \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 669 \
16
+ --d_input 512 \
17
+ --heads 8 \
18
+ --dropout 0.0 \
19
+ --positional_embedding_type="custom-rotational-1d" \
20
+ --backbone_type="parity_backbone" \
21
+ --no-full_eval \
22
+ --weight_decay 0.0 \
23
+ --gradient_clipping -1 \
24
+ --use_scheduler \
25
+ --scheduler_type "cosine" \
26
+ --milestones 0 0 0 \
27
+ --gamma 0 \
28
+ --dataset "parity" \
29
+ --batch_size 64 \
30
+ --batch_size_test 256 \
31
+ --lr=0.0001 \
32
+ --training_iterations 200001 \
33
+ --warmup_steps 500 \
34
+ --track_every 1000 \
35
+ --save_every 10000 \
36
+ --no-reload \
37
+ --no-reload_model_only \
38
+ --device 0 \
39
+ --no-use_amp \
tasks/parity/scripts/train_lstm_10.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=10
4
+ MODEL_TYPE="lstm"
5
+ LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --model_type $MODEL_TYPE \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 686 \
16
+ --d_input 512 \
17
+ --heads 8 \
18
+ --dropout 0.0 \
19
+ --positional_embedding_type="custom-rotational-1d" \
20
+ --backbone_type="parity_backbone" \
21
+ --no-full_eval \
22
+ --weight_decay 0.0 \
23
+ --gradient_clipping -1 \
24
+ --use_scheduler \
25
+ --scheduler_type "cosine" \
26
+ --milestones 0 0 0 \
27
+ --gamma 0 \
28
+ --dataset "parity" \
29
+ --batch_size 64 \
30
+ --batch_size_test 256 \
31
+ --lr=0.0001 \
32
+ --training_iterations 200001 \
33
+ --warmup_steps 500 \
34
+ --track_every 1000 \
35
+ --save_every 10000 \
36
+ --no-reload \
37
+ --no-reload_model_only \
38
+ --device 0 \
39
+ --no-use_amp \
tasks/parity/scripts/train_lstm_100.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=100
4
+ MODEL_TYPE="lstm"
5
+ LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --model_type $MODEL_TYPE \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 857 \
16
+ --d_input 512 \
17
+ --heads 8 \
18
+ --dropout 0.0 \
19
+ --positional_embedding_type="custom-rotational-1d" \
20
+ --backbone_type="parity_backbone" \
21
+ --no-full_eval \
22
+ --weight_decay 0.0 \
23
+ --gradient_clipping -1 \
24
+ --use_scheduler \
25
+ --scheduler_type "cosine" \
26
+ --milestones 0 0 0 \
27
+ --gamma 0 \
28
+ --dataset "parity" \
29
+ --batch_size 64 \
30
+ --batch_size_test 256 \
31
+ --lr=0.0001 \
32
+ --training_iterations 200001 \
33
+ --warmup_steps 500 \
34
+ --track_every 1000 \
35
+ --save_every 10000 \
36
+ --no-reload \
37
+ --no-reload_model_only \
38
+ --device 0 \
39
+ --no-use_amp \
tasks/parity/scripts/train_lstm_10_certain.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=3
3
+ ITERATIONS=10
4
+ MODEL_TYPE="lstm"
5
+ LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}_certain"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --model_type $MODEL_TYPE \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 686 \
16
+ --d_input 512 \
17
+ --heads 8 \
18
+ --dropout 0.0 \
19
+ --positional_embedding_type="custom-rotational-1d" \
20
+ --backbone_type="parity_backbone" \
21
+ --no-full_eval \
22
+ --weight_decay 0.0 \
23
+ --gradient_clipping -1 \
24
+ --use_scheduler \
25
+ --scheduler_type "cosine" \
26
+ --milestones 0 0 0 \
27
+ --gamma 0 \
28
+ --dataset "parity" \
29
+ --batch_size 64 \
30
+ --batch_size_test 256 \
31
+ --lr=0.0001 \
32
+ --training_iterations 200001 \
33
+ --warmup_steps 500 \
34
+ --track_every 1000 \
35
+ --save_every 10000 \
36
+ --no-reload \
37
+ --no-reload_model_only \
38
+ --device 0 \
39
+ --no-use_amp \
40
+ --use_most_certain_with_lstm \
tasks/parity/scripts/train_lstm_25.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=1
3
+ ITERATIONS=25
4
+ MODEL_TYPE="lstm"
5
+ LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --model_type $MODEL_TYPE \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 706 \
16
+ --d_input 512 \
17
+ --heads 8 \
18
+ --dropout 0.0 \
19
+ --positional_embedding_type="custom-rotational-1d" \
20
+ --backbone_type="parity_backbone" \
21
+ --no-full_eval \
22
+ --weight_decay 0.0 \
23
+ --gradient_clipping -1 \
24
+ --use_scheduler \
25
+ --scheduler_type "cosine" \
26
+ --milestones 0 0 0 \
27
+ --gamma 0 \
28
+ --dataset "parity" \
29
+ --batch_size 64 \
30
+ --batch_size_test 256 \
31
+ --lr=0.0001 \
32
+ --training_iterations 200001 \
33
+ --warmup_steps 500 \
34
+ --track_every 1000 \
35
+ --save_every 10000 \
36
+ --no-reload \
37
+ --no-reload_model_only \
38
+ --device 0 \
39
+ --no-use_amp \
tasks/parity/scripts/train_lstm_25_certain.sh ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ RUN=3
3
+ ITERATIONS=25
4
+ MODEL_TYPE="lstm"
5
+ LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}_certain"
6
+ SEED=$((RUN - 1))
7
+
8
+ python -m tasks.parity.train \
9
+ --log_dir $LOG_DIR \
10
+ --seed $SEED \
11
+ --iterations $ITERATIONS \
12
+ --model_type $MODEL_TYPE \
13
+ --parity_sequence_length 64 \
14
+ --n_test_batches 20 \
15
+ --d_model 706 \
16
+ --d_input 512 \
17
+ --heads 8 \
18
+ --dropout 0.0 \
19
+ --positional_embedding_type="custom-rotational-1d" \
20
+ --backbone_type="parity_backbone" \
21
+ --no-full_eval \
22
+ --weight_decay 0.0 \
23
+ --gradient_clipping -1 \
24
+ --use_scheduler \
25
+ --scheduler_type "cosine" \
26
+ --milestones 0 0 0 \
27
+ --gamma 0 \
28
+ --dataset "parity" \
29
+ --batch_size 64 \
30
+ --batch_size_test 256 \
31
+ --lr=0.0001 \
32
+ --training_iterations 200001 \
33
+ --warmup_steps 500 \
34
+ --track_every 1000 \
35
+ --save_every 10000 \
36
+ --no-reload \
37
+ --no-reload_model_only \
38
+ --device 0 \
39
+ --no-use_amp \
40
+ --use_most_certain_with_lstm \