Commit
·
68b32f4
0
Parent(s):
Welcome to the CTM. This is the first commit of the public repo. Enjoy!
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +18 -0
- README.md +134 -0
- data/custom_datasets.py +324 -0
- examples/01_mnist.ipynb +0 -0
- models/README.md +7 -0
- models/constants.py +10 -0
- models/ctm.py +552 -0
- models/ctm_qamnist.py +205 -0
- models/ctm_rl.py +192 -0
- models/ctm_sort.py +126 -0
- models/ff.py +75 -0
- models/lstm.py +244 -0
- models/lstm_qamnist.py +184 -0
- models/lstm_rl.py +96 -0
- models/modules.py +692 -0
- models/resnet.py +374 -0
- models/utils.py +122 -0
- requirements.txt +15 -0
- tasks/image_classification/README.md +29 -0
- tasks/image_classification/analysis/README.md +12 -0
- tasks/image_classification/analysis/run_imagenet_analysis.py +972 -0
- tasks/image_classification/imagenet_classes.py +1007 -0
- tasks/image_classification/plotting.py +494 -0
- tasks/image_classification/scripts/train_cifar10.sh +286 -0
- tasks/image_classification/scripts/train_imagenet.sh +38 -0
- tasks/image_classification/train.py +685 -0
- tasks/image_classification/train_distributed.py +799 -0
- tasks/mazes/README.md +10 -0
- tasks/mazes/analysis/README.md +10 -0
- tasks/mazes/analysis/run.py +407 -0
- tasks/mazes/plotting.py +198 -0
- tasks/mazes/scripts/train_ctm.sh +35 -0
- tasks/mazes/train.py +698 -0
- tasks/mazes/train_distributed.py +782 -0
- tasks/parity/README.md +16 -0
- tasks/parity/analysis/make_blog_gifs.py +263 -0
- tasks/parity/analysis/run.py +269 -0
- tasks/parity/plotting.py +896 -0
- tasks/parity/scripts/train_ctm_100_50.sh +46 -0
- tasks/parity/scripts/train_ctm_10_5.sh +46 -0
- tasks/parity/scripts/train_ctm_1_1.sh +46 -0
- tasks/parity/scripts/train_ctm_25_10.sh +46 -0
- tasks/parity/scripts/train_ctm_50_25.sh +46 -0
- tasks/parity/scripts/train_ctm_75_25.sh +46 -0
- tasks/parity/scripts/train_lstm_1.sh +39 -0
- tasks/parity/scripts/train_lstm_10.sh +39 -0
- tasks/parity/scripts/train_lstm_100.sh +39 -0
- tasks/parity/scripts/train_lstm_10_certain.sh +40 -0
- tasks/parity/scripts/train_lstm_25.sh +39 -0
- tasks/parity/scripts/train_lstm_25_certain.sh +40 -0
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*/__pycache__
|
| 2 |
+
logs
|
| 3 |
+
.DS_Store
|
| 4 |
+
*.png
|
| 5 |
+
*.pdf
|
| 6 |
+
*.gif
|
| 7 |
+
*.out
|
| 8 |
+
*.pyc
|
| 9 |
+
*.env
|
| 10 |
+
*.pt
|
| 11 |
+
*.mp4
|
| 12 |
+
.vscode*
|
| 13 |
+
*outputs*
|
| 14 |
+
data/*
|
| 15 |
+
!assets/*.gif
|
| 16 |
+
!data/custom_datasets.py
|
| 17 |
+
examples/*
|
| 18 |
+
!examples/01_mnist.ipynb
|
README.md
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🕰️ The Continuous Thought Machine
|
| 2 |
+
|
| 3 |
+
📚 [PAPER: Technical Report](https://pub.sakana.ai/ctm/paper) | 📝 [Blog](https://sakana.ai/ctm/) | 🕹️ [Interactive Website](https:pub.sakana.ai/ctm)
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
We present the Continuous Thought Machine (CTM), a model designed to unfold and then leverage neural activity as the underlying mechanism for observation and action. The CTM has two core innovations:
|
| 8 |
+
|
| 9 |
+
1. Neuron-level temporal processing, where each neuron uses unique weight parameters to process a history of incoming signals, enabling fine-grained temporal dynamics.
|
| 10 |
+
|
| 11 |
+
2. Neural synchronisation, employed as a direct latent representation for modulating data and producing outputs, thus directly encoding information in the timing of neural activity.
|
| 12 |
+
|
| 13 |
+
We demonstrate the CTM's strong performance and versatility across a range of challenging tasks, including ImageNet classification, solving 2D mazes, sorting, parity computation, question-answering, and RL tasks.
|
| 14 |
+
|
| 15 |
+
We provide all necessary code to reproduce our results and invite others to build upon and use CTMs in their own work.
|
| 16 |
+
|
| 17 |
+
## Repo structure
|
| 18 |
+
```
|
| 19 |
+
├── tasks
|
| 20 |
+
│ ├── image_classification
|
| 21 |
+
│ │ ├── train.py # Training code for image classification (cifar, imagenet)
|
| 22 |
+
│ │ ├── imagenet_classes.py # Helper for imagenet class names
|
| 23 |
+
│ │ ├── plotting.py # Plotting utils specific to this task
|
| 24 |
+
│ │ └── analysis
|
| 25 |
+
│ │ ├──run_imagenet_analysis.py # ImageNet eval and visualisation code
|
| 26 |
+
│ │ └──outputs/ # Folder for outputs of analysis
|
| 27 |
+
│ ├── mazes
|
| 28 |
+
│ │ ├── train.py # Training code for solving 2D mazes (by way of a route; see paper)
|
| 29 |
+
│ │ └── plotting.py # Plotting utils specific to this task
|
| 30 |
+
│ │ └── analysis
|
| 31 |
+
│ │ ├──run.py # Maze analysis code
|
| 32 |
+
│ │ └──outputs/ # Folder for outputs of analysis
|
| 33 |
+
│ ├── sort
|
| 34 |
+
│ │ ├── train.py # Training code for sorting
|
| 35 |
+
│ │ └── utils.py # Sort specific utils (e.g., CTC decode)
|
| 36 |
+
│ ├── parity
|
| 37 |
+
│ │ ├── train.py # Training code for parity task
|
| 38 |
+
│ │ ├── utils.py # Parity-specific helper functions
|
| 39 |
+
│ │ ├── plotting.py # Plotting utils specific to this task
|
| 40 |
+
│ │ ├── scripts/
|
| 41 |
+
│ │ │ └── *.sh # Training scripts for different experimental setups
|
| 42 |
+
│ │ └── analysis/
|
| 43 |
+
│ │ └── run.py # Entry point for parity analysis
|
| 44 |
+
│ ├── qamnist
|
| 45 |
+
│ │ ├── train.py # Training code for QAMNIST task (quantized MNIST)
|
| 46 |
+
│ │ ├── utils.py # QAMNIST-specific helper functions
|
| 47 |
+
│ │ ├── plotting.py # Plotting utils specific to this task
|
| 48 |
+
│ │ ├── scripts/
|
| 49 |
+
│ │ │ └── *.sh # Training scripts for different experimental setups
|
| 50 |
+
│ │ └── analysis/
|
| 51 |
+
│ │ └── run.py # Entry point for QAMNIST analysis
|
| 52 |
+
│ └── rl
|
| 53 |
+
│ ├── train.py # Training code for RL environments
|
| 54 |
+
│ ├── utils.py # RL-specific helper functions
|
| 55 |
+
│ ├── plotting.py # Plotting utils specific to this task
|
| 56 |
+
│ ├── envs.py # Custom RL environment wrappers
|
| 57 |
+
│ ├── scripts/
|
| 58 |
+
│ │ ├── 4rooms/
|
| 59 |
+
│ │ │ └── *.sh # Training scripts for MiniGrid-FourRooms-v0 environment
|
| 60 |
+
│ │ ├── acrobot/
|
| 61 |
+
│ │ │ └── *.sh # Training scripts for Acrobot-v1 environment
|
| 62 |
+
│ │ └── cartpole/
|
| 63 |
+
│ │ └── *.sh # Training scripts for CartPole-v1 environment
|
| 64 |
+
│ └── analysis/
|
| 65 |
+
│ └── run.py # Entry point for RL analysis
|
| 66 |
+
├── data # This is where data will be saved and downloaded to
|
| 67 |
+
│ └── custom_datasets.py # Custom datasets (e.g., Mazes), sort
|
| 68 |
+
├── models
|
| 69 |
+
│ ├── ctm.py # Main model code, used for: image classification, solving mazes, sort
|
| 70 |
+
│ ├── ctm_*.py # Other model code, standalone adjustments for other tasks
|
| 71 |
+
│ ├── ff.py # feed-forward (simple) baseline code (e.g., for image classification)
|
| 72 |
+
│ ├── lstm.py # LSTM baseline code (e.g., for image classification)
|
| 73 |
+
│ ├── lstm_*.py # Other baseline code, standalone adjustments for other tasks
|
| 74 |
+
│ ├── modules.py # Helper modules, including Neuron-level models and the Synapse UNET
|
| 75 |
+
│ ├── utils.py # Helper functions (e.g., synch decay)
|
| 76 |
+
│ └── resnet.py # Wrapper for ResNet featuriser
|
| 77 |
+
├── utils
|
| 78 |
+
│ ├── housekeeping.py # Helper functions for keeping things neat
|
| 79 |
+
│ ├── losses.py # Loss functions for various tasks (mostly with reshaping stuff)
|
| 80 |
+
│ └── schedulers.py # Helper wrappers for learning rate schedulers
|
| 81 |
+
└── checkpoints
|
| 82 |
+
└── imagenet, mazes, ... # Checkpoint directories (see google drive link for files)
|
| 83 |
+
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
## Setup
|
| 87 |
+
To set up the environment using conda:
|
| 88 |
+
|
| 89 |
+
```
|
| 90 |
+
conda create --name=ctm python=3.12
|
| 91 |
+
conda activate ctm
|
| 92 |
+
pip install -r requirements.txt
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
If there are issues with PyTorch versions, the following can be ran:
|
| 96 |
+
```
|
| 97 |
+
pip uninstall torch
|
| 98 |
+
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## Model training
|
| 102 |
+
Each task has its own (set of) training code. See for instance [tasks/image_classification/train.py](tasks/image_classification/train.py). We have set it up like this to ensure ease-of-use as opposed to clinical efficiency. This code is for researchers and we hope to have it shared in a way that fosters collaboration and learning.
|
| 103 |
+
|
| 104 |
+
While we have provided reasonable defaults in the argparsers of each training setup, scripts to replicate the setups in the paper will typically be found in the accompanying script folders. If you simply want to dive in, run the following as a module (setup like this to make it easy to run many high-level training scripts from the top directory):
|
| 105 |
+
|
| 106 |
+
```
|
| 107 |
+
python -m tasks.image_classification.train
|
| 108 |
+
```
|
| 109 |
+
For debugging in VSCode, this configuration example might be helpful to you:
|
| 110 |
+
```
|
| 111 |
+
{
|
| 112 |
+
"name": "Debug: train image classifier",
|
| 113 |
+
"type": "debugpy",
|
| 114 |
+
"request": "launch",
|
| 115 |
+
"module": "tasks.image_classification.train",
|
| 116 |
+
"console": "integratedTerminal",
|
| 117 |
+
"justMyCode": false
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
## Running analyses
|
| 123 |
+
|
| 124 |
+
We also provide analysis and plotting code to replicate many of the plots in our paper. See `tasks/.../analysis/*` for more details on that. We als provide some data (e.g., the mazes we generated for training) and checkpoints (see [here](#checkpoints-and-data))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
## Checkpoints and data
|
| 128 |
+
You can download the data and checkpoints from here: https://drive.google.com/drive/folders/1f4N0ndIDrRvac5fUnWof33KWhvz8iqo_?usp=drive_link
|
| 129 |
+
|
| 130 |
+
Checkpoints go in the `checkpoints` folder. For instance, when properly populated, the checkpoints folder will have the maze checkpoint in `checkpoints/mazes/...`
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
data/custom_datasets.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision.datasets import ImageFolder
|
| 3 |
+
from torch.utils.data import Dataset
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from datasets import load_dataset
|
| 9 |
+
|
| 10 |
+
class SortDataset(Dataset):
|
| 11 |
+
def __init__(self, N):
|
| 12 |
+
self.N = N
|
| 13 |
+
def __len__(self):
|
| 14 |
+
return 10000000
|
| 15 |
+
def __getitem__(self, idx):
|
| 16 |
+
data = torch.zeros(self.N).normal_()
|
| 17 |
+
ordering = torch.argsort(data)
|
| 18 |
+
inputs = data
|
| 19 |
+
return (inputs), (ordering)
|
| 20 |
+
|
| 21 |
+
class QAMNISTDataset(Dataset):
|
| 22 |
+
"""A QAMNIST dataset that includes plus and minus operations on MNIST digits."""
|
| 23 |
+
def __init__(self, base_dataset, num_images, num_images_delta, num_repeats_per_input, num_operations, num_operations_delta):
|
| 24 |
+
self.base_dataset = base_dataset
|
| 25 |
+
|
| 26 |
+
self.num_images = num_images
|
| 27 |
+
self.num_images_delta = num_images_delta
|
| 28 |
+
self.num_images_range = self._calculate_num_images_range()
|
| 29 |
+
|
| 30 |
+
self.operators = ["+", "-"]
|
| 31 |
+
self.num_operations = num_operations
|
| 32 |
+
self.num_operations_delta = num_operations_delta
|
| 33 |
+
self.num_operations_range = self._calculate_num_operations_range()
|
| 34 |
+
|
| 35 |
+
self.num_repeats_per_input = num_repeats_per_input
|
| 36 |
+
|
| 37 |
+
self.current_num_digits = num_images
|
| 38 |
+
self.current_num_operations = num_operations
|
| 39 |
+
|
| 40 |
+
self.modulo_base = 10
|
| 41 |
+
|
| 42 |
+
self.output_range = [0, 9]
|
| 43 |
+
|
| 44 |
+
def _calculate_num_images_range(self):
|
| 45 |
+
min_val = self.num_images - self.num_images_delta
|
| 46 |
+
max_val = self.num_images + self.num_images_delta
|
| 47 |
+
assert min_val >= 1, f"Minimum number of images must be at least 1, got {min_val}"
|
| 48 |
+
return [min_val, max_val]
|
| 49 |
+
|
| 50 |
+
def _calculate_num_operations_range(self):
|
| 51 |
+
min_val = self.num_operations - self.num_operations_delta
|
| 52 |
+
max_val = self.num_operations + self.num_operations_delta
|
| 53 |
+
assert min_val >= 1, f"Minimum number of operations must be at least 1, got {min_val}"
|
| 54 |
+
return [min_val, max_val]
|
| 55 |
+
|
| 56 |
+
def set_num_digits(self, num_digits):
|
| 57 |
+
self.current_num_digits = num_digits
|
| 58 |
+
|
| 59 |
+
def set_num_operations(self, num_operations):
|
| 60 |
+
self.current_num_operations = num_operations
|
| 61 |
+
|
| 62 |
+
def _get_target_and_question(self, targets):
|
| 63 |
+
question = []
|
| 64 |
+
equations = []
|
| 65 |
+
num_digits = self.current_num_digits
|
| 66 |
+
num_operations = self.current_num_operations
|
| 67 |
+
|
| 68 |
+
# Select the initial digit
|
| 69 |
+
selection_idx = np.random.randint(num_digits)
|
| 70 |
+
first_digit = targets[selection_idx]
|
| 71 |
+
question.extend([selection_idx] * self.num_repeats_per_input)
|
| 72 |
+
# Set current_value to the initial digit (mod is applied in each operation)
|
| 73 |
+
current_value = first_digit % self.modulo_base
|
| 74 |
+
|
| 75 |
+
# For each operation, build an equation line
|
| 76 |
+
for _ in range(num_operations):
|
| 77 |
+
# Choose the operator ('+' or '-')
|
| 78 |
+
operator_idx = np.random.randint(len(self.operators))
|
| 79 |
+
operator = self.operators[operator_idx]
|
| 80 |
+
encoded_operator = -(operator_idx + 1) # -1 for '+', -2 for '-'
|
| 81 |
+
question.extend([encoded_operator] * self.num_repeats_per_input)
|
| 82 |
+
|
| 83 |
+
# Choose the next digit
|
| 84 |
+
selection_idx = np.random.randint(num_digits)
|
| 85 |
+
digit = targets[selection_idx]
|
| 86 |
+
question.extend([selection_idx] * self.num_repeats_per_input)
|
| 87 |
+
|
| 88 |
+
# Compute the new value with immediate modulo reduction
|
| 89 |
+
if operator == '+':
|
| 90 |
+
new_value = (current_value + digit) % self.modulo_base
|
| 91 |
+
else: # operator is '-'
|
| 92 |
+
new_value = (current_value - digit) % self.modulo_base
|
| 93 |
+
|
| 94 |
+
# Build the equation string for this step
|
| 95 |
+
equations.append(f"({current_value} {operator} {digit}) mod {self.modulo_base} = {new_value}")
|
| 96 |
+
# Update current value for the next operation
|
| 97 |
+
current_value = new_value
|
| 98 |
+
|
| 99 |
+
target = current_value
|
| 100 |
+
question_readable = "\n".join(equations)
|
| 101 |
+
return target, question, question_readable
|
| 102 |
+
|
| 103 |
+
def __len__(self):
|
| 104 |
+
return len(self.base_dataset)
|
| 105 |
+
|
| 106 |
+
def __getitem__(self, idx):
|
| 107 |
+
images, targets = [],[]
|
| 108 |
+
for _ in range(self.current_num_digits):
|
| 109 |
+
image, target = self.base_dataset[np.random.randint(self.__len__())]
|
| 110 |
+
images.append(image)
|
| 111 |
+
targets.append(target)
|
| 112 |
+
|
| 113 |
+
observations = torch.repeat_interleave(torch.stack(images, 0), repeats=self.num_repeats_per_input, dim=0)
|
| 114 |
+
target, question, question_readable = self._get_target_and_question(targets)
|
| 115 |
+
return observations, question, question_readable, target
|
| 116 |
+
|
| 117 |
+
class ImageNet(Dataset):
|
| 118 |
+
def __init__(self, which_split, transform):
|
| 119 |
+
"""
|
| 120 |
+
Most simple form of the custom dataset structure.
|
| 121 |
+
Args:
|
| 122 |
+
base_dataset (Dataset): The base dataset to sample from.
|
| 123 |
+
N (int): The number of images to construct into an observable sequence.
|
| 124 |
+
R (int): number of repeats
|
| 125 |
+
operators (list): list of operators from which to sample
|
| 126 |
+
action to take on observations (str): can be 'global' to compute operator over full observations, or 'select_K', where K=integer.
|
| 127 |
+
"""
|
| 128 |
+
dataset = load_dataset('imagenet-1k', split=which_split, trust_remote_code=True)
|
| 129 |
+
|
| 130 |
+
self.transform = transform
|
| 131 |
+
self.base_dataset = dataset
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
return len(self.base_dataset)
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
data_item = self.base_dataset[idx]
|
| 138 |
+
image = self.transform(data_item['image'].convert('RGB'))
|
| 139 |
+
target = data_item['label']
|
| 140 |
+
return image, target
|
| 141 |
+
|
| 142 |
+
class MazeImageFolder(ImageFolder):
|
| 143 |
+
"""
|
| 144 |
+
A custom dataset class that extends the ImageFolder class.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
root (string): Root directory path.
|
| 148 |
+
transform (callable, optional): A function/transform that takes in
|
| 149 |
+
a sample and returns a transformed version.
|
| 150 |
+
E.g, ``transforms.RandomCrop`` for images.
|
| 151 |
+
target_transform (callable, optional): A function/transform that takes
|
| 152 |
+
in the target and transforms it.
|
| 153 |
+
loader (callable, optional): A function to load an image given its path.
|
| 154 |
+
is_valid_file (callable, optional): A function that takes path of an Image file
|
| 155 |
+
and check if the file is a valid file (used to check of corrupt files)
|
| 156 |
+
|
| 157 |
+
Attributes:
|
| 158 |
+
classes (list): List of the class names.
|
| 159 |
+
class_to_idx (dict): Dict with items (class_name, class_index).
|
| 160 |
+
imgs (list): List of (image path, class_index) tuples
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, root, transform=None, target_transform=None,
|
| 164 |
+
loader=Image.open,
|
| 165 |
+
is_valid_file=None,
|
| 166 |
+
which_set='train',
|
| 167 |
+
augment_p=0.5,
|
| 168 |
+
maze_route_length=10,
|
| 169 |
+
trunc=False,
|
| 170 |
+
expand_range=True):
|
| 171 |
+
super(MazeImageFolder, self).__init__(root, transform, target_transform, loader, is_valid_file)
|
| 172 |
+
self.which_set = which_set
|
| 173 |
+
self.augment_p = augment_p
|
| 174 |
+
self.maze_route_length = maze_route_length
|
| 175 |
+
self.all_paths = {}
|
| 176 |
+
self.trunc = trunc
|
| 177 |
+
self.expand_range = expand_range
|
| 178 |
+
|
| 179 |
+
self._preload()
|
| 180 |
+
print('Solving all mazes...')
|
| 181 |
+
for index in range(len(self.preloaded_samples)):
|
| 182 |
+
path = self.get_solution(self.preloaded_samples[index])
|
| 183 |
+
self.all_paths[index] = path
|
| 184 |
+
|
| 185 |
+
def _preload(self):
|
| 186 |
+
preloaded_samples = []
|
| 187 |
+
with tqdm(total=self.__len__(), initial=0, leave=True, position=0, dynamic_ncols=True) as pbar:
|
| 188 |
+
|
| 189 |
+
for index in range(self.__len__()):
|
| 190 |
+
pbar.set_description('Loading mazes')
|
| 191 |
+
path, target = self.samples[index]
|
| 192 |
+
sample = self.loader(path)
|
| 193 |
+
sample = np.array(sample).astype(np.float32)/255
|
| 194 |
+
preloaded_samples.append(sample)
|
| 195 |
+
pbar.update(1)
|
| 196 |
+
if self.trunc and index == 999: break
|
| 197 |
+
self.preloaded_samples = preloaded_samples
|
| 198 |
+
|
| 199 |
+
def __len__(self):
|
| 200 |
+
if hasattr(self, 'preloaded_samples') and self.preloaded_samples is not None:
|
| 201 |
+
return len(self.preloaded_samples)
|
| 202 |
+
else:
|
| 203 |
+
return super().__len__()
|
| 204 |
+
|
| 205 |
+
def get_solution(self, x):
|
| 206 |
+
x = np.copy(x)
|
| 207 |
+
# Find start (red) and end (green) pixel coordinates
|
| 208 |
+
start_coords = np.argwhere((x == [1, 0, 0]).all(axis=2))
|
| 209 |
+
end_coords = np.argwhere((x == [0, 1, 0]).all(axis=2))
|
| 210 |
+
|
| 211 |
+
if len(start_coords) == 0 or len(end_coords) == 0:
|
| 212 |
+
print("Start or end point not found.")
|
| 213 |
+
return None
|
| 214 |
+
|
| 215 |
+
start_y, start_x = start_coords[0]
|
| 216 |
+
end_y, end_x = end_coords[0]
|
| 217 |
+
|
| 218 |
+
current_y, current_x = start_y, start_x
|
| 219 |
+
path = [4] * self.maze_route_length
|
| 220 |
+
|
| 221 |
+
pi = 0
|
| 222 |
+
while (current_y, current_x) != (end_y, end_x):
|
| 223 |
+
next_y, next_x = -1, -1 # Initialize to invalid coordinates
|
| 224 |
+
direction = -1 # Initialize to an invalid direction
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Check Up
|
| 228 |
+
if current_y > 0 and ((x[current_y - 1, current_x] == [0, 0, 1]).all() or (x[current_y - 1, current_x] == [0, 1, 0]).all()):
|
| 229 |
+
next_y, next_x = current_y - 1, current_x
|
| 230 |
+
direction = 0
|
| 231 |
+
|
| 232 |
+
# Check Down
|
| 233 |
+
elif current_y < x.shape[0] - 1 and ((x[current_y + 1, current_x] == [0, 0, 1]).all() or (x[current_y + 1, current_x] == [0, 1, 0]).all()):
|
| 234 |
+
next_y, next_x = current_y + 1, current_x
|
| 235 |
+
direction = 1
|
| 236 |
+
|
| 237 |
+
# Check Left
|
| 238 |
+
elif current_x > 0 and ((x[current_y, current_x - 1] == [0, 0, 1]).all() or (x[current_y, current_x - 1] == [0, 1, 0]).all()):
|
| 239 |
+
next_y, next_x = current_y, current_x - 1
|
| 240 |
+
direction = 2
|
| 241 |
+
|
| 242 |
+
# Check Right
|
| 243 |
+
elif current_x < x.shape[1] - 1 and ((x[current_y, current_x + 1] == [0, 0, 1]).all() or (x[current_y, current_x + 1] == [0, 1, 0]).all()):
|
| 244 |
+
next_y, next_x = current_y, current_x + 1
|
| 245 |
+
direction = 3
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
path[pi] = direction
|
| 249 |
+
pi += 1
|
| 250 |
+
|
| 251 |
+
x[current_y, current_x] = [255,255,255] # mark the current as white to avoid going in circles
|
| 252 |
+
current_y, current_x = next_y, next_x
|
| 253 |
+
if pi == len(path):
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
return np.array(path)
|
| 257 |
+
|
| 258 |
+
def __getitem__(self, index):
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
index (int): Index
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
tuple: (sample, target) where target is class_index of the target class.
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
sample = np.copy(self.preloaded_samples[index])
|
| 268 |
+
|
| 269 |
+
path = np.copy(self.all_paths[index])
|
| 270 |
+
|
| 271 |
+
if self.which_set == 'train':
|
| 272 |
+
# Randomly rotate -90 or +90 degrees
|
| 273 |
+
if random.random() < self.augment_p:
|
| 274 |
+
which_rot = random.choice([-1, 1])
|
| 275 |
+
sample = np.rot90(sample, k=which_rot, axes=(0, 1))
|
| 276 |
+
for pi in range(len(path)):
|
| 277 |
+
if path[pi] == 0: path[pi] = 3 if which_rot == -1 else 2
|
| 278 |
+
elif path[pi] == 1: path[pi] = 2 if which_rot == -1 else 3
|
| 279 |
+
elif path[pi] == 2: path[pi] = 0 if which_rot == -1 else 1
|
| 280 |
+
elif path[pi] == 3: path[pi] = 1 if which_rot == -1 else 0
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
# Random horizontal flip
|
| 284 |
+
if random.random() < self.augment_p:
|
| 285 |
+
sample = np.fliplr(sample)
|
| 286 |
+
for pi in range(len(path)):
|
| 287 |
+
if path[pi] == 2: path[pi] = 3
|
| 288 |
+
elif path[pi] == 3: path[pi] = 2
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# Random vertical flip
|
| 292 |
+
if random.random() < self.augment_p:
|
| 293 |
+
sample = np.flipud(sample)
|
| 294 |
+
for pi in range(len(path)):
|
| 295 |
+
if path[pi] == 0: path[pi] = 1
|
| 296 |
+
elif path[pi] == 1: path[pi] = 0
|
| 297 |
+
|
| 298 |
+
sample = torch.from_numpy(np.copy(sample)).permute(2,0,1)
|
| 299 |
+
|
| 300 |
+
blue_mask = (sample[0] == 0) & (sample[1] == 0) & (sample[2] == 1)
|
| 301 |
+
|
| 302 |
+
sample[:, blue_mask] = 1
|
| 303 |
+
target = path
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
if not self.expand_range:
|
| 307 |
+
return sample, target
|
| 308 |
+
return (sample*2)-1, (target)
|
| 309 |
+
|
| 310 |
+
class ParityDataset(Dataset):
|
| 311 |
+
def __init__(self, sequence_length=64, length=100000):
|
| 312 |
+
self.sequence_length = sequence_length
|
| 313 |
+
self.length = length
|
| 314 |
+
|
| 315 |
+
def __len__(self):
|
| 316 |
+
return self.length
|
| 317 |
+
|
| 318 |
+
def __getitem__(self, idx):
|
| 319 |
+
vector = 2 * torch.randint(0, 2, (self.sequence_length,)) - 1
|
| 320 |
+
vector = vector.float()
|
| 321 |
+
negatives = (vector == -1).to(torch.long)
|
| 322 |
+
cumsum = torch.cumsum(negatives, dim=0)
|
| 323 |
+
target = (cumsum % 2 != 0).to(torch.long)
|
| 324 |
+
return vector, target
|
examples/01_mnist.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models/README.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Continuous Thought Machines
|
| 2 |
+
## Models
|
| 3 |
+
|
| 4 |
+
This folder contains all model-related code.
|
| 5 |
+
|
| 6 |
+
Some notes for clarity:
|
| 7 |
+
1. The resnet structure we used (see resnet.py) has a few minor changes that enable constraining the receptive field of the features yielded. We do this because we want the CTM (or baseline methods) to learn a process whereby they gather information. Neural networks that use SGD will find the [path of least resistence](https://era.ed.ac.uk/handle/1842/39606), even if that path doesn't result in actually intelligent behaviour. Constraining the receptive field helps to prevent this, a bit.
|
models/constants.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
VALID_NEURON_SELECT_TYPES = ['first-last', 'random', 'random-pairing']
|
| 2 |
+
|
| 3 |
+
VALID_BACKBONE_TYPES = [
|
| 4 |
+
f'resnet{depth}-{i}' for depth in [18, 34, 50, 101, 152] for i in range(1, 5)
|
| 5 |
+
] + ['shallow-wide', 'parity_backbone']
|
| 6 |
+
|
| 7 |
+
VALID_POSITIONAL_EMBEDDING_TYPES = [
|
| 8 |
+
'learnable-fourier', 'multi-learnable-fourier',
|
| 9 |
+
'custom-rotational', 'custom-rotational-1d'
|
| 10 |
+
]
|
models/ctm.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
from models.modules import ParityBackbone, SynapseUNET, Squeeze, SuperLinear, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
|
| 7 |
+
from models.resnet import prepare_resnet_backbone
|
| 8 |
+
from models.utils import compute_normalized_entropy
|
| 9 |
+
|
| 10 |
+
from models.constants import (
|
| 11 |
+
VALID_NEURON_SELECT_TYPES,
|
| 12 |
+
VALID_BACKBONE_TYPES,
|
| 13 |
+
VALID_POSITIONAL_EMBEDDING_TYPES
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
class ContinuousThoughtMachine(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
Continuous Thought Machine (CTM).
|
| 19 |
+
|
| 20 |
+
Technical report: TODO:LINK
|
| 21 |
+
|
| 22 |
+
Technical report (web version): TODO:LINK
|
| 23 |
+
|
| 24 |
+
Blog: TODO:LINK
|
| 25 |
+
|
| 26 |
+
Thought takes time and reasoning is a process.
|
| 27 |
+
|
| 28 |
+
The CTM consists of three main ideas:
|
| 29 |
+
1. The use of internal recurrence, enabling a dimension over which a concept analogous to thought can occur.
|
| 30 |
+
1. Neuron-level models, that compute post-activations by applying private (i.e., on a per-neuron basis) MLP
|
| 31 |
+
models to a history of incoming pre-activations.
|
| 32 |
+
2. Synchronisation as representation, where the neural activity over time is tracked and used to compute how
|
| 33 |
+
pairs of neurons synchronise with one another over time. This measure of synchronisation is the representation
|
| 34 |
+
with which the CTM takes action and makes predictions.
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
iterations (int): Number of internal 'thought' ticks (T, in paper).
|
| 39 |
+
d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
|
| 40 |
+
NOTE: Note that this is NOT the representation used for action or prediction, but rather that which
|
| 41 |
+
is fully internal to the model and not directly connected to data.
|
| 42 |
+
d_input (int): Dimensionality of projected attention outputs or direct input features.
|
| 43 |
+
heads (int): Number of attention heads.
|
| 44 |
+
n_synch_out (int): Number of neurons used for output synchronisation (D_out, in paper).
|
| 45 |
+
n_synch_action (int): Number of neurons used for action/attention synchronisation (D_action, in paper).
|
| 46 |
+
synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
|
| 47 |
+
memory_length (int): History length for Neuron-Level Models (M, in paper).
|
| 48 |
+
deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
|
| 49 |
+
NOTE: we almost always use deep NLMs, but a linear NLM is faster.
|
| 50 |
+
memory_hidden_dims (int): Hidden dimension size for deep NLMs.
|
| 51 |
+
do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
|
| 52 |
+
NOTE: we never set this to true in the paper. If you set this to true you will get strange behaviour,
|
| 53 |
+
but you can potentially encourage more periodic behaviour in the dynamics. Untested; be careful.
|
| 54 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 55 |
+
positional_embedding_type (str): Type of positional embedding for backbone features.
|
| 56 |
+
out_dims (int): Output dimension size.
|
| 57 |
+
NOTE: projected from synchronisation!
|
| 58 |
+
prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
|
| 59 |
+
NOTE: this is used to compute certainty and is needed when applying softmax for probabilities
|
| 60 |
+
dropout (float): Dropout rate.
|
| 61 |
+
neuron_select_type (str): Neuron selection strategy ('first-last', 'random', 'random-pairing').
|
| 62 |
+
NOTE: some of this is legacy from our experimentation, but all three strategies are valid and useful.
|
| 63 |
+
We dilineate exactly which strategies we use per experiment in the paper.
|
| 64 |
+
- first-last: build a 'dense' sync matrix for output from the first D_out neurons and action from the
|
| 65 |
+
last D_action neurons. Flatten this matrix into the synchronisation representation.
|
| 66 |
+
This approach shares relationships for neurons and bottlenecks the gradients through them.
|
| 67 |
+
NOTE: the synchronisation size will be (D_out/action * (D_out/action + 1))/2
|
| 68 |
+
- random: randomly select D_out neurons for the 'i' side pairings, and also D_out for the 'j' side pairings,
|
| 69 |
+
also pairing those accross densely, resulting in a bottleneck roughly 2x as wide.
|
| 70 |
+
NOTE: the synchronisation size will be (D_out/action * (D_out/action + 1))/2
|
| 71 |
+
- random-pairing (DEFAULT!): randomly select D_out neurons and pair these with another D_out neurons.
|
| 72 |
+
This results in much less bottlenecking and is the most up-to-date variant.
|
| 73 |
+
NOTE: the synchronisation size will be D_out in this case; better control.
|
| 74 |
+
n_random_pairing_self (int): Number of neurons to select for self-to-self synch when random-pairing is used.
|
| 75 |
+
NOTE: when using random-pairing, i-to-i (self) synchronisation is rare, meaning that 'recovering a
|
| 76 |
+
snapshot representation' (see paper) is difficult. This alleviates that.
|
| 77 |
+
NOTE: works fine when set to 0.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self,
|
| 81 |
+
iterations,
|
| 82 |
+
d_model,
|
| 83 |
+
d_input,
|
| 84 |
+
heads,
|
| 85 |
+
n_synch_out,
|
| 86 |
+
n_synch_action,
|
| 87 |
+
synapse_depth,
|
| 88 |
+
memory_length,
|
| 89 |
+
deep_nlms,
|
| 90 |
+
memory_hidden_dims,
|
| 91 |
+
do_layernorm_nlm,
|
| 92 |
+
backbone_type,
|
| 93 |
+
positional_embedding_type,
|
| 94 |
+
out_dims,
|
| 95 |
+
prediction_reshaper=[-1],
|
| 96 |
+
dropout=0,
|
| 97 |
+
dropout_nlm=None,
|
| 98 |
+
neuron_select_type='random-pairing',
|
| 99 |
+
n_random_pairing_self=0,
|
| 100 |
+
):
|
| 101 |
+
super(ContinuousThoughtMachine, self).__init__()
|
| 102 |
+
|
| 103 |
+
# --- Core Parameters ---
|
| 104 |
+
self.iterations = iterations
|
| 105 |
+
self.d_model = d_model
|
| 106 |
+
self.d_input = d_input
|
| 107 |
+
self.memory_length = memory_length
|
| 108 |
+
self.prediction_reshaper = prediction_reshaper
|
| 109 |
+
self.n_synch_out = n_synch_out
|
| 110 |
+
self.n_synch_action = n_synch_action
|
| 111 |
+
self.backbone_type = backbone_type
|
| 112 |
+
self.out_dims = out_dims
|
| 113 |
+
self.positional_embedding_type = positional_embedding_type
|
| 114 |
+
self.neuron_select_type = neuron_select_type
|
| 115 |
+
self.memory_length = memory_length
|
| 116 |
+
dropout_nlm = dropout if dropout_nlm is None else dropout_nlm
|
| 117 |
+
|
| 118 |
+
# --- Assertions ---
|
| 119 |
+
self.verify_args()
|
| 120 |
+
|
| 121 |
+
# --- Input Processing ---
|
| 122 |
+
d_backbone = self.get_d_backbone()
|
| 123 |
+
self.set_initial_rgb()
|
| 124 |
+
self.set_backbone()
|
| 125 |
+
self.positional_embedding = self.get_positional_embedding(d_backbone)
|
| 126 |
+
self.kv_proj = nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input)) if heads else None
|
| 127 |
+
self.q_proj = nn.LazyLinear(self.d_input) if heads else None
|
| 128 |
+
self.attention = nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True) if heads else None
|
| 129 |
+
|
| 130 |
+
# --- Core CTM Modules ---
|
| 131 |
+
self.synapses = self.get_synapses(synapse_depth, d_model, dropout)
|
| 132 |
+
self.trace_processor = self.get_neuron_level_models(deep_nlms, do_layernorm_nlm, memory_length, memory_hidden_dims, d_model, dropout_nlm)
|
| 133 |
+
|
| 134 |
+
# --- Start States ---
|
| 135 |
+
self.register_parameter('start_activated_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model)))))
|
| 136 |
+
self.register_parameter('start_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length)))))
|
| 137 |
+
|
| 138 |
+
# --- Synchronisation ---
|
| 139 |
+
self.neuron_select_type_out, self.neuron_select_type_action = self.get_neuron_select_type()
|
| 140 |
+
self.synch_representation_size_action = self.calculate_synch_representation_size(self.n_synch_action)
|
| 141 |
+
self.synch_representation_size_out = self.calculate_synch_representation_size(self.n_synch_out)
|
| 142 |
+
|
| 143 |
+
for synch_type, size in (('action', self.synch_representation_size_action), ('out', self.synch_representation_size_out)):
|
| 144 |
+
print(f"Synch representation size {synch_type}: {size}")
|
| 145 |
+
if self.synch_representation_size_action: # if not zero
|
| 146 |
+
self.set_synchronisation_parameters('action', self.n_synch_action, n_random_pairing_self)
|
| 147 |
+
self.set_synchronisation_parameters('out', self.n_synch_out, n_random_pairing_self)
|
| 148 |
+
|
| 149 |
+
# --- Output Procesing ---
|
| 150 |
+
self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))
|
| 151 |
+
|
| 152 |
+
# --- Core CTM Methods ---
|
| 153 |
+
|
| 154 |
+
def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, r, synch_type):
|
| 155 |
+
"""
|
| 156 |
+
Computes synchronisation to be used as a vector representation.
|
| 157 |
+
|
| 158 |
+
A neuron has what we call a 'trace', which is a history (time series) that changes with internal
|
| 159 |
+
recurrence. i.e., it gets longer with every internal tick. There are pre-activation traces
|
| 160 |
+
that are used in the NLMs and post-activation traces that, in theory, are used in this method.
|
| 161 |
+
|
| 162 |
+
We define sychronisation between neuron i and j as the dot product between their respective
|
| 163 |
+
time series. Since there can be many internal ticks, this process can be quite compute heavy as it
|
| 164 |
+
involves many dot products that repeat computation at each step.
|
| 165 |
+
|
| 166 |
+
Therefore, in practice, we update the synchronisation based on the current post-activations,
|
| 167 |
+
which we call the 'activated state' here. This is possible because the inputs to synchronisation
|
| 168 |
+
are only updated recurrently at each step, meaning that there is a linear recurrence we can
|
| 169 |
+
leverage.
|
| 170 |
+
|
| 171 |
+
See Appendix TODO of the Technical Report (TODO:LINK) for the maths that enables this method.
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
if synch_type == 'action': # Get action parameters
|
| 175 |
+
n_synch = self.n_synch_action
|
| 176 |
+
neuron_indices_left = self.action_neuron_indices_left
|
| 177 |
+
neuron_indices_right = self.action_neuron_indices_right
|
| 178 |
+
elif synch_type == 'out': # Get input parameters
|
| 179 |
+
n_synch = self.n_synch_out
|
| 180 |
+
neuron_indices_left = self.out_neuron_indices_left
|
| 181 |
+
neuron_indices_right = self.out_neuron_indices_right
|
| 182 |
+
|
| 183 |
+
if self.neuron_select_type in ('first-last', 'random'):
|
| 184 |
+
# For first-last and random, we compute the pairwise sync between all selected neurons
|
| 185 |
+
if self.neuron_select_type == 'first-last':
|
| 186 |
+
if synch_type == 'action': # Use last n_synch neurons for action
|
| 187 |
+
selected_left = selected_right = activated_state[:, -n_synch:]
|
| 188 |
+
elif synch_type == 'out': # Use first n_synch neurons for out
|
| 189 |
+
selected_left = selected_right = activated_state[:, :n_synch]
|
| 190 |
+
else: # Use the randomly selected neurons
|
| 191 |
+
selected_left = activated_state[:, neuron_indices_left]
|
| 192 |
+
selected_right = activated_state[:, neuron_indices_right]
|
| 193 |
+
|
| 194 |
+
# Compute outer product of selected neurons
|
| 195 |
+
outer = selected_left.unsqueeze(2) * selected_right.unsqueeze(1)
|
| 196 |
+
# Resulting matrix is symmetric, so we only need the upper triangle
|
| 197 |
+
i, j = torch.triu_indices(n_synch, n_synch)
|
| 198 |
+
pairwise_product = outer[:, i, j]
|
| 199 |
+
|
| 200 |
+
elif self.neuron_select_type == 'random-pairing':
|
| 201 |
+
# For random-pairing, we compute the sync between specific pairs of neurons
|
| 202 |
+
left = activated_state[:, neuron_indices_left]
|
| 203 |
+
right = activated_state[:, neuron_indices_right]
|
| 204 |
+
pairwise_product = left * right
|
| 205 |
+
else:
|
| 206 |
+
raise ValueError("Invalid neuron selection type")
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# Compute synchronisation recurrently
|
| 211 |
+
if decay_alpha is None or decay_beta is None:
|
| 212 |
+
decay_alpha = pairwise_product
|
| 213 |
+
decay_beta = torch.ones_like(pairwise_product)
|
| 214 |
+
else:
|
| 215 |
+
decay_alpha = r * decay_alpha + pairwise_product
|
| 216 |
+
decay_beta = r * decay_beta + 1
|
| 217 |
+
|
| 218 |
+
synchronisation = decay_alpha / (torch.sqrt(decay_beta))
|
| 219 |
+
return synchronisation, decay_alpha, decay_beta
|
| 220 |
+
|
| 221 |
+
def compute_features(self, x):
|
| 222 |
+
"""
|
| 223 |
+
Compute the key-value features from the input data using the backbone.
|
| 224 |
+
"""
|
| 225 |
+
initial_rgb = self.initial_rgb(x)
|
| 226 |
+
self.kv_features = self.backbone(initial_rgb)
|
| 227 |
+
pos_emb = self.positional_embedding(self.kv_features)
|
| 228 |
+
combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2)
|
| 229 |
+
kv = self.kv_proj(combined_features)
|
| 230 |
+
return kv
|
| 231 |
+
|
| 232 |
+
def compute_certainty(self, current_prediction):
|
| 233 |
+
"""
|
| 234 |
+
Compute the certainty of the current prediction.
|
| 235 |
+
|
| 236 |
+
We define certainty as being 1-normalised entropy.
|
| 237 |
+
|
| 238 |
+
For legacy reasons we stack that in a 2D vector as this can be used for optimisation later.
|
| 239 |
+
"""
|
| 240 |
+
B = current_prediction.size(0)
|
| 241 |
+
reshaped_pred = current_prediction.reshape([B] + self.prediction_reshaper)
|
| 242 |
+
ne = compute_normalized_entropy(reshaped_pred)
|
| 243 |
+
current_certainty = torch.stack((ne, 1-ne), -1)
|
| 244 |
+
return current_certainty
|
| 245 |
+
|
| 246 |
+
# --- Setup Methods ---
|
| 247 |
+
|
| 248 |
+
def set_initial_rgb(self):
|
| 249 |
+
"""
|
| 250 |
+
This is largely to accommodate training on grescale images and is legacy, but it
|
| 251 |
+
doesn't hurt the model in any way that we can tell.
|
| 252 |
+
"""
|
| 253 |
+
if 'resnet' in self.backbone_type:
|
| 254 |
+
self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
|
| 255 |
+
else:
|
| 256 |
+
self.initial_rgb = nn.Identity()
|
| 257 |
+
|
| 258 |
+
def get_d_backbone(self):
|
| 259 |
+
"""
|
| 260 |
+
Get the dimensionality of the backbone output, to be used for positional embedding setup.
|
| 261 |
+
|
| 262 |
+
This is a little bit complicated for resnets, but the logic should be easy enough to read below.
|
| 263 |
+
"""
|
| 264 |
+
if self.backbone_type == 'shallow-wide':
|
| 265 |
+
return 2048
|
| 266 |
+
elif self.backbone_type == 'parity_backbone':
|
| 267 |
+
return self.d_input
|
| 268 |
+
elif 'resnet' in self.backbone_type:
|
| 269 |
+
if '18' in self.backbone_type or '34' in self.backbone_type:
|
| 270 |
+
if self.backbone_type.split('-')[1]=='1': return 64
|
| 271 |
+
elif self.backbone_type.split('-')[1]=='2': return 128
|
| 272 |
+
elif self.backbone_type.split('-')[1]=='3': return 256
|
| 273 |
+
elif self.backbone_type.split('-')[1]=='4': return 512
|
| 274 |
+
else:
|
| 275 |
+
raise NotImplementedError
|
| 276 |
+
else:
|
| 277 |
+
if self.backbone_type.split('-')[1]=='1': return 256
|
| 278 |
+
elif self.backbone_type.split('-')[1]=='2': return 512
|
| 279 |
+
elif self.backbone_type.split('-')[1]=='3': return 1024
|
| 280 |
+
elif self.backbone_type.split('-')[1]=='4': return 2048
|
| 281 |
+
else:
|
| 282 |
+
raise NotImplementedError
|
| 283 |
+
elif self.backbone_type == 'none':
|
| 284 |
+
return None
|
| 285 |
+
else:
|
| 286 |
+
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
|
| 287 |
+
|
| 288 |
+
def set_backbone(self):
|
| 289 |
+
"""
|
| 290 |
+
Set the backbone module based on the specified type.
|
| 291 |
+
"""
|
| 292 |
+
if self.backbone_type == 'shallow-wide':
|
| 293 |
+
self.backbone = ShallowWide()
|
| 294 |
+
elif self.backbone_type == 'parity_backbone':
|
| 295 |
+
d_backbone = self.get_d_backbone()
|
| 296 |
+
self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone)
|
| 297 |
+
elif 'resnet' in self.backbone_type:
|
| 298 |
+
self.backbone = prepare_resnet_backbone(self.backbone_type)
|
| 299 |
+
elif self.backbone_type == 'none':
|
| 300 |
+
self.backbone = nn.Identity()
|
| 301 |
+
else:
|
| 302 |
+
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
|
| 303 |
+
|
| 304 |
+
def get_positional_embedding(self, d_backbone):
|
| 305 |
+
"""
|
| 306 |
+
Get the positional embedding module.
|
| 307 |
+
|
| 308 |
+
For Imagenet and mazes we used NO positional embedding, and largely don't think
|
| 309 |
+
that it is necessary as the CTM can build up its own internal world model when
|
| 310 |
+
observing.
|
| 311 |
+
|
| 312 |
+
LearnableFourierPositionalEncoding:
|
| 313 |
+
Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
|
| 314 |
+
Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
|
| 315 |
+
Provides positional information for 2D feature maps.
|
| 316 |
+
|
| 317 |
+
(MultiLearnableFourierPositionalEncoding uses multiple feature scales)
|
| 318 |
+
|
| 319 |
+
CustomRotationalEmbedding:
|
| 320 |
+
Simple sinusoidal embedding to encourage interpretability
|
| 321 |
+
"""
|
| 322 |
+
if self.positional_embedding_type == 'learnable-fourier':
|
| 323 |
+
return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5)
|
| 324 |
+
elif self.positional_embedding_type == 'multi-learnable-fourier':
|
| 325 |
+
return MultiLearnableFourierPositionalEncoding(d_backbone)
|
| 326 |
+
elif self.positional_embedding_type == 'custom-rotational':
|
| 327 |
+
return CustomRotationalEmbedding(d_backbone)
|
| 328 |
+
elif self.positional_embedding_type == 'custom-rotational-1d':
|
| 329 |
+
return CustomRotationalEmbedding1D(d_backbone)
|
| 330 |
+
elif self.positional_embedding_type == 'none':
|
| 331 |
+
return lambda x: 0 # Default no-op
|
| 332 |
+
else:
|
| 333 |
+
raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}")
|
| 334 |
+
|
| 335 |
+
def get_neuron_level_models(self, deep_nlms, do_layernorm_nlm, memory_length, memory_hidden_dims, d_model, dropout):
|
| 336 |
+
"""
|
| 337 |
+
Neuron level models are one of the core innovations of the CTM. They apply separate MLPs/linears to
|
| 338 |
+
each neuron.
|
| 339 |
+
NOTE: the name 'SuperLinear' is largely legacy, but its purpose is to apply separate linear layers
|
| 340 |
+
per neuron. It is sort of a 'grouped linear' function, where the group size is equal to 1.
|
| 341 |
+
One could make the group size bigger and use fewer parameters, but that is future work.
|
| 342 |
+
|
| 343 |
+
NOTE: We used GLU() nonlinearities because they worked well in practice.
|
| 344 |
+
"""
|
| 345 |
+
if deep_nlms:
|
| 346 |
+
return nn.Sequential(
|
| 347 |
+
nn.Sequential(
|
| 348 |
+
SuperLinear(in_dims=memory_length, out_dims=2 * memory_hidden_dims, N=d_model,
|
| 349 |
+
do_norm=do_layernorm_nlm, dropout=dropout),
|
| 350 |
+
nn.GLU(),
|
| 351 |
+
SuperLinear(in_dims=memory_hidden_dims, out_dims=2, N=d_model,
|
| 352 |
+
do_norm=do_layernorm_nlm, dropout=dropout),
|
| 353 |
+
nn.GLU(),
|
| 354 |
+
Squeeze(-1)
|
| 355 |
+
)
|
| 356 |
+
)
|
| 357 |
+
else:
|
| 358 |
+
return nn.Sequential(
|
| 359 |
+
nn.Sequential(
|
| 360 |
+
SuperLinear(in_dims=memory_length, out_dims=2, N=d_model,
|
| 361 |
+
do_norm=do_layernorm_nlm, dropout=dropout),
|
| 362 |
+
nn.GLU(),
|
| 363 |
+
Squeeze(-1)
|
| 364 |
+
)
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
def get_synapses(self, synapse_depth, d_model, dropout):
|
| 368 |
+
"""
|
| 369 |
+
The synapse model is the recurrent model in the CTM. It's purpose is to share information
|
| 370 |
+
across neurons. If using depth of 1, this is just a simple single layer with nonlinearity and layernomr.
|
| 371 |
+
For deeper synapse models we use a U-NET structure with many skip connections. In practice this performs
|
| 372 |
+
better as it enables multi-level information mixing.
|
| 373 |
+
|
| 374 |
+
The intuition with having a deep UNET model for synapses is that the action of synaptic connections is
|
| 375 |
+
not necessarily a linear one, and that approximate a synapose 'update' step in the brain is non trivial.
|
| 376 |
+
Hence, we set it up so that the CTM can learn some complex internal rule instead of trying to approximate
|
| 377 |
+
it ourselves.
|
| 378 |
+
"""
|
| 379 |
+
if synapse_depth == 1:
|
| 380 |
+
return nn.Sequential(
|
| 381 |
+
nn.Dropout(dropout),
|
| 382 |
+
nn.LazyLinear(d_model * 2),
|
| 383 |
+
nn.GLU(),
|
| 384 |
+
nn.LayerNorm(d_model)
|
| 385 |
+
)
|
| 386 |
+
else:
|
| 387 |
+
return SynapseUNET(d_model, synapse_depth, 16, dropout) # hard-coded minimum width of 16; future work TODO.
|
| 388 |
+
|
| 389 |
+
def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
|
| 390 |
+
"""
|
| 391 |
+
1. Set the buffers for selecting neurons so that these indices are saved into the model state_dict.
|
| 392 |
+
2. Set the parameters for learnable exponential decay when computing synchronisation between all
|
| 393 |
+
neurons.
|
| 394 |
+
"""
|
| 395 |
+
assert synch_type in ('out', 'action'), f"Invalid synch_type: {synch_type}"
|
| 396 |
+
left, right = self.initialize_left_right_neurons(synch_type, self.d_model, n_synch, n_random_pairing_self)
|
| 397 |
+
synch_representation_size = self.synch_representation_size_action if synch_type == 'action' else self.synch_representation_size_out
|
| 398 |
+
self.register_buffer(f'{synch_type}_neuron_indices_left', left)
|
| 399 |
+
self.register_buffer(f'{synch_type}_neuron_indices_right', right)
|
| 400 |
+
self.register_parameter(f'decay_params_{synch_type}', nn.Parameter(torch.zeros(synch_representation_size), requires_grad=True))
|
| 401 |
+
|
| 402 |
+
def initialize_left_right_neurons(self, synch_type, d_model, n_synch, n_random_pairing_self=0):
|
| 403 |
+
"""
|
| 404 |
+
Initialize the left and right neuron indices based on the neuron selection type.
|
| 405 |
+
This complexity is owing to legacy experiments, but we retain that these types of
|
| 406 |
+
neuron selections are interesting to experiment with.
|
| 407 |
+
"""
|
| 408 |
+
if self.neuron_select_type=='first-last':
|
| 409 |
+
if synch_type == 'out':
|
| 410 |
+
neuron_indices_left = neuron_indices_right = torch.arange(0, n_synch)
|
| 411 |
+
elif synch_type == 'action':
|
| 412 |
+
neuron_indices_left = neuron_indices_right = torch.arange(d_model-n_synch, d_model)
|
| 413 |
+
|
| 414 |
+
elif self.neuron_select_type=='random':
|
| 415 |
+
neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
|
| 416 |
+
neuron_indices_right = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
|
| 417 |
+
|
| 418 |
+
elif self.neuron_select_type=='random-pairing':
|
| 419 |
+
assert n_synch > n_random_pairing_self, f"Need at least {n_random_pairing_self} pairs for {self.neuron_select_type}"
|
| 420 |
+
neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
|
| 421 |
+
neuron_indices_right = torch.concatenate((neuron_indices_left[:n_random_pairing_self], torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch-n_random_pairing_self))))
|
| 422 |
+
|
| 423 |
+
device = self.start_activated_state.device
|
| 424 |
+
return neuron_indices_left.to(device), neuron_indices_right.to(device)
|
| 425 |
+
|
| 426 |
+
def get_neuron_select_type(self):
|
| 427 |
+
"""
|
| 428 |
+
Another helper method to accomodate our legacy neuron selection types.
|
| 429 |
+
TODO: additional experimentation and possible removal of 'first-last' and 'random'
|
| 430 |
+
"""
|
| 431 |
+
print(f"Using neuron select type: {self.neuron_select_type}")
|
| 432 |
+
if self.neuron_select_type == 'first-last':
|
| 433 |
+
neuron_select_type_out, neuron_select_type_action = 'first', 'last'
|
| 434 |
+
elif self.neuron_select_type in ('random', 'random-pairing'):
|
| 435 |
+
neuron_select_type_out = neuron_select_type_action = self.neuron_select_type
|
| 436 |
+
else:
|
| 437 |
+
raise ValueError(f"Invalid neuron selection type: {self.neuron_select_type}")
|
| 438 |
+
return neuron_select_type_out, neuron_select_type_action
|
| 439 |
+
|
| 440 |
+
# --- Utilty Methods ---
|
| 441 |
+
|
| 442 |
+
def verify_args(self):
|
| 443 |
+
"""
|
| 444 |
+
Verify the validity of the input arguments to ensure consistent behaviour.
|
| 445 |
+
Specifically when selecting neurons for sychronisation using 'first-last' or 'random',
|
| 446 |
+
one needs the right number of neurons
|
| 447 |
+
"""
|
| 448 |
+
assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
|
| 449 |
+
f"Invalid neuron selection type: {self.neuron_select_type}"
|
| 450 |
+
|
| 451 |
+
assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \
|
| 452 |
+
f"Invalid backbone_type: {self.backbone_type}"
|
| 453 |
+
|
| 454 |
+
assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \
|
| 455 |
+
f"Invalid positional_embedding_type: {self.positional_embedding_type}"
|
| 456 |
+
|
| 457 |
+
if self.neuron_select_type == 'first-last':
|
| 458 |
+
assert self.d_model >= (self.n_synch_out + self.n_synch_action), \
|
| 459 |
+
"d_model must be >= n_synch_out + n_synch_action for neuron subsets"
|
| 460 |
+
|
| 461 |
+
if self.backbone_type=='none' and self.positional_embedding_type!='none':
|
| 462 |
+
raise AssertionError("There should be no positional embedding if there is no backbone.")
|
| 463 |
+
|
| 464 |
+
def calculate_synch_representation_size(self, n_synch):
|
| 465 |
+
"""
|
| 466 |
+
Calculate the size of the synchronisation representation based on neuron selection type.
|
| 467 |
+
"""
|
| 468 |
+
if self.neuron_select_type == 'random-pairing':
|
| 469 |
+
synch_representation_size = n_synch
|
| 470 |
+
elif self.neuron_select_type in ('first-last', 'random'):
|
| 471 |
+
synch_representation_size = (n_synch * (n_synch + 1)) // 2
|
| 472 |
+
else:
|
| 473 |
+
raise ValueError(f"Invalid neuron selection type: {self.neuron_select_type}")
|
| 474 |
+
return synch_representation_size
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def forward(self, x, track=False):
|
| 480 |
+
B = x.size(0)
|
| 481 |
+
device = x.device
|
| 482 |
+
|
| 483 |
+
# --- Tracking Initialization ---
|
| 484 |
+
pre_activations_tracking = []
|
| 485 |
+
post_activations_tracking = []
|
| 486 |
+
synch_out_tracking = []
|
| 487 |
+
synch_action_tracking = []
|
| 488 |
+
attention_tracking = []
|
| 489 |
+
|
| 490 |
+
# --- Featurise Input Data ---
|
| 491 |
+
kv = self.compute_features(x)
|
| 492 |
+
|
| 493 |
+
# --- Initialise Recurrent State ---
|
| 494 |
+
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
|
| 495 |
+
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
|
| 496 |
+
|
| 497 |
+
# --- Prepare Storage for Outputs per Iteration ---
|
| 498 |
+
predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=torch.float32)
|
| 499 |
+
certainties = torch.empty(B, 2, self.iterations, device=device, dtype=torch.float32)
|
| 500 |
+
|
| 501 |
+
# --- Initialise Recurrent Synch Values ---
|
| 502 |
+
decay_alpha_action, decay_beta_action = None, None
|
| 503 |
+
r_action, r_out = torch.exp(-torch.clamp(self.decay_params_action, 0, 15)).unsqueeze(0).repeat(B, 1), torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
|
| 504 |
+
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 505 |
+
# Compute learned weighting for synchronisation
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
# --- Recurrent Loop ---
|
| 509 |
+
for stepi in range(self.iterations):
|
| 510 |
+
|
| 511 |
+
# --- Calculate Synchronisation for Input Data Interaction ---
|
| 512 |
+
synchronisation_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
|
| 513 |
+
|
| 514 |
+
# --- Interact with Data via Attention ---
|
| 515 |
+
q = self.q_proj(synchronisation_action).unsqueeze(1)
|
| 516 |
+
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
|
| 517 |
+
attn_out = attn_out.squeeze(1)
|
| 518 |
+
pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
|
| 519 |
+
|
| 520 |
+
# --- Apply Synapses ---
|
| 521 |
+
state = self.synapses(pre_synapse_input)
|
| 522 |
+
# The 'state_trace' is the history of incoming pre-activations
|
| 523 |
+
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
|
| 524 |
+
|
| 525 |
+
# --- Apply Neuron-Level Models ---
|
| 526 |
+
activated_state = self.trace_processor(state_trace)
|
| 527 |
+
# One would also keep an 'activated_state_trace' as the history of outgoing post-activations
|
| 528 |
+
# BUT, this is unnecessary because the synchronisation calculation is fully linear and can be
|
| 529 |
+
# done using only the currect activated state (see compute_synchronisation method for explanation)
|
| 530 |
+
|
| 531 |
+
# --- Calculate Synchronisation for Output Predictions ---
|
| 532 |
+
synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
|
| 533 |
+
|
| 534 |
+
# --- Get Predictions and Certainties ---
|
| 535 |
+
current_prediction = self.output_projector(synchronisation_out)
|
| 536 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 537 |
+
|
| 538 |
+
predictions[..., stepi] = current_prediction
|
| 539 |
+
certainties[..., stepi] = current_certainty
|
| 540 |
+
|
| 541 |
+
# --- Tracking ---
|
| 542 |
+
if track:
|
| 543 |
+
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
|
| 544 |
+
post_activations_tracking.append(activated_state.detach().cpu().numpy())
|
| 545 |
+
attention_tracking.append(attn_weights.detach().cpu().numpy())
|
| 546 |
+
synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
|
| 547 |
+
synch_action_tracking.append(synchronisation_action.detach().cpu().numpy())
|
| 548 |
+
|
| 549 |
+
# --- Return Values ---
|
| 550 |
+
if track:
|
| 551 |
+
return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
|
| 552 |
+
return predictions, certainties, synchronisation_out
|
models/ctm_qamnist.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from models.ctm import ContinuousThoughtMachine
|
| 4 |
+
from models.modules import MNISTBackbone, QAMNISTIndexEmbeddings, QAMNISTOperatorEmbeddings
|
| 5 |
+
|
| 6 |
+
class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine):
|
| 7 |
+
def __init__(self,
|
| 8 |
+
iterations,
|
| 9 |
+
d_model,
|
| 10 |
+
d_input,
|
| 11 |
+
heads,
|
| 12 |
+
n_synch_out,
|
| 13 |
+
n_synch_action,
|
| 14 |
+
synapse_depth,
|
| 15 |
+
memory_length,
|
| 16 |
+
deep_nlms,
|
| 17 |
+
memory_hidden_dims,
|
| 18 |
+
do_layernorm_nlm,
|
| 19 |
+
out_dims,
|
| 20 |
+
iterations_per_digit,
|
| 21 |
+
iterations_per_question_part,
|
| 22 |
+
iterations_for_answering,
|
| 23 |
+
prediction_reshaper=[-1],
|
| 24 |
+
dropout=0,
|
| 25 |
+
neuron_select_type='first-last',
|
| 26 |
+
n_random_pairing_self=256
|
| 27 |
+
):
|
| 28 |
+
super().__init__(
|
| 29 |
+
iterations=iterations,
|
| 30 |
+
d_model=d_model,
|
| 31 |
+
d_input=d_input,
|
| 32 |
+
heads=heads,
|
| 33 |
+
n_synch_out=n_synch_out,
|
| 34 |
+
n_synch_action=n_synch_action,
|
| 35 |
+
synapse_depth=synapse_depth,
|
| 36 |
+
memory_length=memory_length,
|
| 37 |
+
deep_nlms=deep_nlms,
|
| 38 |
+
memory_hidden_dims=memory_hidden_dims,
|
| 39 |
+
do_layernorm_nlm=do_layernorm_nlm,
|
| 40 |
+
out_dims=out_dims,
|
| 41 |
+
prediction_reshaper=prediction_reshaper,
|
| 42 |
+
dropout=dropout,
|
| 43 |
+
neuron_select_type=neuron_select_type,
|
| 44 |
+
n_random_pairing_self=n_random_pairing_self,
|
| 45 |
+
backbone_type='none',
|
| 46 |
+
positional_embedding_type='none',
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# --- Core Parameters ---
|
| 50 |
+
self.iterations_per_digit = iterations_per_digit
|
| 51 |
+
self.iterations_per_question_part = iterations_per_question_part
|
| 52 |
+
self.iterations_for_answering = iterations_for_answering
|
| 53 |
+
|
| 54 |
+
# --- Setup Methods ---
|
| 55 |
+
|
| 56 |
+
def set_initial_rgb(self):
|
| 57 |
+
"""Set the initial RGB values for the backbone."""
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
def get_d_backbone(self):
|
| 61 |
+
"""Get the dimensionality of the backbone output."""
|
| 62 |
+
return self.d_input
|
| 63 |
+
|
| 64 |
+
def set_backbone(self):
|
| 65 |
+
"""Set the backbone module based on the specified type."""
|
| 66 |
+
self.backbone_digit = MNISTBackbone(self.d_input)
|
| 67 |
+
self.index_backbone = QAMNISTIndexEmbeddings(50, self.d_input)
|
| 68 |
+
self.operator_backbone = QAMNISTOperatorEmbeddings(2, self.d_input)
|
| 69 |
+
pass
|
| 70 |
+
|
| 71 |
+
# --- Utilty Methods ---
|
| 72 |
+
|
| 73 |
+
def determine_step_type(self, total_iterations_for_digits, total_iterations_for_question, stepi: int):
|
| 74 |
+
"""Determine whether the current step is for digits, questions, or answers."""
|
| 75 |
+
is_digit_step = stepi < total_iterations_for_digits
|
| 76 |
+
is_question_step = total_iterations_for_digits <= stepi < total_iterations_for_digits + total_iterations_for_question
|
| 77 |
+
is_answer_step = stepi >= total_iterations_for_digits + total_iterations_for_question
|
| 78 |
+
return is_digit_step, is_question_step, is_answer_step
|
| 79 |
+
|
| 80 |
+
def determine_index_operator_step_type(self, total_iterations_for_digits, stepi: int):
|
| 81 |
+
"""Determine whether the current step is for index or operator."""
|
| 82 |
+
step_within_questions = stepi - total_iterations_for_digits
|
| 83 |
+
if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
|
| 84 |
+
is_index_step = True
|
| 85 |
+
is_operator_step = False
|
| 86 |
+
else:
|
| 87 |
+
is_index_step = False
|
| 88 |
+
is_operator_step = True
|
| 89 |
+
return is_index_step, is_operator_step
|
| 90 |
+
|
| 91 |
+
def get_kv_for_step(self, total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input=None, prev_kv=None):
|
| 92 |
+
"""Get the key-value for the current step."""
|
| 93 |
+
is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
|
| 94 |
+
|
| 95 |
+
if is_digit_step:
|
| 96 |
+
current_input = x[:, stepi]
|
| 97 |
+
if prev_input is not None and torch.equal(current_input, prev_input):
|
| 98 |
+
return prev_kv, prev_input
|
| 99 |
+
kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
|
| 100 |
+
|
| 101 |
+
elif is_question_step:
|
| 102 |
+
offset = stepi - total_iterations_for_digits
|
| 103 |
+
current_input = z[:, offset]
|
| 104 |
+
if prev_input is not None and torch.equal(current_input, prev_input):
|
| 105 |
+
return prev_kv, prev_input
|
| 106 |
+
is_index_step, is_operator_step = self.determine_index_operator_step_type(total_iterations_for_digits, stepi)
|
| 107 |
+
if is_index_step:
|
| 108 |
+
kv = self.index_backbone(current_input)
|
| 109 |
+
elif is_operator_step:
|
| 110 |
+
kv = self.operator_backbone(current_input)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError("Invalid step type for question processing.")
|
| 113 |
+
|
| 114 |
+
elif is_answer_step:
|
| 115 |
+
current_input = None
|
| 116 |
+
kv = torch.zeros((x.size(0), self.d_input), device=x.device)
|
| 117 |
+
|
| 118 |
+
else:
|
| 119 |
+
raise ValueError("Invalid step type.")
|
| 120 |
+
|
| 121 |
+
return kv, current_input
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def forward(self, x, z, track=False):
|
| 127 |
+
B = x.size(0)
|
| 128 |
+
device = x.device
|
| 129 |
+
|
| 130 |
+
# --- Tracking Initialization ---
|
| 131 |
+
pre_activations_tracking = []
|
| 132 |
+
post_activations_tracking = []
|
| 133 |
+
attention_tracking = []
|
| 134 |
+
embedding_tracking = []
|
| 135 |
+
|
| 136 |
+
total_iterations_for_digits = x.size(1)
|
| 137 |
+
total_iterations_for_question = z.size(1)
|
| 138 |
+
total_iterations = total_iterations_for_digits + total_iterations_for_question + self.iterations_for_answering
|
| 139 |
+
|
| 140 |
+
# --- Initialise Recurrent State ---
|
| 141 |
+
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
|
| 142 |
+
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
|
| 143 |
+
|
| 144 |
+
# --- Storage for outputs per iteration ---
|
| 145 |
+
predictions = torch.empty(B, self.out_dims, total_iterations, device=device, dtype=x.dtype)
|
| 146 |
+
certainties = torch.empty(B, 2, total_iterations, device=device, dtype=x.dtype)
|
| 147 |
+
|
| 148 |
+
# --- Initialise Recurrent Synch Values ---
|
| 149 |
+
decay_alpha_action, decay_beta_action = None, None
|
| 150 |
+
r_action, r_out = torch.exp(-torch.clamp(self.decay_params_action, 0, 15)).unsqueeze(0).repeat(B, 1), torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
|
| 151 |
+
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 152 |
+
|
| 153 |
+
prev_input = None
|
| 154 |
+
prev_kv = None
|
| 155 |
+
|
| 156 |
+
# --- Recurrent Loop ---
|
| 157 |
+
for stepi in range(total_iterations):
|
| 158 |
+
is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
|
| 159 |
+
|
| 160 |
+
kv, prev_input = self.get_kv_for_step(total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input, prev_kv)
|
| 161 |
+
prev_kv = kv
|
| 162 |
+
|
| 163 |
+
synchronization_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
|
| 164 |
+
|
| 165 |
+
# --- Interact with Data via Attention ---
|
| 166 |
+
attn_weights = None
|
| 167 |
+
if is_digit_step:
|
| 168 |
+
q = self.q_proj(synchronization_action).unsqueeze(1)
|
| 169 |
+
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
|
| 170 |
+
attn_out = attn_out.squeeze(1)
|
| 171 |
+
pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
|
| 172 |
+
else:
|
| 173 |
+
kv = kv.squeeze(1)
|
| 174 |
+
pre_synapse_input = torch.concatenate((kv, activated_state), dim=-1)
|
| 175 |
+
|
| 176 |
+
# --- Apply Synapses ---
|
| 177 |
+
state = self.synapses(pre_synapse_input)
|
| 178 |
+
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
|
| 179 |
+
|
| 180 |
+
# --- Apply NLMs ---
|
| 181 |
+
activated_state = self.trace_processor(state_trace)
|
| 182 |
+
|
| 183 |
+
# --- Calculate Synchronisation for Output Predictions ---
|
| 184 |
+
synchronization_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
|
| 185 |
+
|
| 186 |
+
# --- Get Predictions and Certainties ---
|
| 187 |
+
current_prediction = self.output_projector(synchronization_out)
|
| 188 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 189 |
+
|
| 190 |
+
predictions[..., stepi] = current_prediction
|
| 191 |
+
certainties[..., stepi] = current_certainty
|
| 192 |
+
|
| 193 |
+
# --- Tracking ---
|
| 194 |
+
if track:
|
| 195 |
+
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
|
| 196 |
+
post_activations_tracking.append(activated_state.detach().cpu().numpy())
|
| 197 |
+
if attn_weights is not None:
|
| 198 |
+
attention_tracking.append(attn_weights.detach().cpu().numpy())
|
| 199 |
+
if is_question_step:
|
| 200 |
+
embedding_tracking.append(kv.detach().cpu().numpy())
|
| 201 |
+
|
| 202 |
+
# --- Return Values ---
|
| 203 |
+
if track:
|
| 204 |
+
return predictions, certainties, synchronization_out, np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
|
| 205 |
+
return predictions, certainties, synchronization_out
|
models/ctm_rl.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
from models.ctm import ContinuousThoughtMachine
|
| 6 |
+
from models.modules import MiniGridBackbone, ClassicControlBackbone, SynapseUNET
|
| 7 |
+
from models.utils import compute_decay
|
| 8 |
+
from models.constants import VALID_NEURON_SELECT_TYPES
|
| 9 |
+
|
| 10 |
+
class ContinuousThoughtMachineRL(ContinuousThoughtMachine):
|
| 11 |
+
def __init__(self,
|
| 12 |
+
iterations,
|
| 13 |
+
d_model,
|
| 14 |
+
d_input,
|
| 15 |
+
n_synch_out,
|
| 16 |
+
synapse_depth,
|
| 17 |
+
memory_length,
|
| 18 |
+
deep_nlms,
|
| 19 |
+
memory_hidden_dims,
|
| 20 |
+
do_layernorm_nlm,
|
| 21 |
+
backbone_type,
|
| 22 |
+
prediction_reshaper=[-1],
|
| 23 |
+
dropout=0,
|
| 24 |
+
neuron_select_type='first-last',
|
| 25 |
+
):
|
| 26 |
+
super().__init__(
|
| 27 |
+
iterations=iterations,
|
| 28 |
+
d_model=d_model,
|
| 29 |
+
d_input=d_input,
|
| 30 |
+
heads=0, # Set heads to 0 will return None
|
| 31 |
+
n_synch_out=n_synch_out,
|
| 32 |
+
n_synch_action=0,
|
| 33 |
+
synapse_depth=synapse_depth,
|
| 34 |
+
memory_length=memory_length,
|
| 35 |
+
deep_nlms=deep_nlms,
|
| 36 |
+
memory_hidden_dims=memory_hidden_dims,
|
| 37 |
+
do_layernorm_nlm=do_layernorm_nlm,
|
| 38 |
+
out_dims=0,
|
| 39 |
+
prediction_reshaper=prediction_reshaper,
|
| 40 |
+
dropout=dropout,
|
| 41 |
+
neuron_select_type=neuron_select_type,
|
| 42 |
+
backbone_type=backbone_type,
|
| 43 |
+
n_random_pairing_self=0,
|
| 44 |
+
positional_embedding_type='none',
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# --- Use a minimal CTM w/out input (action) synch ---
|
| 48 |
+
self.neuron_select_type_action = None
|
| 49 |
+
self.synch_representation_size_action = None
|
| 50 |
+
|
| 51 |
+
# --- Start dynamics with a learned activated state trace ---
|
| 52 |
+
self.register_parameter('start_activated_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True))
|
| 53 |
+
self.start_activated_state = None
|
| 54 |
+
|
| 55 |
+
self.register_buffer('diagonal_mask_out', torch.triu(torch.ones(self.n_synch_out, self.n_synch_out, dtype=torch.bool)))
|
| 56 |
+
|
| 57 |
+
self.attention = None # Should already be None because super(... heads=0... )
|
| 58 |
+
self.q_proj = None # Should already be None because super(... heads=0... )
|
| 59 |
+
self.kv_proj = None # Should already be None because super(... heads=0... )
|
| 60 |
+
self.output_projector = None
|
| 61 |
+
|
| 62 |
+
# --- Core CTM Methods ---
|
| 63 |
+
|
| 64 |
+
def compute_synchronisation(self, activated_state_trace):
|
| 65 |
+
"""Compute the synchronisation between neurons."""
|
| 66 |
+
assert self.neuron_select_type == "first-last", "only fisrst-last neuron selection is supported here"
|
| 67 |
+
# For RL tasks we track a sliding window of activations from which we compute synchronisation
|
| 68 |
+
S = activated_state_trace.permute(0, 2, 1)
|
| 69 |
+
diagonal_mask = self.diagonal_mask_out.to(S.device)
|
| 70 |
+
decay = compute_decay(S.size(1), self.decay_params_out, clamp_lims=(0, 4))
|
| 71 |
+
synchronisation = ((decay.unsqueeze(0) *(S[:,:,-self.n_synch_out:].unsqueeze(-1) * S[:,:,-self.n_synch_out:].unsqueeze(-2))[:,:,diagonal_mask]).sum(1))/torch.sqrt(decay.unsqueeze(0).sum(1,))
|
| 72 |
+
return synchronisation
|
| 73 |
+
|
| 74 |
+
# --- Setup Methods ---
|
| 75 |
+
|
| 76 |
+
def set_initial_rgb(self):
|
| 77 |
+
"""Set the initial RGB values for the backbone."""
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
def get_d_backbone(self):
|
| 81 |
+
"""Get the dimensionality of the backbone output."""
|
| 82 |
+
return self.d_input
|
| 83 |
+
|
| 84 |
+
def set_backbone(self):
|
| 85 |
+
"""Set the backbone module based on the specified type."""
|
| 86 |
+
if self.backbone_type == 'navigation-backbone':
|
| 87 |
+
self.backbone = MiniGridBackbone(self.d_input)
|
| 88 |
+
elif self.backbone_type == 'classic-control-backbone':
|
| 89 |
+
self.backbone = ClassicControlBackbone(self.d_input)
|
| 90 |
+
else:
|
| 91 |
+
raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
def get_positional_embedding(self, d_backbone):
|
| 95 |
+
"""Get the positional embedding module."""
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_synapses(self, synapse_depth, d_model, dropout):
|
| 100 |
+
"""
|
| 101 |
+
Get the synapse module.
|
| 102 |
+
|
| 103 |
+
We found in our early experimentation that a single Linear, GLU and LayerNorm block performed worse than two blocks.
|
| 104 |
+
For that reason we set the default synapse depth to two blocks.
|
| 105 |
+
|
| 106 |
+
TODO: This is legacy and needs further experimentation to iron out.
|
| 107 |
+
"""
|
| 108 |
+
if synapse_depth == 1:
|
| 109 |
+
return nn.Sequential(
|
| 110 |
+
nn.Dropout(dropout),
|
| 111 |
+
nn.LazyLinear(d_model*2),
|
| 112 |
+
nn.GLU(),
|
| 113 |
+
nn.LayerNorm(d_model),
|
| 114 |
+
nn.LazyLinear(d_model*2),
|
| 115 |
+
nn.GLU(),
|
| 116 |
+
nn.LayerNorm(d_model)
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
return SynapseUNET(d_model, synapse_depth, 16, dropout)
|
| 120 |
+
|
| 121 |
+
def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
|
| 122 |
+
"""Set the parameters for the synchronisation of neurons."""
|
| 123 |
+
if synch_type == 'action':
|
| 124 |
+
pass
|
| 125 |
+
elif synch_type == 'out':
|
| 126 |
+
left, right = self.initialize_left_right_neurons("out", self.d_model, n_synch, n_random_pairing_self)
|
| 127 |
+
self.register_buffer(f'out_neuron_indices_left', left)
|
| 128 |
+
self.register_buffer(f'out_neuron_indices_right', right)
|
| 129 |
+
self.register_parameter(f'decay_params_out', nn.Parameter(torch.zeros(self.synch_representation_size_out), requires_grad=True))
|
| 130 |
+
pass
|
| 131 |
+
else:
|
| 132 |
+
raise ValueError(f"Invalid synch_type: {synch_type}")
|
| 133 |
+
|
| 134 |
+
# --- Utilty Methods ---
|
| 135 |
+
|
| 136 |
+
def verify_args(self):
|
| 137 |
+
"""Verify the validity of the input arguments."""
|
| 138 |
+
assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
|
| 139 |
+
f"Invalid neuron selection type: {self.neuron_select_type}"
|
| 140 |
+
assert self.neuron_select_type != 'random-pairing', \
|
| 141 |
+
f"Random pairing is not supported for RL."
|
| 142 |
+
assert self.backbone_type in ('navigation-backbone', 'classic-control-backbone'), \
|
| 143 |
+
f"Invalid backbone_type: {self.backbone_type}"
|
| 144 |
+
assert self.d_model >= (self.n_synch_out), \
|
| 145 |
+
"d_model must be >= n_synch_out for neuron subsets"
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def forward(self, x, hidden_states, track=False):
|
| 152 |
+
|
| 153 |
+
# --- Tracking Initialization ---
|
| 154 |
+
pre_activations_tracking = []
|
| 155 |
+
post_activations_tracking = []
|
| 156 |
+
|
| 157 |
+
# --- Featurise Input Data ---
|
| 158 |
+
features = self.backbone(x)
|
| 159 |
+
|
| 160 |
+
# --- Get Recurrent State ---
|
| 161 |
+
state_trace, activated_state_trace = hidden_states
|
| 162 |
+
|
| 163 |
+
# --- Recurrent Loop ---
|
| 164 |
+
for stepi in range(self.iterations):
|
| 165 |
+
|
| 166 |
+
pre_synapse_input = torch.concatenate((features.reshape(x.size(0), -1), activated_state_trace[:,:,-1]), -1)
|
| 167 |
+
|
| 168 |
+
# --- Apply Synapses ---
|
| 169 |
+
state = self.synapses(pre_synapse_input)
|
| 170 |
+
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
|
| 171 |
+
|
| 172 |
+
# --- Apply NLMs ---
|
| 173 |
+
activated_state = self.trace_processor(state_trace)
|
| 174 |
+
activated_state_trace = torch.concatenate((activated_state_trace[:,:,1:], activated_state.unsqueeze(-1)), -1)
|
| 175 |
+
|
| 176 |
+
# --- Tracking ---
|
| 177 |
+
if track:
|
| 178 |
+
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
|
| 179 |
+
post_activations_tracking.append(activated_state.detach().cpu().numpy())
|
| 180 |
+
|
| 181 |
+
hidden_states = (
|
| 182 |
+
state_trace,
|
| 183 |
+
activated_state_trace,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
# --- Calculate Output Synchronisation ---
|
| 187 |
+
synchronisation_out = self.compute_synchronisation(activated_state_trace)
|
| 188 |
+
|
| 189 |
+
# --- Return Values ---
|
| 190 |
+
if track:
|
| 191 |
+
return synchronisation_out, hidden_states, np.array(pre_activations_tracking), np.array(post_activations_tracking)
|
| 192 |
+
return synchronisation_out, hidden_states
|
models/ctm_sort.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from models.ctm import ContinuousThoughtMachine
|
| 4 |
+
|
| 5 |
+
class ContinuousThoughtMachineSORT(ContinuousThoughtMachine):
|
| 6 |
+
"""
|
| 7 |
+
Slight adaption of the CTM to work with the sort task.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self,
|
| 11 |
+
iterations,
|
| 12 |
+
d_model,
|
| 13 |
+
d_input,
|
| 14 |
+
heads,
|
| 15 |
+
n_synch_out,
|
| 16 |
+
n_synch_action,
|
| 17 |
+
synapse_depth,
|
| 18 |
+
memory_length,
|
| 19 |
+
deep_nlms,
|
| 20 |
+
memory_hidden_dims,
|
| 21 |
+
do_layernorm_nlm,
|
| 22 |
+
backbone_type,
|
| 23 |
+
positional_embedding_type,
|
| 24 |
+
out_dims,
|
| 25 |
+
prediction_reshaper=[-1],
|
| 26 |
+
dropout=0,
|
| 27 |
+
dropout_nlm=None,
|
| 28 |
+
neuron_select_type='random-pairing',
|
| 29 |
+
n_random_pairing_self=0,
|
| 30 |
+
):
|
| 31 |
+
super().__init__(
|
| 32 |
+
iterations=iterations,
|
| 33 |
+
d_model=d_model,
|
| 34 |
+
d_input=d_input,
|
| 35 |
+
heads=0,
|
| 36 |
+
n_synch_out=n_synch_out,
|
| 37 |
+
n_synch_action=0,
|
| 38 |
+
synapse_depth=synapse_depth,
|
| 39 |
+
memory_length=memory_length,
|
| 40 |
+
deep_nlms=deep_nlms,
|
| 41 |
+
memory_hidden_dims=memory_hidden_dims,
|
| 42 |
+
do_layernorm_nlm=do_layernorm_nlm,
|
| 43 |
+
backbone_type='none',
|
| 44 |
+
positional_embedding_type='none',
|
| 45 |
+
out_dims=out_dims,
|
| 46 |
+
prediction_reshaper=prediction_reshaper,
|
| 47 |
+
dropout=dropout,
|
| 48 |
+
dropout_nlm=dropout_nlm,
|
| 49 |
+
neuron_select_type=neuron_select_type,
|
| 50 |
+
n_random_pairing_self=n_random_pairing_self,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# --- Use a minimal CTM w/out input (action) synch ---
|
| 54 |
+
self.neuron_select_type_action = None
|
| 55 |
+
self.synch_representation_size_action = None
|
| 56 |
+
|
| 57 |
+
self.attention = None # Should already be None because super(... heads=0... )
|
| 58 |
+
self.q_proj = None # Should already be None because super(... heads=0... )
|
| 59 |
+
self.kv_proj = None # Should already be None because super(... heads=0... )
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def forward(self, x, track=False):
|
| 65 |
+
B = x.size(0)
|
| 66 |
+
device = x.device
|
| 67 |
+
|
| 68 |
+
# --- Tracking Initialization ---
|
| 69 |
+
pre_activations_tracking = []
|
| 70 |
+
post_activations_tracking = []
|
| 71 |
+
synch_out_tracking = []
|
| 72 |
+
attention_tracking = []
|
| 73 |
+
|
| 74 |
+
# --- For SORT: no need to featurise data ---
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# --- Initialise Recurrent State ---
|
| 78 |
+
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
|
| 79 |
+
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
|
| 80 |
+
|
| 81 |
+
# --- Prepare Storage for Outputs per Iteration ---
|
| 82 |
+
predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
|
| 83 |
+
certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
|
| 84 |
+
|
| 85 |
+
# --- Initialise Recurrent Synch Values ---
|
| 86 |
+
r_out = torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
|
| 87 |
+
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
|
| 88 |
+
# Compute learned weighting for synchronisation
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# --- Recurrent Loop ---
|
| 92 |
+
for stepi in range(self.iterations):
|
| 93 |
+
|
| 94 |
+
pre_synapse_input = torch.concatenate((x, activated_state), dim=-1)
|
| 95 |
+
|
| 96 |
+
# --- Apply Synapses ---
|
| 97 |
+
state = self.synapses(pre_synapse_input)
|
| 98 |
+
# The 'state_trace' is the history of incoming pre-activations
|
| 99 |
+
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
|
| 100 |
+
|
| 101 |
+
# --- Apply Neuron-Level Models ---
|
| 102 |
+
activated_state = self.trace_processor(state_trace)
|
| 103 |
+
# One would also keep an 'activated_state_trace' as the history of outgoing post-activations
|
| 104 |
+
# BUT, this is unnecessary because the synchronisation calculation is fully linear and can be
|
| 105 |
+
# done using only the currect activated state (see compute_synchronisation method for explanation)
|
| 106 |
+
|
| 107 |
+
# --- Calculate Synchronisation for Output Predictions ---
|
| 108 |
+
synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
|
| 109 |
+
|
| 110 |
+
# --- Get Predictions and Certainties ---
|
| 111 |
+
current_prediction = self.output_projector(synchronisation_out)
|
| 112 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 113 |
+
|
| 114 |
+
predictions[..., stepi] = current_prediction
|
| 115 |
+
certainties[..., stepi] = current_certainty
|
| 116 |
+
|
| 117 |
+
# --- Tracking ---
|
| 118 |
+
if track:
|
| 119 |
+
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
|
| 120 |
+
post_activations_tracking.append(activated_state.detach().cpu().numpy())
|
| 121 |
+
synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
|
| 122 |
+
|
| 123 |
+
# --- Return Values ---
|
| 124 |
+
if track:
|
| 125 |
+
return predictions, certainties, np.array(synch_out_tracking), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
|
| 126 |
+
return predictions, certainties, synchronisation_out
|
models/ff.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
# Local imports (Assuming these contain necessary custom modules)
|
| 4 |
+
from models.modules import *
|
| 5 |
+
from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class FFBaseline(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
LSTM Baseline.
|
| 11 |
+
|
| 12 |
+
Wrapper that lets us use the same backbone as the CTM and LSTM baselines, with a
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
d_model (int): workaround that projects final layer to this space so that parameter-matching is plausible.
|
| 17 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 18 |
+
out_dims (int): Dimensionality of the final output projection.
|
| 19 |
+
dropout (float): dropout in last layer
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self,
|
| 23 |
+
d_model,
|
| 24 |
+
backbone_type,
|
| 25 |
+
out_dims,
|
| 26 |
+
dropout=0,
|
| 27 |
+
):
|
| 28 |
+
super(FFBaseline, self).__init__()
|
| 29 |
+
|
| 30 |
+
# --- Core Parameters ---
|
| 31 |
+
self.d_model = d_model
|
| 32 |
+
self.backbone_type = backbone_type
|
| 33 |
+
self.out_dims = out_dims
|
| 34 |
+
|
| 35 |
+
# --- Input Assertions ---
|
| 36 |
+
assert backbone_type in ['resnet18-1', 'resnet18-2', 'resnet18-3', 'resnet18-4',
|
| 37 |
+
'resnet34-1', 'resnet34-2', 'resnet34-3', 'resnet34-4',
|
| 38 |
+
'resnet50-1', 'resnet50-2', 'resnet50-3', 'resnet50-4',
|
| 39 |
+
'resnet101-1', 'resnet101-2', 'resnet101-3', 'resnet101-4',
|
| 40 |
+
'resnet152-1', 'resnet152-2', 'resnet152-3', 'resnet152-4',
|
| 41 |
+
'none', 'shallow-wide', 'parity_backbone'], f"Invalid backbone_type: {backbone_type}"
|
| 42 |
+
|
| 43 |
+
# --- Backbone / Feature Extraction ---
|
| 44 |
+
self.initial_rgb = Identity() # Placeholder, potentially replaced if using ResNet
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
|
| 48 |
+
resnet_family = resnet18 # Default
|
| 49 |
+
if '34' in self.backbone_type: resnet_family = resnet34
|
| 50 |
+
if '50' in self.backbone_type: resnet_family = resnet50
|
| 51 |
+
if '101' in self.backbone_type: resnet_family = resnet101
|
| 52 |
+
if '152' in self.backbone_type: resnet_family = resnet152
|
| 53 |
+
|
| 54 |
+
# Determine which ResNet blocks to keep
|
| 55 |
+
block_num_str = self.backbone_type.split('-')[-1]
|
| 56 |
+
hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]
|
| 57 |
+
|
| 58 |
+
self.backbone = resnet_family(
|
| 59 |
+
3, # initial_rgb handles input channels now
|
| 60 |
+
hyper_blocks_to_keep,
|
| 61 |
+
stride=2,
|
| 62 |
+
pretrained=False,
|
| 63 |
+
progress=True,
|
| 64 |
+
device="cpu", # Initialise on CPU, move later via .to(device)
|
| 65 |
+
do_initial_max_pool=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# At this point we will have a 4D tensor of features: [B, C, H, W]
|
| 70 |
+
# The following lets us scale up the resnet with d_model until it matches the CTM
|
| 71 |
+
self.output_projector = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), Squeeze(-1), Squeeze(-1), nn.LazyLinear(d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, out_dims))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def forward(self, x):
|
| 75 |
+
return self.output_projector((self.backbone(self.initial_rgb(x))))
|
models/lstm.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
from models.modules import ParityBackbone, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
|
| 7 |
+
from models.resnet import prepare_resnet_backbone
|
| 8 |
+
from models.utils import compute_normalized_entropy
|
| 9 |
+
|
| 10 |
+
from models.constants import (
|
| 11 |
+
VALID_BACKBONE_TYPES,
|
| 12 |
+
VALID_POSITIONAL_EMBEDDING_TYPES
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
class LSTMBaseline(nn.Module):
|
| 16 |
+
"""
|
| 17 |
+
LSTM Baseline
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
iterations (int): Number of internal 'thought' steps (T, in paper).
|
| 21 |
+
d_model (int): Core dimensionality of the latent space.
|
| 22 |
+
d_input (int): Dimensionality of projected attention outputs or direct input features.
|
| 23 |
+
heads (int): Number of attention heads.
|
| 24 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 25 |
+
positional_embedding_type (str): Type of positional embedding for backbone features.
|
| 26 |
+
out_dims (int): Dimensionality of the final output projection.
|
| 27 |
+
prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
|
| 28 |
+
dropout (float): Dropout rate.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self,
|
| 32 |
+
iterations,
|
| 33 |
+
d_model,
|
| 34 |
+
d_input,
|
| 35 |
+
heads,
|
| 36 |
+
backbone_type,
|
| 37 |
+
num_layers,
|
| 38 |
+
positional_embedding_type,
|
| 39 |
+
out_dims,
|
| 40 |
+
prediction_reshaper=[-1],
|
| 41 |
+
dropout=0,
|
| 42 |
+
):
|
| 43 |
+
super(LSTMBaseline, self).__init__()
|
| 44 |
+
|
| 45 |
+
# --- Core Parameters ---
|
| 46 |
+
self.iterations = iterations
|
| 47 |
+
self.d_model = d_model
|
| 48 |
+
self.d_input = d_input
|
| 49 |
+
self.prediction_reshaper = prediction_reshaper
|
| 50 |
+
self.backbone_type = backbone_type
|
| 51 |
+
self.positional_embedding_type = positional_embedding_type
|
| 52 |
+
self.out_dims = out_dims
|
| 53 |
+
|
| 54 |
+
# --- Assertions ---
|
| 55 |
+
self.verify_args()
|
| 56 |
+
|
| 57 |
+
# --- Input Processing ---
|
| 58 |
+
d_backbone = self.get_d_backbone()
|
| 59 |
+
|
| 60 |
+
self.set_initial_rgb()
|
| 61 |
+
self.set_backbone()
|
| 62 |
+
self.positional_embedding = self.get_positional_embedding(d_backbone)
|
| 63 |
+
self.kv_proj = self.get_kv_proj()
|
| 64 |
+
self.lstm = nn.LSTM(d_input, d_model, num_layers, batch_first=True, dropout=dropout)
|
| 65 |
+
self.q_proj = self.get_q_proj()
|
| 66 |
+
self.attention = self.get_attention(heads, dropout)
|
| 67 |
+
self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
|
| 68 |
+
|
| 69 |
+
# --- Start States ---
|
| 70 |
+
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 71 |
+
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# --- Core LSTM Methods ---
|
| 76 |
+
|
| 77 |
+
def compute_features(self, x):
|
| 78 |
+
"""Applies backbone and positional embedding to input."""
|
| 79 |
+
x = self.initial_rgb(x)
|
| 80 |
+
self.kv_features = self.backbone(x)
|
| 81 |
+
pos_emb = self.positional_embedding(self.kv_features)
|
| 82 |
+
combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2)
|
| 83 |
+
kv = self.kv_proj(combined_features)
|
| 84 |
+
return kv
|
| 85 |
+
|
| 86 |
+
def compute_certainty(self, current_prediction):
|
| 87 |
+
"""Compute the certainty of the current prediction."""
|
| 88 |
+
B = current_prediction.size(0)
|
| 89 |
+
reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
|
| 90 |
+
ne = compute_normalized_entropy(reshaped_pred)
|
| 91 |
+
current_certainty = torch.stack((ne, 1-ne), -1)
|
| 92 |
+
return current_certainty
|
| 93 |
+
|
| 94 |
+
# --- Setup Methods ---
|
| 95 |
+
|
| 96 |
+
def set_initial_rgb(self):
|
| 97 |
+
"""Set the initial RGB processing module based on the backbone type."""
|
| 98 |
+
if 'resnet' in self.backbone_type:
|
| 99 |
+
self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
|
| 100 |
+
else:
|
| 101 |
+
self.initial_rgb = nn.Identity()
|
| 102 |
+
|
| 103 |
+
def get_d_backbone(self):
|
| 104 |
+
"""
|
| 105 |
+
Get the dimensionality of the backbone output, to be used for positional embedding setup.
|
| 106 |
+
|
| 107 |
+
This is a little bit complicated for resnets, but the logic should be easy enough to read below.
|
| 108 |
+
"""
|
| 109 |
+
if self.backbone_type == 'shallow-wide':
|
| 110 |
+
return 2048
|
| 111 |
+
elif self.backbone_type == 'parity_backbone':
|
| 112 |
+
return self.d_input
|
| 113 |
+
elif 'resnet' in self.backbone_type:
|
| 114 |
+
if '18' in self.backbone_type or '34' in self.backbone_type:
|
| 115 |
+
if self.backbone_type.split('-')[1]=='1': return 64
|
| 116 |
+
elif self.backbone_type.split('-')[1]=='2': return 128
|
| 117 |
+
elif self.backbone_type.split('-')[1]=='3': return 256
|
| 118 |
+
elif self.backbone_type.split('-')[1]=='4': return 512
|
| 119 |
+
else:
|
| 120 |
+
raise NotImplementedError
|
| 121 |
+
else:
|
| 122 |
+
if self.backbone_type.split('-')[1]=='1': return 256
|
| 123 |
+
elif self.backbone_type.split('-')[1]=='2': return 512
|
| 124 |
+
elif self.backbone_type.split('-')[1]=='3': return 1024
|
| 125 |
+
elif self.backbone_type.split('-')[1]=='4': return 2048
|
| 126 |
+
else:
|
| 127 |
+
raise NotImplementedError
|
| 128 |
+
elif self.backbone_type == 'none':
|
| 129 |
+
return None
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
|
| 132 |
+
|
| 133 |
+
def set_backbone(self):
|
| 134 |
+
"""Set the backbone module based on the specified type."""
|
| 135 |
+
if self.backbone_type == 'shallow-wide':
|
| 136 |
+
self.backbone = ShallowWide()
|
| 137 |
+
elif self.backbone_type == 'parity_backbone':
|
| 138 |
+
d_backbone = self.get_d_backbone()
|
| 139 |
+
self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone)
|
| 140 |
+
elif 'resnet' in self.backbone_type:
|
| 141 |
+
self.backbone = prepare_resnet_backbone(self.backbone_type)
|
| 142 |
+
elif self.backbone_type == 'none':
|
| 143 |
+
self.backbone = nn.Identity()
|
| 144 |
+
else:
|
| 145 |
+
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
|
| 146 |
+
|
| 147 |
+
def get_positional_embedding(self, d_backbone):
|
| 148 |
+
"""Get the positional embedding module."""
|
| 149 |
+
if self.positional_embedding_type == 'learnable-fourier':
|
| 150 |
+
return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5)
|
| 151 |
+
elif self.positional_embedding_type == 'multi-learnable-fourier':
|
| 152 |
+
return MultiLearnableFourierPositionalEncoding(d_backbone)
|
| 153 |
+
elif self.positional_embedding_type == 'custom-rotational':
|
| 154 |
+
return CustomRotationalEmbedding(d_backbone)
|
| 155 |
+
elif self.positional_embedding_type == 'custom-rotational-1d':
|
| 156 |
+
return CustomRotationalEmbedding1D(d_backbone)
|
| 157 |
+
elif self.positional_embedding_type == 'none':
|
| 158 |
+
return lambda x: 0 # Default no-op
|
| 159 |
+
else:
|
| 160 |
+
raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}")
|
| 161 |
+
|
| 162 |
+
def get_attention(self, heads, dropout):
|
| 163 |
+
"""Get the attention module."""
|
| 164 |
+
return nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True)
|
| 165 |
+
|
| 166 |
+
def get_kv_proj(self):
|
| 167 |
+
"""Get the key-value projection module."""
|
| 168 |
+
return nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input))
|
| 169 |
+
|
| 170 |
+
def get_q_proj(self):
|
| 171 |
+
"""Get the query projection module."""
|
| 172 |
+
return nn.LazyLinear(self.d_input)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def verify_args(self):
|
| 176 |
+
"""Verify the validity of the input arguments."""
|
| 177 |
+
|
| 178 |
+
assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \
|
| 179 |
+
f"Invalid backbone_type: {self.backbone_type}"
|
| 180 |
+
|
| 181 |
+
assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \
|
| 182 |
+
f"Invalid positional_embedding_type: {self.positional_embedding_type}"
|
| 183 |
+
|
| 184 |
+
if self.backbone_type=='none' and self.positional_embedding_type!='none':
|
| 185 |
+
raise AssertionError("There should be no positional embedding if there is no backbone.")
|
| 186 |
+
|
| 187 |
+
pass
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def forward(self, x, track=False):
|
| 193 |
+
"""
|
| 194 |
+
Forward pass - Reverted to structure closer to user's working version.
|
| 195 |
+
Executes T=iterations steps.
|
| 196 |
+
"""
|
| 197 |
+
B = x.size(0)
|
| 198 |
+
device = x.device
|
| 199 |
+
|
| 200 |
+
# --- Tracking Initialization ---
|
| 201 |
+
activations_tracking = []
|
| 202 |
+
attention_tracking = []
|
| 203 |
+
|
| 204 |
+
# --- Featurise Input Data ---
|
| 205 |
+
kv = self.compute_features(x)
|
| 206 |
+
|
| 207 |
+
# --- Initialise Recurrent State ---
|
| 208 |
+
hn = torch.repeat_interleave(self.start_hidden_state.unsqueeze(1), x.size(0), 1)
|
| 209 |
+
cn = torch.repeat_interleave(self.start_cell_state.unsqueeze(1), x.size(0), 1)
|
| 210 |
+
state_trace = [hn[-1]]
|
| 211 |
+
|
| 212 |
+
# --- Prepare Storage for Outputs per Iteration ---
|
| 213 |
+
predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
|
| 214 |
+
certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
|
| 215 |
+
|
| 216 |
+
# --- Recurrent Loop ---
|
| 217 |
+
for stepi in range(self.iterations):
|
| 218 |
+
|
| 219 |
+
# --- Interact with Data via Attention ---
|
| 220 |
+
q = self.q_proj(hn[-1].unsqueeze(1))
|
| 221 |
+
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
|
| 222 |
+
lstm_input = attn_out
|
| 223 |
+
|
| 224 |
+
# --- Apply LSTM ---
|
| 225 |
+
hidden_state, (hn,cn) = self.lstm(lstm_input, (hn, cn))
|
| 226 |
+
hidden_state = hidden_state.squeeze(1)
|
| 227 |
+
state_trace.append(hidden_state)
|
| 228 |
+
|
| 229 |
+
# --- Get Predictions and Certainties ---
|
| 230 |
+
current_prediction = self.output_projector(hidden_state)
|
| 231 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 232 |
+
|
| 233 |
+
predictions[..., stepi] = current_prediction
|
| 234 |
+
certainties[..., stepi] = current_certainty
|
| 235 |
+
|
| 236 |
+
# --- Tracking ---
|
| 237 |
+
if track:
|
| 238 |
+
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
|
| 239 |
+
attention_tracking.append(attn_weights.detach().cpu().numpy())
|
| 240 |
+
|
| 241 |
+
# --- Return Values ---
|
| 242 |
+
if track:
|
| 243 |
+
return predictions, certainties, None, np.zeros_like(activations_tracking), np.array(activations_tracking), np.array(attention_tracking)
|
| 244 |
+
return predictions, certainties, None
|
models/lstm_qamnist.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F # Used for GLU if not in modules
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
# Local imports (Assuming these contain necessary custom modules)
|
| 8 |
+
from models.modules import *
|
| 9 |
+
from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
|
| 10 |
+
|
| 11 |
+
class LSTMBaseline(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
LSTM Baseline
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
iterations (int): Number of internal 'thought' steps (T, in paper).
|
| 17 |
+
d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
|
| 18 |
+
d_input (int): Dimensionality of projected attention outputs or direct input features.
|
| 19 |
+
heads (int): Number of attention heads.
|
| 20 |
+
n_synch_out (int): Number of neurons used for output synchronisation (No, in paper).
|
| 21 |
+
n_synch_action (int): Number of neurons used for action/attention synchronisation (Ni, in paper).
|
| 22 |
+
synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
|
| 23 |
+
memory_length (int): History length for Neuron-Level Models (M, in paper).
|
| 24 |
+
deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
|
| 25 |
+
memory_hidden_dims (int): Hidden dimension size for deep NLMs.
|
| 26 |
+
do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
|
| 27 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 28 |
+
positional_embedding_type (str): Type of positional embedding for backbone features.
|
| 29 |
+
out_dims (int): Dimensionality of the final output projection.
|
| 30 |
+
prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
|
| 31 |
+
dropout (float): Dropout rate.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self,
|
| 35 |
+
iterations,
|
| 36 |
+
d_model,
|
| 37 |
+
d_input,
|
| 38 |
+
heads,
|
| 39 |
+
out_dims,
|
| 40 |
+
iterations_per_digit,
|
| 41 |
+
iterations_per_question_part,
|
| 42 |
+
iterations_for_answering,
|
| 43 |
+
prediction_reshaper=[-1],
|
| 44 |
+
dropout=0,
|
| 45 |
+
):
|
| 46 |
+
super(LSTMBaseline, self).__init__()
|
| 47 |
+
|
| 48 |
+
# --- Core Parameters ---
|
| 49 |
+
self.iterations = iterations
|
| 50 |
+
self.d_model = d_model
|
| 51 |
+
self.prediction_reshaper = prediction_reshaper
|
| 52 |
+
self.out_dims = out_dims
|
| 53 |
+
self.d_input = d_input
|
| 54 |
+
self.backbone_type = 'qamnist_backbone'
|
| 55 |
+
self.iterations_per_digit = iterations_per_digit
|
| 56 |
+
self.iterations_per_question_part = iterations_per_question_part
|
| 57 |
+
self.total_iterations_for_answering = iterations_for_answering
|
| 58 |
+
|
| 59 |
+
# --- Backbone / Feature Extraction ---
|
| 60 |
+
self.backbone_digit = MNISTBackbone(d_input)
|
| 61 |
+
self.index_backbone = QAMNISTIndexEmbeddings(50, d_input)
|
| 62 |
+
self.operator_backbone = QAMNISTOperatorEmbeddings(2, d_input)
|
| 63 |
+
|
| 64 |
+
# --- Core CTM Modules ---
|
| 65 |
+
self.lstm_cell = nn.LSTMCell(d_input, d_model)
|
| 66 |
+
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 67 |
+
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 68 |
+
|
| 69 |
+
# Attention
|
| 70 |
+
self.q_proj = nn.LazyLinear(d_input)
|
| 71 |
+
self.kv_proj = nn.Sequential(nn.LazyLinear(d_input), nn.LayerNorm(d_input))
|
| 72 |
+
self.attention = nn.MultiheadAttention(d_input, heads, dropout, batch_first=True)
|
| 73 |
+
|
| 74 |
+
# Output Projection
|
| 75 |
+
self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
|
| 76 |
+
|
| 77 |
+
def compute_certainty(self, current_prediction):
|
| 78 |
+
"""Compute the certainty of the current prediction."""
|
| 79 |
+
B = current_prediction.size(0)
|
| 80 |
+
reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
|
| 81 |
+
ne = compute_normalized_entropy(reshaped_pred)
|
| 82 |
+
current_certainty = torch.stack((ne, 1-ne), -1)
|
| 83 |
+
return current_certainty
|
| 84 |
+
|
| 85 |
+
def get_kv_for_step(self, stepi, x, z, thought_steps, prev_input=None, prev_kv=None):
|
| 86 |
+
is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
|
| 87 |
+
|
| 88 |
+
if is_digit_step:
|
| 89 |
+
current_input = x[:, stepi]
|
| 90 |
+
if prev_input is not None and torch.equal(current_input, prev_input):
|
| 91 |
+
return prev_kv, prev_input
|
| 92 |
+
kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
|
| 93 |
+
|
| 94 |
+
elif is_question_step:
|
| 95 |
+
offset = stepi - thought_steps.total_iterations_for_digits
|
| 96 |
+
current_input = z[:, offset].squeeze(0)
|
| 97 |
+
if prev_input is not None and torch.equal(current_input, prev_input):
|
| 98 |
+
return prev_kv, prev_input
|
| 99 |
+
is_index_step, is_operator_step = thought_steps.determine_answer_step_type(stepi)
|
| 100 |
+
if is_index_step:
|
| 101 |
+
kv = self.kv_proj(self.index_backbone(current_input))
|
| 102 |
+
elif is_operator_step:
|
| 103 |
+
kv = self.kv_proj(self.operator_backbone(current_input))
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError("Invalid step type for question processing.")
|
| 106 |
+
|
| 107 |
+
elif is_answer_step:
|
| 108 |
+
current_input = None
|
| 109 |
+
kv = torch.zeros((x.size(0), self.d_input), device=x.device)
|
| 110 |
+
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError("Invalid step type.")
|
| 113 |
+
|
| 114 |
+
return kv, current_input
|
| 115 |
+
|
| 116 |
+
def forward(self, x, z, track=False):
|
| 117 |
+
"""
|
| 118 |
+
Forward pass - Reverted to structure closer to user's working version.
|
| 119 |
+
Executes T=iterations steps.
|
| 120 |
+
"""
|
| 121 |
+
B = x.size(0) # Batch size
|
| 122 |
+
|
| 123 |
+
# --- Tracking Initialization ---
|
| 124 |
+
activations_tracking = []
|
| 125 |
+
attention_tracking = [] # Note: reshaping this correctly requires knowing num_heads
|
| 126 |
+
embedding_tracking = []
|
| 127 |
+
|
| 128 |
+
thought_steps = ThoughtSteps(self.iterations_per_digit, self.iterations_per_question_part, self.total_iterations_for_answering, x.size(1), z.size(1))
|
| 129 |
+
|
| 130 |
+
# --- Step 2: Initialise Recurrent State ---
|
| 131 |
+
hidden_state = torch.repeat_interleave(self.start_hidden_state.unsqueeze(0), x.size(0), 0)
|
| 132 |
+
cell_state = torch.repeat_interleave(self.start_cell_state.unsqueeze(0), x.size(0), 0)
|
| 133 |
+
|
| 134 |
+
state_trace = [hidden_state]
|
| 135 |
+
|
| 136 |
+
device = hidden_state.device
|
| 137 |
+
|
| 138 |
+
# Storage for outputs per iteration
|
| 139 |
+
predictions = torch.empty(B, self.out_dims, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
|
| 140 |
+
certainties = torch.empty(B, 2, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
|
| 141 |
+
|
| 142 |
+
prev_input = None
|
| 143 |
+
prev_kv = None
|
| 144 |
+
|
| 145 |
+
# --- Recurrent Loop (T=iterations steps) ---
|
| 146 |
+
for stepi in range(thought_steps.total_iterations):
|
| 147 |
+
|
| 148 |
+
is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
|
| 149 |
+
kv, prev_input = self.get_kv_for_step(stepi, x, z, thought_steps, prev_input, prev_kv)
|
| 150 |
+
prev_kv = kv
|
| 151 |
+
|
| 152 |
+
# --- Interact with Data via Attention ---
|
| 153 |
+
attn_weights = None
|
| 154 |
+
if is_digit_step:
|
| 155 |
+
q = self.q_proj(hidden_state).unsqueeze(1)
|
| 156 |
+
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
|
| 157 |
+
lstm_input = attn_out.squeeze(1)
|
| 158 |
+
else:
|
| 159 |
+
lstm_input = kv
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
|
| 164 |
+
state_trace.append(hidden_state)
|
| 165 |
+
|
| 166 |
+
# --- Get Predictions and Certainties ---
|
| 167 |
+
current_prediction = self.output_projector(hidden_state)
|
| 168 |
+
current_certainty = self.compute_certainty(current_prediction)
|
| 169 |
+
|
| 170 |
+
predictions[..., stepi] = current_prediction
|
| 171 |
+
certainties[..., stepi] = current_certainty
|
| 172 |
+
|
| 173 |
+
# --- Tracking ---
|
| 174 |
+
if track:
|
| 175 |
+
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
|
| 176 |
+
if attn_weights is not None:
|
| 177 |
+
attention_tracking.append(attn_weights.detach().cpu().numpy())
|
| 178 |
+
if is_question_step:
|
| 179 |
+
embedding_tracking.append(kv.detach().cpu().numpy())
|
| 180 |
+
|
| 181 |
+
# --- Return Values ---
|
| 182 |
+
if track:
|
| 183 |
+
return predictions, certainties, None, np.array(activations_tracking), np.array(activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
|
| 184 |
+
return predictions, certainties, None
|
models/lstm_rl.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F # Used for GLU if not in modules
|
| 4 |
+
import numpy as np
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
# Local imports (Assuming these contain necessary custom modules)
|
| 8 |
+
from models.modules import *
|
| 9 |
+
from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LSTMBaseline(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
LSTM Baseline
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
iterations (int): Number of internal 'thought' steps (T, in paper).
|
| 19 |
+
d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
|
| 20 |
+
d_input (int): Dimensionality of projected attention outputs or direct input features.
|
| 21 |
+
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self,
|
| 25 |
+
iterations,
|
| 26 |
+
d_model,
|
| 27 |
+
d_input,
|
| 28 |
+
backbone_type,
|
| 29 |
+
):
|
| 30 |
+
super(LSTMBaseline, self).__init__()
|
| 31 |
+
|
| 32 |
+
# --- Core Parameters ---
|
| 33 |
+
self.iterations = iterations
|
| 34 |
+
self.d_model = d_model
|
| 35 |
+
self.backbone_type = backbone_type
|
| 36 |
+
|
| 37 |
+
# --- Input Assertions ---
|
| 38 |
+
assert backbone_type in ('navigation-backbone', 'classic-control-backbone'), f"Invalid backbone_type: {backbone_type}"
|
| 39 |
+
|
| 40 |
+
# --- Backbone / Feature Extraction ---
|
| 41 |
+
if self.backbone_type == 'navigation-backbone':
|
| 42 |
+
grid_size = 7
|
| 43 |
+
self.backbone = MiniGridBackbone(d_input=d_input, grid_size=grid_size)
|
| 44 |
+
lstm_cell_input_dim = grid_size * grid_size * d_input
|
| 45 |
+
|
| 46 |
+
elif self.backbone_type == 'classic-control-backbone':
|
| 47 |
+
self.backbone = ClassicControlBackbone(d_input=d_input)
|
| 48 |
+
lstm_cell_input_dim = d_input
|
| 49 |
+
|
| 50 |
+
else:
|
| 51 |
+
raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
|
| 52 |
+
|
| 53 |
+
# --- Core LSTM Modules ---
|
| 54 |
+
self.lstm_cell = nn.LSTMCell(lstm_cell_input_dim, d_model)
|
| 55 |
+
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 56 |
+
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
|
| 57 |
+
|
| 58 |
+
def compute_features(self, x):
|
| 59 |
+
"""Applies backbone and positional embedding to input."""
|
| 60 |
+
return self.backbone(x)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def forward(self, x, hidden_states, track=False):
|
| 64 |
+
"""
|
| 65 |
+
Forward pass - Reverted to structure closer to user's working version.
|
| 66 |
+
Executes T=iterations steps.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
# --- Tracking Initialization ---
|
| 70 |
+
activations_tracking = []
|
| 71 |
+
|
| 72 |
+
# --- Featurise Input Data ---
|
| 73 |
+
features = self.compute_features(x)
|
| 74 |
+
|
| 75 |
+
hidden_state = hidden_states[0]
|
| 76 |
+
cell_state = hidden_states[1]
|
| 77 |
+
|
| 78 |
+
# --- Recurrent Loop ---
|
| 79 |
+
for stepi in range(self.iterations):
|
| 80 |
+
|
| 81 |
+
lstm_input = features.reshape(x.size(0), -1)
|
| 82 |
+
hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
|
| 83 |
+
|
| 84 |
+
# --- Tracking ---
|
| 85 |
+
if track:
|
| 86 |
+
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
|
| 87 |
+
|
| 88 |
+
hidden_states = (
|
| 89 |
+
hidden_state,
|
| 90 |
+
cell_state
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# --- Return Values ---
|
| 94 |
+
if track:
|
| 95 |
+
return hidden_state, hidden_states, np.array(activations_tracking), np.array(activations_tracking)
|
| 96 |
+
return hidden_state, hidden_states
|
models/modules.py
ADDED
|
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F # Used for GLU
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
# Assuming 'add_coord_dim' is defined in models.utils
|
| 8 |
+
from models.utils import add_coord_dim
|
| 9 |
+
|
| 10 |
+
# --- Basic Utility Modules ---
|
| 11 |
+
|
| 12 |
+
class Identity(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
Identity Module.
|
| 15 |
+
|
| 16 |
+
Returns the input tensor unchanged. Useful as a placeholder or a no-op layer
|
| 17 |
+
in nn.Sequential containers or conditional network parts.
|
| 18 |
+
"""
|
| 19 |
+
def __init__(self):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
return x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Squeeze(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
Squeeze Module.
|
| 29 |
+
|
| 30 |
+
Removes a specified dimension of size 1 from the input tensor.
|
| 31 |
+
Useful for incorporating tensor dimension squeezing within nn.Sequential.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
dim (int): The dimension to squeeze.
|
| 35 |
+
"""
|
| 36 |
+
def __init__(self, dim):
|
| 37 |
+
super().__init__()
|
| 38 |
+
self.dim = dim
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return x.squeeze(self.dim)
|
| 42 |
+
|
| 43 |
+
# --- Core CTM Component Modules ---
|
| 44 |
+
|
| 45 |
+
class SynapseUNET(nn.Module):
|
| 46 |
+
"""
|
| 47 |
+
UNET-style architecture for the Synapse Model (f_theta1 in the paper).
|
| 48 |
+
|
| 49 |
+
This module implements the connections between neurons in the CTM's latent
|
| 50 |
+
space. It processes the combined input (previous post-activation state z^t
|
| 51 |
+
and attention output o^t) to produce the pre-activations (a^t) for the
|
| 52 |
+
next internal tick (Eq. 1 in the paper).
|
| 53 |
+
|
| 54 |
+
While a simpler Linear or MLP layer can be used, the paper notes
|
| 55 |
+
that this U-Net structure empirically performed better, suggesting benefit
|
| 56 |
+
from more flexible synaptic connections[cite: 79, 80]. This implementation
|
| 57 |
+
uses `depth` points in linspace and creates `depth-1` down/up blocks.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
in_dims (int): Number of input dimensions (d_model + d_input).
|
| 61 |
+
out_dims (int): Number of output dimensions (d_model).
|
| 62 |
+
depth (int): Determines structure size; creates `depth-1` down/up blocks.
|
| 63 |
+
minimum_width (int): Smallest channel width at the U-Net bottleneck.
|
| 64 |
+
dropout (float): Dropout rate applied within down/up projections.
|
| 65 |
+
"""
|
| 66 |
+
def __init__(self,
|
| 67 |
+
out_dims,
|
| 68 |
+
depth,
|
| 69 |
+
minimum_width=16,
|
| 70 |
+
dropout=0.0):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.width_out = out_dims
|
| 73 |
+
self.n_deep = depth # Store depth just for reference if needed
|
| 74 |
+
|
| 75 |
+
# Define UNET structure based on depth
|
| 76 |
+
# Creates `depth` width values, leading to `depth-1` blocks
|
| 77 |
+
widths = np.linspace(out_dims, minimum_width, depth)
|
| 78 |
+
|
| 79 |
+
# Initial projection layer
|
| 80 |
+
self.first_projection = nn.Sequential(
|
| 81 |
+
nn.LazyLinear(int(widths[0])), # Project to the first width
|
| 82 |
+
nn.LayerNorm(int(widths[0])),
|
| 83 |
+
nn.SiLU()
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Downward path (encoding layers)
|
| 87 |
+
self.down_projections = nn.ModuleList()
|
| 88 |
+
self.up_projections = nn.ModuleList()
|
| 89 |
+
self.skip_lns = nn.ModuleList()
|
| 90 |
+
num_blocks = len(widths) - 1 # Number of down/up blocks created
|
| 91 |
+
|
| 92 |
+
for i in range(num_blocks):
|
| 93 |
+
# Down block: widths[i] -> widths[i+1]
|
| 94 |
+
self.down_projections.append(nn.Sequential(
|
| 95 |
+
nn.Dropout(dropout),
|
| 96 |
+
nn.Linear(int(widths[i]), int(widths[i+1])),
|
| 97 |
+
nn.LayerNorm(int(widths[i+1])),
|
| 98 |
+
nn.SiLU()
|
| 99 |
+
))
|
| 100 |
+
# Up block: widths[i+1] -> widths[i]
|
| 101 |
+
# Note: Up blocks are added in order matching down blocks conceptually,
|
| 102 |
+
# but applied in reverse order in the forward pass.
|
| 103 |
+
self.up_projections.append(nn.Sequential(
|
| 104 |
+
nn.Dropout(dropout),
|
| 105 |
+
nn.Linear(int(widths[i+1]), int(widths[i])),
|
| 106 |
+
nn.LayerNorm(int(widths[i])),
|
| 107 |
+
nn.SiLU()
|
| 108 |
+
))
|
| 109 |
+
# Skip connection LayerNorm operates on width[i]
|
| 110 |
+
self.skip_lns.append(nn.LayerNorm(int(widths[i])))
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
# Initial projection
|
| 114 |
+
out_first = self.first_projection(x)
|
| 115 |
+
|
| 116 |
+
# Downward path, storing outputs for skip connections
|
| 117 |
+
outs_down = [out_first]
|
| 118 |
+
for layer in self.down_projections:
|
| 119 |
+
outs_down.append(layer(outs_down[-1]))
|
| 120 |
+
# outs_down contains [level_0, level_1, ..., level_depth-1=bottleneck] outputs
|
| 121 |
+
|
| 122 |
+
# Upward path, starting from the bottleneck output
|
| 123 |
+
outs_up = outs_down[-1] # Bottleneck activation
|
| 124 |
+
num_blocks = len(self.up_projections) # Should be depth - 1
|
| 125 |
+
|
| 126 |
+
for i in range(num_blocks):
|
| 127 |
+
# Apply up projection in reverse order relative to down blocks
|
| 128 |
+
# up_projection[num_blocks - 1 - i] processes deeper features first
|
| 129 |
+
up_layer_idx = num_blocks - 1 - i
|
| 130 |
+
out_up = self.up_projections[up_layer_idx](outs_up)
|
| 131 |
+
|
| 132 |
+
# Get corresponding skip connection from downward path
|
| 133 |
+
# skip_connection index = num_blocks - 1 - i (same as up_layer_idx)
|
| 134 |
+
# This matches the output width of the up_projection[up_layer_idx]
|
| 135 |
+
skip_idx = up_layer_idx
|
| 136 |
+
skip_connection = outs_down[skip_idx]
|
| 137 |
+
|
| 138 |
+
# Add skip connection and apply LayerNorm corresponding to this level
|
| 139 |
+
# skip_lns index also corresponds to the level = skip_idx
|
| 140 |
+
outs_up = self.skip_lns[skip_idx](out_up + skip_connection)
|
| 141 |
+
|
| 142 |
+
# The final output after all up-projections
|
| 143 |
+
return outs_up
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class SuperLinear(nn.Module):
|
| 147 |
+
"""
|
| 148 |
+
SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.
|
| 149 |
+
|
| 150 |
+
This layer is the core component enabling Neuron-Level Models (NLMs),
|
| 151 |
+
referred to as g_theta_d in the paper (Eq. 3). It applies N independent
|
| 152 |
+
linear transformations (or small MLPs when used sequentially) to corresponding
|
| 153 |
+
slices of the input tensor along a specified dimension (typically the neuron
|
| 154 |
+
or feature dimension).
|
| 155 |
+
|
| 156 |
+
How it works for NLMs:
|
| 157 |
+
- The input `x` is expected to be the pre-activation history for each neuron,
|
| 158 |
+
shaped (batch_size, n_neurons=N, history_length=in_dims).
|
| 159 |
+
- This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons.
|
| 160 |
+
`w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims).
|
| 161 |
+
- `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix
|
| 162 |
+
multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`):
|
| 163 |
+
- For each neuron `n` (from 0 to N-1):
|
| 164 |
+
- It takes the neuron's history `x[:, n, :]` (shape B, in_dims).
|
| 165 |
+
- Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims).
|
| 166 |
+
- Resulting in `out[:, n, :]` (shape B, out_dims).
|
| 167 |
+
- The unique bias `self.b1[:, n, :]` is added.
|
| 168 |
+
- The result is squeezed on the last dim (if out_dims=1) and scaled by `T`.
|
| 169 |
+
|
| 170 |
+
This allows each neuron `d` to process its temporal history `A_d^t` using
|
| 171 |
+
its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`,
|
| 172 |
+
enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85].
|
| 173 |
+
It's typically used within the `trace_processor` module of the main CTM class.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
in_dims (int): Input dimension (typically `memory_length`).
|
| 177 |
+
out_dims (int): Output dimension per neuron.
|
| 178 |
+
N (int): Number of independent linear models (typically `d_model`).
|
| 179 |
+
T (float): Initial value for learnable temperature/scaling factor applied to output.
|
| 180 |
+
do_norm (bool): Apply Layer Normalization to the input history before linear transform.
|
| 181 |
+
dropout (float): Dropout rate applied to the input.
|
| 182 |
+
"""
|
| 183 |
+
def __init__(self,
|
| 184 |
+
in_dims,
|
| 185 |
+
out_dims,
|
| 186 |
+
N,
|
| 187 |
+
T=1.0,
|
| 188 |
+
do_norm=False,
|
| 189 |
+
dropout=0):
|
| 190 |
+
super().__init__()
|
| 191 |
+
# N is the number of neurons (d_model), in_dims is the history length (memory_length)
|
| 192 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
|
| 193 |
+
self.in_dims = in_dims # Corresponds to memory_length
|
| 194 |
+
# LayerNorm applied across the history dimension for each neuron independently
|
| 195 |
+
self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity()
|
| 196 |
+
self.do_norm = do_norm
|
| 197 |
+
|
| 198 |
+
# Initialize weights and biases
|
| 199 |
+
# w1 shape: (memory_length, out_dims, d_model)
|
| 200 |
+
self.register_parameter('w1', nn.Parameter(
|
| 201 |
+
torch.empty((in_dims, out_dims, N)).uniform_(
|
| 202 |
+
-1/math.sqrt(in_dims + out_dims),
|
| 203 |
+
1/math.sqrt(in_dims + out_dims)
|
| 204 |
+
), requires_grad=True)
|
| 205 |
+
)
|
| 206 |
+
# b1 shape: (1, d_model, out_dims)
|
| 207 |
+
self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
|
| 208 |
+
# Learnable temperature/scaler T
|
| 209 |
+
self.register_parameter('T', nn.Parameter(torch.Tensor([T])))
|
| 210 |
+
|
| 211 |
+
def forward(self, x):
|
| 212 |
+
"""
|
| 213 |
+
Args:
|
| 214 |
+
x (torch.Tensor): Input tensor, expected shape (B, N, in_dims)
|
| 215 |
+
where B=batch, N=d_model, in_dims=memory_length.
|
| 216 |
+
Returns:
|
| 217 |
+
torch.Tensor: Output tensor, shape (B, N) after squeeze(-1).
|
| 218 |
+
"""
|
| 219 |
+
# Input shape: (B, D, M) where D=d_model=N neurons in CTM, M=history/memory length
|
| 220 |
+
out = self.dropout(x)
|
| 221 |
+
# LayerNorm across the memory_length dimension (dim=-1)
|
| 222 |
+
out = self.layernorm(out) # Shape remains (B, N, M)
|
| 223 |
+
|
| 224 |
+
# Apply N independent linear models using einsum
|
| 225 |
+
# einsum('BDM,MHD->BDH', ...)
|
| 226 |
+
# x: (B=batch size, D=N neurons, one NLM per each of these, M=history/memory length)
|
| 227 |
+
# w1: (M, H=hidden dims if using MLP, otherwise output, D=N neurons, parallel)
|
| 228 |
+
# b1: (1, D=N neurons, H)
|
| 229 |
+
# einsum result: (B, D, H)
|
| 230 |
+
# Applying bias requires matching shapes, b1 is broadcasted.
|
| 231 |
+
out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1
|
| 232 |
+
|
| 233 |
+
# Squeeze the output dimension (assumed to be 1 usually) and scale by T
|
| 234 |
+
# This matches the original code's structure exactly.
|
| 235 |
+
out = out.squeeze(-1) / self.T
|
| 236 |
+
return out
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
# --- Backbone Modules ---
|
| 240 |
+
|
| 241 |
+
class ParityBackbone(nn.Module):
|
| 242 |
+
def __init__(self, n_embeddings, d_embedding):
|
| 243 |
+
super(ParityBackbone, self).__init__()
|
| 244 |
+
self.embedding = nn.Embedding(n_embeddings, d_embedding)
|
| 245 |
+
|
| 246 |
+
def forward(self, x):
|
| 247 |
+
"""
|
| 248 |
+
Maps -1 (negative parity) to 0 and 1 (positive) to 1
|
| 249 |
+
"""
|
| 250 |
+
x = (x == 1).long()
|
| 251 |
+
return self.embedding(x.long()).transpose(1, 2) # Transpose for compatibility with other backbones
|
| 252 |
+
|
| 253 |
+
class QAMNISTOperatorEmbeddings(nn.Module):
|
| 254 |
+
def __init__(self, num_operator_types, d_projection):
|
| 255 |
+
super(QAMNISTOperatorEmbeddings, self).__init__()
|
| 256 |
+
self.embedding = nn.Embedding(num_operator_types, d_projection)
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
# -1 for plus and -2 for minus
|
| 260 |
+
return self.embedding(-x - 1)
|
| 261 |
+
|
| 262 |
+
class QAMNISTIndexEmbeddings(torch.nn.Module):
|
| 263 |
+
def __init__(self, max_seq_length, embedding_dim):
|
| 264 |
+
super().__init__()
|
| 265 |
+
self.max_seq_length = max_seq_length
|
| 266 |
+
self.embedding_dim = embedding_dim
|
| 267 |
+
|
| 268 |
+
embedding = torch.zeros(max_seq_length, embedding_dim)
|
| 269 |
+
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
|
| 270 |
+
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
|
| 271 |
+
|
| 272 |
+
embedding[:, 0::2] = torch.sin(position * div_term)
|
| 273 |
+
embedding[:, 1::2] = torch.cos(position * div_term)
|
| 274 |
+
|
| 275 |
+
self.register_buffer('embedding', embedding)
|
| 276 |
+
|
| 277 |
+
def forward(self, x):
|
| 278 |
+
return self.embedding[x]
|
| 279 |
+
|
| 280 |
+
class ThoughtSteps:
|
| 281 |
+
"""
|
| 282 |
+
Helper class for managing "thought steps" in the ctm_qamnist pipeline.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
iterations_per_digit (int): Number of iterations for each digit.
|
| 286 |
+
iterations_per_question_part (int): Number of iterations for each question part.
|
| 287 |
+
total_iterations_for_answering (int): Total number of iterations for answering.
|
| 288 |
+
total_iterations_for_digits (int): Total number of iterations for digits.
|
| 289 |
+
total_iterations_for_question (int): Total number of iterations for question.
|
| 290 |
+
"""
|
| 291 |
+
def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question):
|
| 292 |
+
self.iterations_per_digit = iterations_per_digit
|
| 293 |
+
self.iterations_per_question_part = iterations_per_question_part
|
| 294 |
+
self.total_iterations_for_digits = total_iterations_for_digits
|
| 295 |
+
self.total_iterations_for_question = total_iterations_for_question
|
| 296 |
+
self.total_iterations_for_answering = total_iterations_for_answering
|
| 297 |
+
self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering
|
| 298 |
+
|
| 299 |
+
def determine_step_type(self, stepi: int):
|
| 300 |
+
is_digit_step = stepi < self.total_iterations_for_digits
|
| 301 |
+
is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question
|
| 302 |
+
is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question
|
| 303 |
+
return is_digit_step, is_question_step, is_answer_step
|
| 304 |
+
|
| 305 |
+
def determine_answer_step_type(self, stepi: int):
|
| 306 |
+
step_within_questions = stepi - self.total_iterations_for_digits
|
| 307 |
+
if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
|
| 308 |
+
is_index_step = True
|
| 309 |
+
is_operator_step = False
|
| 310 |
+
else:
|
| 311 |
+
is_index_step = False
|
| 312 |
+
is_operator_step = True
|
| 313 |
+
return is_index_step, is_operator_step
|
| 314 |
+
|
| 315 |
+
class MNISTBackbone(nn.Module):
|
| 316 |
+
"""
|
| 317 |
+
Simple backbone for MNIST feature extraction.
|
| 318 |
+
"""
|
| 319 |
+
def __init__(self, d_input):
|
| 320 |
+
super(MNISTBackbone, self).__init__()
|
| 321 |
+
self.layers = nn.Sequential(
|
| 322 |
+
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
|
| 323 |
+
nn.BatchNorm2d(d_input),
|
| 324 |
+
nn.ReLU(),
|
| 325 |
+
nn.MaxPool2d(2, 2),
|
| 326 |
+
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
|
| 327 |
+
nn.BatchNorm2d(d_input),
|
| 328 |
+
nn.ReLU(),
|
| 329 |
+
nn.MaxPool2d(2, 2),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
def forward(self, x):
|
| 333 |
+
return self.layers(x)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class MiniGridBackbone(nn.Module):
|
| 337 |
+
def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8):
|
| 338 |
+
super().__init__()
|
| 339 |
+
self.object_embedding = nn.Embedding(num_objects, embedding_dim)
|
| 340 |
+
self.color_embedding = nn.Embedding(num_colors, embedding_dim)
|
| 341 |
+
self.state_embedding = nn.Embedding(num_states, embedding_dim)
|
| 342 |
+
|
| 343 |
+
self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim)
|
| 344 |
+
|
| 345 |
+
self.project_to_d_projection = nn.Sequential(
|
| 346 |
+
nn.Linear(embedding_dim * 4, d_input * 2),
|
| 347 |
+
nn.GLU(),
|
| 348 |
+
nn.LayerNorm(d_input),
|
| 349 |
+
nn.Linear(d_input, d_input * 2),
|
| 350 |
+
nn.GLU(),
|
| 351 |
+
nn.LayerNorm(d_input)
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
def forward(self, x):
|
| 355 |
+
x = x.long()
|
| 356 |
+
B, H, W, C = x.size()
|
| 357 |
+
|
| 358 |
+
object_idx = x[:,:,:, 0]
|
| 359 |
+
color_idx = x[:,:,:, 1]
|
| 360 |
+
state_idx = x[:,:,:, 2]
|
| 361 |
+
|
| 362 |
+
obj_embed = self.object_embedding(object_idx)
|
| 363 |
+
color_embed = self.color_embedding(color_idx)
|
| 364 |
+
state_embed = self.state_embedding(state_idx)
|
| 365 |
+
|
| 366 |
+
pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1)
|
| 367 |
+
pos_embed = self.position_embedding(pos_idx)
|
| 368 |
+
|
| 369 |
+
out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1))
|
| 370 |
+
return out
|
| 371 |
+
|
| 372 |
+
class ClassicControlBackbone(nn.Module):
|
| 373 |
+
def __init__(self, d_input):
|
| 374 |
+
super().__init__()
|
| 375 |
+
self.input_projector = nn.Sequential(
|
| 376 |
+
nn.Flatten(),
|
| 377 |
+
nn.LazyLinear(d_input * 2),
|
| 378 |
+
nn.GLU(),
|
| 379 |
+
nn.LayerNorm(d_input),
|
| 380 |
+
nn.LazyLinear(d_input * 2),
|
| 381 |
+
nn.GLU(),
|
| 382 |
+
nn.LayerNorm(d_input)
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
def forward(self, x):
|
| 386 |
+
return self.input_projector(x)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class ShallowWide(nn.Module):
|
| 390 |
+
"""
|
| 391 |
+
Simple, wide, shallow convolutional backbone for image feature extraction.
|
| 392 |
+
|
| 393 |
+
Alternative to ResNet, uses grouped convolutions and GLU activations.
|
| 394 |
+
Fixed structure, useful for specific experiments.
|
| 395 |
+
"""
|
| 396 |
+
def __init__(self):
|
| 397 |
+
super(ShallowWide, self).__init__()
|
| 398 |
+
# LazyConv2d infers input channels
|
| 399 |
+
self.layers = nn.Sequential(
|
| 400 |
+
nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), # Output channels = 4096
|
| 401 |
+
nn.GLU(dim=1), # Halves channels to 2048
|
| 402 |
+
nn.BatchNorm2d(2048),
|
| 403 |
+
# Grouped convolution maintains width but processes groups independently
|
| 404 |
+
nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32),
|
| 405 |
+
nn.GLU(dim=1), # Halves channels to 2048
|
| 406 |
+
nn.BatchNorm2d(2048)
|
| 407 |
+
)
|
| 408 |
+
def forward(self, x):
|
| 409 |
+
return self.layers(x)
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class PretrainedResNetWrapper(nn.Module):
|
| 413 |
+
"""
|
| 414 |
+
Wrapper to use standard pre-trained ResNet models from torchvision.
|
| 415 |
+
|
| 416 |
+
Loads a specified ResNet architecture pre-trained on ImageNet, removes the
|
| 417 |
+
final classification layer (fc), average pooling, and optionally later layers
|
| 418 |
+
(e.g., layer4), allowing it to be used as a feature extractor backbone.
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50').
|
| 422 |
+
fine_tune (bool): If False, freezes the weights of the pre-trained backbone.
|
| 423 |
+
"""
|
| 424 |
+
def __init__(self, resnet_type, fine_tune=True):
|
| 425 |
+
super(PretrainedResNetWrapper, self).__init__()
|
| 426 |
+
self.resnet_type = resnet_type
|
| 427 |
+
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True)
|
| 428 |
+
|
| 429 |
+
if not fine_tune:
|
| 430 |
+
for param in self.backbone.parameters():
|
| 431 |
+
param.requires_grad = False
|
| 432 |
+
|
| 433 |
+
# Remove final layers to use as feature extractor
|
| 434 |
+
self.backbone.avgpool = Identity()
|
| 435 |
+
self.backbone.fc = Identity()
|
| 436 |
+
# Keep layer4 by default, user can modify instance if needed
|
| 437 |
+
# self.backbone.layer4 = Identity()
|
| 438 |
+
|
| 439 |
+
def forward(self, x):
|
| 440 |
+
# Get features from the modified ResNet
|
| 441 |
+
out = self.backbone(x)
|
| 442 |
+
|
| 443 |
+
# Reshape output to (B, C, H, W) - This is heuristic based on original comment.
|
| 444 |
+
# User might need to adjust this based on which layers are kept/removed.
|
| 445 |
+
# Infer C based on ResNet type (example values)
|
| 446 |
+
nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 # Approx for layer3/4 output channel numbers
|
| 447 |
+
# Infer H, W assuming output is flattened C * H * W
|
| 448 |
+
num_features = out.shape[-1]
|
| 449 |
+
# This calculation assumes nc is correct and feature map is square
|
| 450 |
+
wh_squared = num_features / nc
|
| 451 |
+
if wh_squared < 0 or not float(wh_squared).is_integer():
|
| 452 |
+
print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}")
|
| 453 |
+
# Return potentially flattened features if reshape fails
|
| 454 |
+
return out
|
| 455 |
+
wh = int(np.sqrt(wh_squared))
|
| 456 |
+
|
| 457 |
+
return out.reshape(x.size(0), nc, wh, wh)
|
| 458 |
+
|
| 459 |
+
# --- Positional Encoding Modules ---
|
| 460 |
+
|
| 461 |
+
class LearnableFourierPositionalEncoding(nn.Module):
|
| 462 |
+
"""
|
| 463 |
+
Learnable Fourier Feature Positional Encoding.
|
| 464 |
+
|
| 465 |
+
Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
|
| 466 |
+
Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
|
| 467 |
+
Provides positional information for 2D feature maps.
|
| 468 |
+
|
| 469 |
+
Args:
|
| 470 |
+
d_model (int): The output dimension of the positional encoding (D).
|
| 471 |
+
G (int): Positional groups (default 1).
|
| 472 |
+
M (int): Dimensionality of input coordinates (default 2 for H, W).
|
| 473 |
+
F_dim (int): Dimension of the Fourier features.
|
| 474 |
+
H_dim (int): Hidden dimension of the MLP.
|
| 475 |
+
gamma (float): Initialization scale for the Fourier projection weights (Wr).
|
| 476 |
+
"""
|
| 477 |
+
def __init__(self, d_model,
|
| 478 |
+
G=1, M=2,
|
| 479 |
+
F_dim=256,
|
| 480 |
+
H_dim=128,
|
| 481 |
+
gamma=1/2.5,
|
| 482 |
+
):
|
| 483 |
+
super().__init__()
|
| 484 |
+
self.G = G
|
| 485 |
+
self.M = M
|
| 486 |
+
self.F_dim = F_dim
|
| 487 |
+
self.H_dim = H_dim
|
| 488 |
+
self.D = d_model
|
| 489 |
+
self.gamma = gamma
|
| 490 |
+
|
| 491 |
+
self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
|
| 492 |
+
self.mlp = nn.Sequential(
|
| 493 |
+
nn.Linear(self.F_dim, self.H_dim, bias=True),
|
| 494 |
+
nn.GLU(), # Halves H_dim
|
| 495 |
+
nn.Linear(self.H_dim // 2, self.D // self.G),
|
| 496 |
+
nn.LayerNorm(self.D // self.G)
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
self.init_weights()
|
| 500 |
+
|
| 501 |
+
def init_weights(self):
|
| 502 |
+
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
|
| 503 |
+
|
| 504 |
+
def forward(self, x):
|
| 505 |
+
"""
|
| 506 |
+
Computes positional encodings for the input feature map x.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
x (torch.Tensor): Input feature map, shape (B, C, H, W).
|
| 510 |
+
|
| 511 |
+
Returns:
|
| 512 |
+
torch.Tensor: Positional encoding tensor, shape (B, D, H, W).
|
| 513 |
+
"""
|
| 514 |
+
B, C, H, W = x.shape
|
| 515 |
+
# Creates coordinates based on (H, W) and repeats for batch B.
|
| 516 |
+
# Takes x[:,0] assuming channel dim isn't needed for coords.
|
| 517 |
+
x_coord = add_coord_dim(x[:,0]) # Expects (B, H, W) -> (B, H, W, 2)
|
| 518 |
+
|
| 519 |
+
# Compute Fourier features
|
| 520 |
+
projected = self.Wr(x_coord) # (B, H, W, F_dim // 2)
|
| 521 |
+
cosines = torch.cos(projected)
|
| 522 |
+
sines = torch.sin(projected)
|
| 523 |
+
F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) # (B, H, W, F_dim)
|
| 524 |
+
|
| 525 |
+
# Project features through MLP
|
| 526 |
+
Y = self.mlp(F) # (B, H, W, D // G)
|
| 527 |
+
|
| 528 |
+
# Reshape to (B, D, H, W)
|
| 529 |
+
PEx = Y.permute(0, 3, 1, 2) # Assuming G=1
|
| 530 |
+
return PEx
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class MultiLearnableFourierPositionalEncoding(nn.Module):
|
| 534 |
+
"""
|
| 535 |
+
Combines multiple LearnableFourierPositionalEncoding modules with different
|
| 536 |
+
initialization scales (gamma) via a learnable weighted sum.
|
| 537 |
+
|
| 538 |
+
Allows the model to learn an optimal combination of positional frequencies.
|
| 539 |
+
|
| 540 |
+
Args:
|
| 541 |
+
d_model (int): Output dimension of the encoding.
|
| 542 |
+
G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding.
|
| 543 |
+
gamma_range (list[float]): Min and max gamma values for the linspace.
|
| 544 |
+
N (int): Number of parallel embedding modules to create.
|
| 545 |
+
"""
|
| 546 |
+
def __init__(self, d_model,
|
| 547 |
+
G=1, M=2,
|
| 548 |
+
F_dim=256,
|
| 549 |
+
H_dim=128,
|
| 550 |
+
gamma_range=[1.0, 0.1], # Default range
|
| 551 |
+
N=10,
|
| 552 |
+
):
|
| 553 |
+
super().__init__()
|
| 554 |
+
self.embedders = nn.ModuleList()
|
| 555 |
+
for gamma in np.linspace(gamma_range[0], gamma_range[1], N):
|
| 556 |
+
self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma))
|
| 557 |
+
|
| 558 |
+
# Renamed parameter from 'combination' to 'combination_weights' for clarity only in comments
|
| 559 |
+
# Actual registered name remains 'combination' as in original code
|
| 560 |
+
self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True))
|
| 561 |
+
self.N = N
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def forward(self, x):
|
| 565 |
+
"""
|
| 566 |
+
Computes combined positional encoding.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
x (torch.Tensor): Input feature map, shape (B, C, H, W).
|
| 570 |
+
|
| 571 |
+
Returns:
|
| 572 |
+
torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W).
|
| 573 |
+
"""
|
| 574 |
+
# Compute embeddings from all modules and stack: (N, B, D, H, W)
|
| 575 |
+
pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0)
|
| 576 |
+
|
| 577 |
+
# Compute combination weights using softmax
|
| 578 |
+
# Use registered parameter name 'combination'
|
| 579 |
+
# Reshape weights for broadcasting: (N,) -> (N, 1, 1, 1, 1)
|
| 580 |
+
weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1)
|
| 581 |
+
|
| 582 |
+
# Compute weighted sum over the N dimension
|
| 583 |
+
combined_emb = (pos_embs * weights).sum(0) # (B, D, H, W)
|
| 584 |
+
return combined_emb
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
class CustomRotationalEmbedding(nn.Module):
|
| 588 |
+
"""
|
| 589 |
+
Custom Rotational Positional Embedding.
|
| 590 |
+
|
| 591 |
+
Generates 2D positional embeddings based on rotating a fixed start vector.
|
| 592 |
+
The rotation angle for each grid position is determined primarily by its
|
| 593 |
+
horizontal position (width dimension). The resulting rotated vectors are
|
| 594 |
+
concatenated and projected.
|
| 595 |
+
|
| 596 |
+
Note: The current implementation derives angles only from the width dimension (`x.size(-1)`).
|
| 597 |
+
|
| 598 |
+
Args:
|
| 599 |
+
d_model (int): Dimensionality of the output embeddings.
|
| 600 |
+
"""
|
| 601 |
+
def __init__(self, d_model):
|
| 602 |
+
super(CustomRotationalEmbedding, self).__init__()
|
| 603 |
+
# Learnable 2D start vector
|
| 604 |
+
self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True))
|
| 605 |
+
# Projects the 4D concatenated rotated vectors to d_model
|
| 606 |
+
# Input size 4 comes from concatenating two 2D rotated vectors
|
| 607 |
+
self.projection = nn.Sequential(nn.Linear(4, d_model))
|
| 608 |
+
|
| 609 |
+
def forward(self, x):
|
| 610 |
+
"""
|
| 611 |
+
Computes rotational positional embeddings based on input width.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
x (torch.Tensor): Input tensor (used for shape and device),
|
| 615 |
+
shape (batch_size, channels, height, width).
|
| 616 |
+
Returns:
|
| 617 |
+
Output tensor containing positional embeddings,
|
| 618 |
+
shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all.
|
| 619 |
+
"""
|
| 620 |
+
B, C, H, W = x.shape
|
| 621 |
+
device = x.device
|
| 622 |
+
|
| 623 |
+
# --- Generate rotations based only on Width ---
|
| 624 |
+
# Angles derived from width dimension
|
| 625 |
+
theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) # Angle per column
|
| 626 |
+
cos_theta = torch.cos(theta_rad)
|
| 627 |
+
sin_theta = torch.sin(theta_rad)
|
| 628 |
+
|
| 629 |
+
# Create rotation matrices: Shape (W, 2, 2)
|
| 630 |
+
# Use unsqueeze(1) to allow stacking along dim 1
|
| 631 |
+
rotation_matrices = torch.stack([
|
| 632 |
+
torch.stack([cos_theta, -sin_theta], dim=-1), # Shape (W, 2)
|
| 633 |
+
torch.stack([sin_theta, cos_theta], dim=-1) # Shape (W, 2)
|
| 634 |
+
], dim=1) # Stacks along dim 1 -> Shape (W, 2, 2)
|
| 635 |
+
|
| 636 |
+
# Rotate the start vector by column angle: Shape (W, 2)
|
| 637 |
+
rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector)
|
| 638 |
+
|
| 639 |
+
# --- Create Grid Key ---
|
| 640 |
+
# Original code uses repeats based on rotated_vectors.shape[0] (which is W) for both dimensions.
|
| 641 |
+
# This creates a (W, W, 4) key tensor.
|
| 642 |
+
key = torch.cat((
|
| 643 |
+
torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), # (W, 1, 2) -> (W, W, 2)
|
| 644 |
+
torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0) # (1, W, 2) -> (W, W, 2)
|
| 645 |
+
), dim=-1) # Shape (W, W, 4)
|
| 646 |
+
|
| 647 |
+
# Project the 4D key vector to d_model: Shape (W, W, d_model)
|
| 648 |
+
pe_grid = self.projection(key)
|
| 649 |
+
|
| 650 |
+
# Reshape to (1, d_model, W, W) and then select/resize to target H, W?
|
| 651 |
+
# Original code permutes to (d_model, W, W) and unsqueezes to (1, d_model, W, W)
|
| 652 |
+
pe = pe_grid.permute(2, 0, 1).unsqueeze(0)
|
| 653 |
+
|
| 654 |
+
# If H != W, this needs adjustment. Assuming H=W or cropping/padding happens later.
|
| 655 |
+
# Let's return the (1, d_model, W, W) tensor as generated by the original logic.
|
| 656 |
+
# If H != W, downstream code must handle the mismatch or this PE needs modification.
|
| 657 |
+
if H != W:
|
| 658 |
+
# Simple interpolation/cropping could be added, but sticking to original logic:
|
| 659 |
+
# Option 1: Interpolate
|
| 660 |
+
# pe = F.interpolate(pe, size=(H, W), mode='bilinear', align_corners=False)
|
| 661 |
+
# Option 2: Crop/Pad (e.g., crop if W > W_target, pad if W < W_target)
|
| 662 |
+
# Sticking to original: return shape (1, d_model, W, W)
|
| 663 |
+
pass
|
| 664 |
+
|
| 665 |
+
return pe
|
| 666 |
+
|
| 667 |
+
class CustomRotationalEmbedding1D(nn.Module):
|
| 668 |
+
def __init__(self, d_model):
|
| 669 |
+
super(CustomRotationalEmbedding1D, self).__init__()
|
| 670 |
+
self.projection = nn.Linear(2, d_model)
|
| 671 |
+
|
| 672 |
+
def forward(self, x):
|
| 673 |
+
start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float)
|
| 674 |
+
theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device))
|
| 675 |
+
cos_theta = torch.cos(theta_rad)
|
| 676 |
+
sin_theta = torch.sin(theta_rad)
|
| 677 |
+
cos_theta = cos_theta.unsqueeze(1) # Shape: (height, 1)
|
| 678 |
+
sin_theta = sin_theta.unsqueeze(1) # Shape: (height, 1)
|
| 679 |
+
|
| 680 |
+
# Create rotation matrices
|
| 681 |
+
rotation_matrices = torch.stack([
|
| 682 |
+
torch.cat([cos_theta, -sin_theta], dim=1),
|
| 683 |
+
torch.cat([sin_theta, cos_theta], dim=1)
|
| 684 |
+
], dim=1) # Shape: (height, 2, 2)
|
| 685 |
+
|
| 686 |
+
# Rotate the start vector
|
| 687 |
+
rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector)
|
| 688 |
+
|
| 689 |
+
pe = self.projection(rotated_vectors)
|
| 690 |
+
pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0)
|
| 691 |
+
return pe.transpose(1, 2) # Transpose for compatibility with other backbones
|
| 692 |
+
|
models/resnet.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
from models.modules import Identity
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"ResNet",
|
| 8 |
+
"resnet18",
|
| 9 |
+
"resnet34",
|
| 10 |
+
"resnet50",
|
| 11 |
+
"resnet101",
|
| 12 |
+
"resnet152",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 17 |
+
"""3x3 convolution with padding"""
|
| 18 |
+
return nn.Conv2d(
|
| 19 |
+
in_planes,
|
| 20 |
+
out_planes,
|
| 21 |
+
kernel_size=3,
|
| 22 |
+
stride=stride,
|
| 23 |
+
padding=dilation,
|
| 24 |
+
groups=groups,
|
| 25 |
+
bias=False,
|
| 26 |
+
dilation=dilation,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 31 |
+
"""1x1 convolution"""
|
| 32 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class BasicBlock(nn.Module):
|
| 36 |
+
expansion = 1
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
inplanes,
|
| 41 |
+
planes,
|
| 42 |
+
stride=1,
|
| 43 |
+
downsample=None,
|
| 44 |
+
groups=1,
|
| 45 |
+
base_width=64,
|
| 46 |
+
dilation=1,
|
| 47 |
+
norm_layer=None,
|
| 48 |
+
):
|
| 49 |
+
super(BasicBlock, self).__init__()
|
| 50 |
+
if norm_layer is None:
|
| 51 |
+
norm_layer = nn.BatchNorm2d
|
| 52 |
+
if groups != 1 or base_width != 64:
|
| 53 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
| 54 |
+
if dilation > 1:
|
| 55 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 56 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 57 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 58 |
+
self.bn1 = norm_layer(planes)
|
| 59 |
+
self.relu = nn.ReLU(inplace=True)
|
| 60 |
+
self.conv2 = conv3x3(planes, planes)
|
| 61 |
+
self.bn2 = norm_layer(planes)
|
| 62 |
+
self.downsample = downsample
|
| 63 |
+
self.stride = stride
|
| 64 |
+
|
| 65 |
+
def forward(self, x):
|
| 66 |
+
identity = x
|
| 67 |
+
|
| 68 |
+
out = self.conv1(x)
|
| 69 |
+
out = self.bn1(out)
|
| 70 |
+
out = self.relu(out)
|
| 71 |
+
|
| 72 |
+
out = self.conv2(out)
|
| 73 |
+
out = self.bn2(out)
|
| 74 |
+
|
| 75 |
+
if self.downsample is not None:
|
| 76 |
+
identity = self.downsample(x)
|
| 77 |
+
|
| 78 |
+
out += identity
|
| 79 |
+
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
return out
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class Bottleneck(nn.Module):
|
| 85 |
+
expansion = 4
|
| 86 |
+
|
| 87 |
+
def __init__(
|
| 88 |
+
self,
|
| 89 |
+
inplanes,
|
| 90 |
+
planes,
|
| 91 |
+
stride=1,
|
| 92 |
+
downsample=None,
|
| 93 |
+
groups=1,
|
| 94 |
+
base_width=64,
|
| 95 |
+
dilation=1,
|
| 96 |
+
norm_layer=None,
|
| 97 |
+
):
|
| 98 |
+
super(Bottleneck, self).__init__()
|
| 99 |
+
if norm_layer is None:
|
| 100 |
+
norm_layer = nn.BatchNorm2d
|
| 101 |
+
width = int(planes * (base_width / 64.0)) * groups
|
| 102 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 103 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 104 |
+
self.bn1 = norm_layer(width)
|
| 105 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 106 |
+
self.bn2 = norm_layer(width)
|
| 107 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 108 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 109 |
+
self.relu = nn.ReLU(inplace=True)
|
| 110 |
+
self.downsample = downsample
|
| 111 |
+
self.stride = stride
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
identity = x
|
| 115 |
+
|
| 116 |
+
out = self.conv1(x)
|
| 117 |
+
out = self.bn1(out)
|
| 118 |
+
out = self.relu(out)
|
| 119 |
+
|
| 120 |
+
out = self.conv2(out)
|
| 121 |
+
out = self.bn2(out)
|
| 122 |
+
out = self.relu(out)
|
| 123 |
+
|
| 124 |
+
out = self.conv3(out)
|
| 125 |
+
out = self.bn3(out)
|
| 126 |
+
|
| 127 |
+
if self.downsample is not None:
|
| 128 |
+
identity = self.downsample(x)
|
| 129 |
+
|
| 130 |
+
out += identity
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# activation = None
|
| 134 |
+
# activation = out.detach().cpu().numpy()
|
| 135 |
+
out = self.relu(out)
|
| 136 |
+
# return out, activation
|
| 137 |
+
|
| 138 |
+
return out
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class ResNet(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
in_channels,
|
| 145 |
+
feature_scales,
|
| 146 |
+
stride,
|
| 147 |
+
block,
|
| 148 |
+
layers,
|
| 149 |
+
num_classes=10,
|
| 150 |
+
zero_init_residual=False,
|
| 151 |
+
groups=1,
|
| 152 |
+
width_per_group=64,
|
| 153 |
+
replace_stride_with_dilation=None,
|
| 154 |
+
norm_layer=None,
|
| 155 |
+
do_initial_max_pool=True,
|
| 156 |
+
):
|
| 157 |
+
super(ResNet, self).__init__()
|
| 158 |
+
if norm_layer is None:
|
| 159 |
+
norm_layer = nn.BatchNorm2d
|
| 160 |
+
self._norm_layer = norm_layer
|
| 161 |
+
|
| 162 |
+
self.inplanes = 64
|
| 163 |
+
self.dilation = 1
|
| 164 |
+
if replace_stride_with_dilation is None:
|
| 165 |
+
# each element in the tuple indicates if we should replace
|
| 166 |
+
# the 2x2 stride with a dilated convolution instead
|
| 167 |
+
replace_stride_with_dilation = [False, False, False]
|
| 168 |
+
if len(replace_stride_with_dilation) != 3:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
"replace_stride_with_dilation should be None "
|
| 171 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation)
|
| 172 |
+
)
|
| 173 |
+
self.groups = groups
|
| 174 |
+
self.base_width = width_per_group
|
| 175 |
+
|
| 176 |
+
# NOTE: Important!
|
| 177 |
+
# This has changed from a kernel size of 7 (padding=3) to a kernel of 3 (padding=1)
|
| 178 |
+
# The reason for this was to limit the receptive field to constrain models to
|
| 179 |
+
# "Looking around" to gather information.
|
| 180 |
+
|
| 181 |
+
self.conv1 = nn.Conv2d(
|
| 182 |
+
in_channels, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
|
| 183 |
+
) if in_channels in [1, 3] else nn.LazyConv2d(
|
| 184 |
+
self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
|
| 185 |
+
)
|
| 186 |
+
# END
|
| 187 |
+
|
| 188 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 189 |
+
self.relu = nn.ReLU(inplace=True)
|
| 190 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) if do_initial_max_pool else Identity()
|
| 191 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 192 |
+
self.feature_scales = feature_scales
|
| 193 |
+
if 2 in feature_scales:
|
| 194 |
+
self.layer2 = self._make_layer(
|
| 195 |
+
block, 128, layers[1], stride=stride, dilate=replace_stride_with_dilation[0]
|
| 196 |
+
)
|
| 197 |
+
if 3 in feature_scales:
|
| 198 |
+
self.layer3 = self._make_layer(
|
| 199 |
+
block, 256, layers[2], stride=stride, dilate=replace_stride_with_dilation[1]
|
| 200 |
+
)
|
| 201 |
+
if 4 in feature_scales:
|
| 202 |
+
self.layer4 = self._make_layer(
|
| 203 |
+
block, 512, layers[3], stride=stride, dilate=replace_stride_with_dilation[2]
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# NOTE: Commented this out as it is not used anymore for this work, kept it for reference
|
| 207 |
+
# self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 208 |
+
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 209 |
+
|
| 210 |
+
# for m in self.modules():
|
| 211 |
+
# if isinstance(m, nn.Conv2d):
|
| 212 |
+
# nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
| 213 |
+
# elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 214 |
+
# nn.init.constant_(m.weight, 1)
|
| 215 |
+
# nn.init.constant_(m.bias, 0)
|
| 216 |
+
|
| 217 |
+
# Zero-initialize the last BN in each residual branch,
|
| 218 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 219 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 220 |
+
if zero_init_residual:
|
| 221 |
+
for m in self.modules():
|
| 222 |
+
if isinstance(m, Bottleneck):
|
| 223 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 224 |
+
elif isinstance(m, BasicBlock):
|
| 225 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 226 |
+
|
| 227 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 228 |
+
norm_layer = self._norm_layer
|
| 229 |
+
downsample = None
|
| 230 |
+
previous_dilation = self.dilation
|
| 231 |
+
if dilate:
|
| 232 |
+
self.dilation *= stride
|
| 233 |
+
stride = 1
|
| 234 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 235 |
+
downsample = nn.Sequential(
|
| 236 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 237 |
+
norm_layer(planes * block.expansion),
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
layers = []
|
| 241 |
+
layers.append(
|
| 242 |
+
block(
|
| 243 |
+
self.inplanes,
|
| 244 |
+
planes,
|
| 245 |
+
stride,
|
| 246 |
+
downsample,
|
| 247 |
+
self.groups,
|
| 248 |
+
self.base_width,
|
| 249 |
+
previous_dilation,
|
| 250 |
+
norm_layer,
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
self.inplanes = planes * block.expansion
|
| 254 |
+
for _ in range(1, blocks):
|
| 255 |
+
layers.append(
|
| 256 |
+
block(
|
| 257 |
+
self.inplanes,
|
| 258 |
+
planes,
|
| 259 |
+
groups=self.groups,
|
| 260 |
+
base_width=self.base_width,
|
| 261 |
+
dilation=self.dilation,
|
| 262 |
+
norm_layer=norm_layer,
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
return nn.Sequential(*layers)
|
| 267 |
+
|
| 268 |
+
def forward(self, x):
|
| 269 |
+
activations = []
|
| 270 |
+
x = self.conv1(x)
|
| 271 |
+
x = self.bn1(x)
|
| 272 |
+
x = self.relu(x)
|
| 273 |
+
x = self.maxpool(x)
|
| 274 |
+
# if return_activations: activations.append(torch.clone(x))
|
| 275 |
+
x = self.layer1(x)
|
| 276 |
+
|
| 277 |
+
if 2 in self.feature_scales:
|
| 278 |
+
x = self.layer2(x)
|
| 279 |
+
if 3 in self.feature_scales:
|
| 280 |
+
x = self.layer3(x)
|
| 281 |
+
if 4 in self.feature_scales:
|
| 282 |
+
x = self.layer4(x)
|
| 283 |
+
return x
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _resnet(in_channels, feature_scales, stride, arch, block, layers, pretrained, progress, device, do_initial_max_pool, **kwargs):
|
| 287 |
+
model = ResNet(in_channels, feature_scales, stride, block, layers, do_initial_max_pool=do_initial_max_pool, **kwargs)
|
| 288 |
+
if pretrained:
|
| 289 |
+
assert in_channels==3
|
| 290 |
+
script_dir = os.path.dirname(__file__)
|
| 291 |
+
state_dict = torch.load(
|
| 292 |
+
script_dir + '/state_dicts/' + arch + ".pt", map_location=device
|
| 293 |
+
)
|
| 294 |
+
model.load_state_dict(state_dict, strict=False)
|
| 295 |
+
return model
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def resnet18(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 299 |
+
"""Constructs a ResNet-18 model.
|
| 300 |
+
Args:
|
| 301 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 302 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 303 |
+
"""
|
| 304 |
+
return _resnet(in_channels,
|
| 305 |
+
feature_scales, stride, "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def resnet34(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 310 |
+
"""Constructs a ResNet-34 model.
|
| 311 |
+
Args:
|
| 312 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 313 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 314 |
+
"""
|
| 315 |
+
return _resnet(in_channels,
|
| 316 |
+
feature_scales, stride, "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def resnet50(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 321 |
+
"""Constructs a ResNet-50 model.
|
| 322 |
+
Args:
|
| 323 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 324 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 325 |
+
"""
|
| 326 |
+
return _resnet(in_channels,
|
| 327 |
+
feature_scales, stride, "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def resnet101(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 332 |
+
"""Constructs a ResNet-50 model.
|
| 333 |
+
Args:
|
| 334 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 335 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 336 |
+
"""
|
| 337 |
+
return _resnet(in_channels,
|
| 338 |
+
feature_scales, stride, "resnet101", Bottleneck, [3, 4, 23, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def resnet152(in_channels, feature_scales, stride=2, pretrained=False, progress=True, device="cpu", do_initial_max_pool=True, **kwargs):
|
| 343 |
+
"""Constructs a ResNet-50 model.
|
| 344 |
+
Args:
|
| 345 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 346 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 347 |
+
"""
|
| 348 |
+
return _resnet(in_channels,
|
| 349 |
+
feature_scales, stride, "resnet152", Bottleneck, [3, 4, 36, 3], pretrained, progress, device, do_initial_max_pool, **kwargs
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
def prepare_resnet_backbone(backbone_type):
|
| 353 |
+
|
| 354 |
+
resnet_family = resnet18 # Default
|
| 355 |
+
if '34' in backbone_type: resnet_family = resnet34
|
| 356 |
+
if '50' in backbone_type: resnet_family = resnet50
|
| 357 |
+
if '101' in backbone_type: resnet_family = resnet101
|
| 358 |
+
if '152' in backbone_type: resnet_family = resnet152
|
| 359 |
+
|
| 360 |
+
# Determine which ResNet blocks to keep
|
| 361 |
+
block_num_str = backbone_type.split('-')[-1]
|
| 362 |
+
hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]
|
| 363 |
+
|
| 364 |
+
backbone = resnet_family(
|
| 365 |
+
3,
|
| 366 |
+
hyper_blocks_to_keep,
|
| 367 |
+
stride=2,
|
| 368 |
+
pretrained=False,
|
| 369 |
+
progress=True,
|
| 370 |
+
device="cpu",
|
| 371 |
+
do_initial_max_pool=True,
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
return backbone
|
models/utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import re
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
def compute_decay(T, params, clamp_lims=(0, 15)):
|
| 7 |
+
"""
|
| 8 |
+
This function computes exponential decays for learnable synchronisation
|
| 9 |
+
interactions between pairs of neurons.
|
| 10 |
+
"""
|
| 11 |
+
assert len(clamp_lims), 'Clamp lims should be length 2'
|
| 12 |
+
assert type(clamp_lims) == tuple, 'Clamp lims should be tuple'
|
| 13 |
+
|
| 14 |
+
indices = torch.arange(T-1, -1, -1, device=params.device).reshape(T, 1).expand(T, params.shape[0])
|
| 15 |
+
out = torch.exp(-indices * torch.clamp(params, clamp_lims[0], clamp_lims[1]).unsqueeze(0))
|
| 16 |
+
return out
|
| 17 |
+
|
| 18 |
+
def add_coord_dim(x, scaled=True):
|
| 19 |
+
"""
|
| 20 |
+
Adds a final dimension to the tensor representing 2D coordinates.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
tensor: A PyTorch tensor of shape (B, D, H, W).
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
A PyTorch tensor of shape (B, D, H, W, 2) with the last dimension
|
| 27 |
+
representing the 2D coordinates within the HW dimensions.
|
| 28 |
+
"""
|
| 29 |
+
B, H, W = x.shape
|
| 30 |
+
# Create coordinate grids
|
| 31 |
+
x_coords = torch.arange(W, device=x.device, dtype=x.dtype).repeat(H, 1) # Shape (H, W)
|
| 32 |
+
y_coords = torch.arange(H, device=x.device, dtype=x.dtype).unsqueeze(-1).repeat(1, W) # Shape (H, W)
|
| 33 |
+
if scaled:
|
| 34 |
+
x_coords /= (W-1)
|
| 35 |
+
y_coords /= (H-1)
|
| 36 |
+
# Stack coordinates and expand dimensions
|
| 37 |
+
coords = torch.stack((x_coords, y_coords), dim=-1) # Shape (H, W, 2)
|
| 38 |
+
coords = coords.unsqueeze(0) # Shape (1, 1, H, W, 2)
|
| 39 |
+
coords = coords.repeat(B, 1, 1, 1) # Shape (B, D, H, W, 2)
|
| 40 |
+
return coords
|
| 41 |
+
|
| 42 |
+
def compute_normalized_entropy(logits, reduction='mean'):
|
| 43 |
+
"""
|
| 44 |
+
Calculates the normalized entropy of a PyTorch tensor of logits along the
|
| 45 |
+
final dimension.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
logits: A PyTorch tensor of logits.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
A PyTorch tensor containing the normalized entropy values.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
# Apply softmax to get probabilities
|
| 55 |
+
preds = F.softmax(logits, dim=-1)
|
| 56 |
+
|
| 57 |
+
# Calculate the log probabilities
|
| 58 |
+
log_preds = torch.log_softmax(logits, dim=-1)
|
| 59 |
+
|
| 60 |
+
# Calculate the entropy
|
| 61 |
+
entropy = -torch.sum(preds * log_preds, dim=-1)
|
| 62 |
+
|
| 63 |
+
# Calculate the maximum possible entropy
|
| 64 |
+
num_classes = preds.shape[-1]
|
| 65 |
+
max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32))
|
| 66 |
+
|
| 67 |
+
# Normalize the entropy
|
| 68 |
+
normalized_entropy = entropy / max_entropy
|
| 69 |
+
if len(logits.shape)>2 and reduction == 'mean':
|
| 70 |
+
normalized_entropy = normalized_entropy.flatten(1).mean(-1)
|
| 71 |
+
|
| 72 |
+
return normalized_entropy
|
| 73 |
+
|
| 74 |
+
def reshape_predictions(predictions, prediction_reshaper):
|
| 75 |
+
B, T = predictions.size(0), predictions.size(-1)
|
| 76 |
+
new_shape = [B] + prediction_reshaper + [T]
|
| 77 |
+
rehaped_predictions = predictions.reshape(new_shape)
|
| 78 |
+
return rehaped_predictions
|
| 79 |
+
|
| 80 |
+
def get_all_log_dirs(root_dir):
|
| 81 |
+
folders = []
|
| 82 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 83 |
+
if any(f.endswith(".pt") for f in filenames):
|
| 84 |
+
folders.append(dirpath)
|
| 85 |
+
return folders
|
| 86 |
+
|
| 87 |
+
def get_latest_checkpoint(log_dir):
|
| 88 |
+
files = [f for f in os.listdir(log_dir) if re.match(r'checkpoint_\d+\.pt', f)]
|
| 89 |
+
return os.path.join(log_dir, max(files, key=lambda f: int(re.search(r'\d+', f).group()))) if files else None
|
| 90 |
+
|
| 91 |
+
def get_latest_checkpoint_file(filepath, limit=300000):
|
| 92 |
+
checkpoint_files = get_checkpoint_files(filepath)
|
| 93 |
+
checkpoint_files = [
|
| 94 |
+
f for f in checkpoint_files if int(re.search(r'checkpoint_(\d+)\.pt', f).group(1)) <= limit
|
| 95 |
+
]
|
| 96 |
+
if not checkpoint_files:
|
| 97 |
+
return None
|
| 98 |
+
return checkpoint_files[-1]
|
| 99 |
+
|
| 100 |
+
def get_checkpoint_files(filepath):
|
| 101 |
+
regex = r'checkpoint_(\d+)\.pt'
|
| 102 |
+
files = [f for f in os.listdir(filepath) if re.match(regex, f)]
|
| 103 |
+
files = sorted(files, key=lambda f: int(re.search(regex, f).group(1)))
|
| 104 |
+
return [os.path.join(filepath, f) for f in files]
|
| 105 |
+
|
| 106 |
+
def load_checkpoint(checkpoint_path, device):
|
| 107 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 108 |
+
return checkpoint
|
| 109 |
+
|
| 110 |
+
def get_model_args_from_checkpoint(checkpoint):
|
| 111 |
+
if "args" in checkpoint:
|
| 112 |
+
return(checkpoint["args"])
|
| 113 |
+
else:
|
| 114 |
+
raise ValueError("Checkpoint does not contain saved args.")
|
| 115 |
+
|
| 116 |
+
def get_accuracy_and_loss_from_checkpoint(checkpoint, device="cpu"):
|
| 117 |
+
training_iteration = checkpoint.get('training_iteration', 0)
|
| 118 |
+
train_losses = checkpoint.get('train_losses', [])
|
| 119 |
+
test_losses = checkpoint.get('test_losses', [])
|
| 120 |
+
train_accuracies = checkpoint.get('train_accuracies_most_certain', [])
|
| 121 |
+
test_accuracies = checkpoint.get('test_accuracies_most_certain', [])
|
| 122 |
+
return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
matplotlib
|
| 5 |
+
seaborn
|
| 6 |
+
tdqm
|
| 7 |
+
opencv-python
|
| 8 |
+
imageio
|
| 9 |
+
scikit-learn
|
| 10 |
+
umap-learn
|
| 11 |
+
python-dotenv
|
| 12 |
+
gymnasium
|
| 13 |
+
minigrid
|
| 14 |
+
datasets
|
| 15 |
+
autoclip
|
tasks/image_classification/README.md
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Image classification
|
| 2 |
+
|
| 3 |
+
This folder contains code for training and analysing imagenet and cifar related experiments.
|
| 4 |
+
|
| 5 |
+
## Accessing and loading imagenet
|
| 6 |
+
|
| 7 |
+
We use the [ILSRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset in our paper.
|
| 8 |
+
|
| 9 |
+
To get this to work for you, you will need to do the following:
|
| 10 |
+
1. Login to huggingface (make an account) to agree to TCs of this dataset,
|
| 11 |
+
2. Make a new access token.
|
| 12 |
+
3. Install huggingface_hub on the target machine with ```pip install huggingface_hub```
|
| 13 |
+
4. Run ```huggingface-cli login``` and use your token. This will authenticate you on the backend and allow the code to run.
|
| 14 |
+
5. Simply run an imagenet experiment. It will auto download and do all that magic.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Training
|
| 18 |
+
There are two training files: `train.py` and `train_distributed.py`. The training code uses mixed precision. For the settings in the paper, the following command was used for distributed training:
|
| 19 |
+
|
| 20 |
+
```
|
| 21 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
You can run the same setup on a single GPU with:
|
| 25 |
+
```
|
| 26 |
+
python -m tasks.image_classification.train tasks.image_classification.train --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp --device 0
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
|
tasks/image_classification/analysis/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analysis
|
| 2 |
+
|
| 3 |
+
This folder contains analysis code for image classifcation experiments. To build GIFs for imagenet run (from the base directory):
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
python -m tasks.image_classification.analysis.build_imagenet_viz
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
To build the plots in the paper run:
|
| 10 |
+
```
|
| 11 |
+
python -m tasks.image_classification.analysis.imagenet_evaluate_and_plot
|
| 12 |
+
```
|
tasks/image_classification/analysis/run_imagenet_analysis.py
ADDED
|
@@ -0,0 +1,972 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- Core Libraries ---
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import argparse
|
| 6 |
+
from tqdm.auto import tqdm
|
| 7 |
+
import torch.nn.functional as F # Used for interpolate
|
| 8 |
+
|
| 9 |
+
# --- Plotting & Visualization ---
|
| 10 |
+
import matplotlib.pyplot as plt
|
| 11 |
+
import matplotlib as mpl
|
| 12 |
+
mpl.use('Agg')
|
| 13 |
+
import seaborn as sns
|
| 14 |
+
sns.set_style('darkgrid')
|
| 15 |
+
from matplotlib import patheffects
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
import imageio
|
| 18 |
+
import cv2
|
| 19 |
+
from scipy.special import softmax
|
| 20 |
+
from tasks.image_classification.plotting import save_frames_to_mp4
|
| 21 |
+
|
| 22 |
+
# --- Data Handling & Model ---
|
| 23 |
+
from torchvision import transforms
|
| 24 |
+
from torchvision import datasets # Only used for CIFAR100 in debug mode
|
| 25 |
+
from scipy import ndimage # Used in find_island_centers
|
| 26 |
+
from data.custom_datasets import ImageNet
|
| 27 |
+
from models.ctm import ContinuousThoughtMachine
|
| 28 |
+
from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
|
| 29 |
+
from tasks.image_classification.plotting import plot_neural_dynamics
|
| 30 |
+
|
| 31 |
+
# --- Global Settings ---
|
| 32 |
+
np.seterr(divide='ignore')
|
| 33 |
+
mpl.use('Agg')
|
| 34 |
+
sns.set_style('darkgrid')
|
| 35 |
+
|
| 36 |
+
# --- Helper Functions ---
|
| 37 |
+
|
| 38 |
+
def find_island_centers(array_2d, threshold):
|
| 39 |
+
"""
|
| 40 |
+
Finds the center of mass of each island (connected component > threshold)
|
| 41 |
+
in a 2D array, weighted by the array's values.
|
| 42 |
+
Returns list of (y, x) centers and list of areas.
|
| 43 |
+
"""
|
| 44 |
+
binary_image = array_2d > threshold
|
| 45 |
+
labeled_image, num_labels = ndimage.label(binary_image)
|
| 46 |
+
centers = []
|
| 47 |
+
areas = []
|
| 48 |
+
# Calculate center of mass for each labeled island (label 0 is background)
|
| 49 |
+
for i in range(1, num_labels + 1):
|
| 50 |
+
island_mask = (labeled_image == i)
|
| 51 |
+
total_mass = np.sum(array_2d[island_mask])
|
| 52 |
+
if total_mass > 0:
|
| 53 |
+
# Get coordinates for this island
|
| 54 |
+
y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
|
| 55 |
+
# Calculate weighted average for center
|
| 56 |
+
x_center = np.average(x_coords[island_mask], weights=array_2d[island_mask])
|
| 57 |
+
y_center = np.average(y_coords[island_mask], weights=array_2d[island_mask])
|
| 58 |
+
centers.append((round(y_center, 4), round(x_center, 4)))
|
| 59 |
+
areas.append(np.sum(island_mask)) # Area is the count of pixels in the island
|
| 60 |
+
return centers, areas
|
| 61 |
+
|
| 62 |
+
def parse_args():
|
| 63 |
+
"""Parses command-line arguments."""
|
| 64 |
+
# Note: Original had two ArgumentParser instances, using the second one.
|
| 65 |
+
parser = argparse.ArgumentParser(description="Visualize Continuous Thought Machine Attention")
|
| 66 |
+
parser.add_argument('--actions', type=str, nargs='+', default=['videos'], choices=['plots', 'videos', 'demo'], help="Actions to take. Plots=results plots; videos=gifs/mp4s to watch attention; demo: last frame of internal ticks")
|
| 67 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help="GPU device index or -1 for CPU")
|
| 68 |
+
|
| 69 |
+
parser.add_argument('--checkpoint', type=str, default='checkpoints/imagenet/ctm_clean.pt', help="Path to ATM checkpoint")
|
| 70 |
+
parser.add_argument('--output_dir', type=str, default='tasks/image_classification/analysis/outputs/imagenet_viz', help="Directory for visualization outputs")
|
| 71 |
+
parser.add_argument('--debug', action=argparse.BooleanOptionalAction, default=True, help='Debug mode: use CIFAR100 instead of ImageNet for debugging.')
|
| 72 |
+
parser.add_argument('--plot_every', type=int, default=10, help="How often to plot.")
|
| 73 |
+
|
| 74 |
+
parser.add_argument('--inference_iterations', type=int, default=50, help="Iterations to use during inference.")
|
| 75 |
+
parser.add_argument('--data_indices', type=int, nargs='+', default=[], help="Use specific indices in validation data for demos, otherwise random.")
|
| 76 |
+
parser.add_argument('--N_to_viz', type=int, default=5, help="When not supplying data_indices.")
|
| 77 |
+
|
| 78 |
+
return parser.parse_args()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# --- Main Execution Block ---
|
| 82 |
+
if __name__=='__main__':
|
| 83 |
+
|
| 84 |
+
# --- Setup ---
|
| 85 |
+
args = parse_args()
|
| 86 |
+
if args.device[0] != -1 and torch.cuda.is_available():
|
| 87 |
+
device = f'cuda:{args.device[0]}'
|
| 88 |
+
else:
|
| 89 |
+
device = 'cpu'
|
| 90 |
+
print(f"Using device: {device}")
|
| 91 |
+
|
| 92 |
+
# --- Load Checkpoint & Model ---
|
| 93 |
+
print(f"Loading checkpoint: {args.checkpoint}")
|
| 94 |
+
checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) # removed weights_only=False
|
| 95 |
+
model_args = checkpoint['args']
|
| 96 |
+
|
| 97 |
+
# Handle legacy arguments from checkpoint if necessary
|
| 98 |
+
if not hasattr(model_args, 'backbone_type') and hasattr(model_args, 'resnet_type'):
|
| 99 |
+
model_args.backbone_type = f'{model_args.resnet_type}-{getattr(model_args, "resnet_feature_scales", [4])[-1]}'
|
| 100 |
+
if not hasattr(model_args, 'neuron_select_type'):
|
| 101 |
+
model_args.neuron_select_type = 'first-last'
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Instantiate Model based on checkpoint args
|
| 105 |
+
print("Instantiating CTM model...")
|
| 106 |
+
model = ContinuousThoughtMachine(
|
| 107 |
+
iterations=model_args.iterations,
|
| 108 |
+
d_model=model_args.d_model,
|
| 109 |
+
d_input=model_args.d_input,
|
| 110 |
+
heads=model_args.heads,
|
| 111 |
+
n_synch_out=model_args.n_synch_out,
|
| 112 |
+
n_synch_action=model_args.n_synch_action,
|
| 113 |
+
synapse_depth=model_args.synapse_depth,
|
| 114 |
+
memory_length=model_args.memory_length,
|
| 115 |
+
deep_nlms=model_args.deep_memory,
|
| 116 |
+
memory_hidden_dims=model_args.memory_hidden_dims,
|
| 117 |
+
do_layernorm_nlm=model_args.do_normalisation,
|
| 118 |
+
backbone_type=model_args.backbone_type,
|
| 119 |
+
positional_embedding_type=model_args.positional_embedding_type,
|
| 120 |
+
out_dims=model_args.out_dims,
|
| 121 |
+
prediction_reshaper=[-1], # Kept fixed value from original code
|
| 122 |
+
dropout=0, # No dropout for eval
|
| 123 |
+
neuron_select_type=model_args.neuron_select_type,
|
| 124 |
+
n_random_pairing_self=model_args.n_random_pairing_self,
|
| 125 |
+
).to(device)
|
| 126 |
+
|
| 127 |
+
# Load weights into model
|
| 128 |
+
load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=False)
|
| 129 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 130 |
+
model.eval() # Set model to evaluation mode
|
| 131 |
+
|
| 132 |
+
# --- Prepare Dataset ---
|
| 133 |
+
if args.debug:
|
| 134 |
+
print("Debug mode: Using CIFAR100")
|
| 135 |
+
# CIFAR100 specific normalization constants
|
| 136 |
+
dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
|
| 137 |
+
dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
|
| 138 |
+
img_size = 256 # Resize CIFAR images for consistency
|
| 139 |
+
transform = transforms.Compose([
|
| 140 |
+
transforms.Resize(img_size),
|
| 141 |
+
transforms.ToTensor(),
|
| 142 |
+
transforms.Normalize(mean=dataset_mean, std=dataset_std), # Normalize
|
| 143 |
+
])
|
| 144 |
+
validation_dataset = datasets.CIFAR100('data/', train=False, transform=transform, download=True)
|
| 145 |
+
validation_dataset_centercrop = datasets.CIFAR100('data/', train=True, transform=transform, download=True)
|
| 146 |
+
else:
|
| 147 |
+
print("Using ImageNet")
|
| 148 |
+
# ImageNet specific normalization constants
|
| 149 |
+
dataset_mean = [0.485, 0.456, 0.406]
|
| 150 |
+
dataset_std = [0.229, 0.224, 0.225]
|
| 151 |
+
img_size = 256 # Resize ImageNet images
|
| 152 |
+
# Note: Original comment mentioned no CenterCrop, this transform reflects that.
|
| 153 |
+
transform = transforms.Compose([
|
| 154 |
+
transforms.Resize(img_size),
|
| 155 |
+
transforms.ToTensor(),
|
| 156 |
+
transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
|
| 157 |
+
])
|
| 158 |
+
validation_dataset = ImageNet(which_split='validation', transform=transform)
|
| 159 |
+
validation_dataset_centercrop = ImageNet(which_split='train', transform=transforms.Compose([
|
| 160 |
+
transforms.Resize(img_size),
|
| 161 |
+
transforms.RandomCrop(img_size),
|
| 162 |
+
transforms.ToTensor(),
|
| 163 |
+
transforms.Normalize(mean=dataset_mean, std=dataset_std) # Normalize
|
| 164 |
+
]))
|
| 165 |
+
class_labels = list(IMAGENET2012_CLASSES.values()) # Load actual class names
|
| 166 |
+
|
| 167 |
+
os.makedirs(f'{args.output_dir}', exist_ok=True)
|
| 168 |
+
|
| 169 |
+
interp_mode = 'nearest'
|
| 170 |
+
cmap_calib = sns.color_palette('viridis', as_cmap=True)
|
| 171 |
+
loader = torch.utils.data.DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=0, drop_last=False)
|
| 172 |
+
loader_crop = torch.utils.data.DataLoader(validation_dataset_centercrop, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
|
| 173 |
+
|
| 174 |
+
model.eval()
|
| 175 |
+
|
| 176 |
+
figscale = 0.85
|
| 177 |
+
topk = 5
|
| 178 |
+
mean_certainties_correct, mean_certainties_incorrect = [],[]
|
| 179 |
+
tracked_certainties = []
|
| 180 |
+
tracked_targets = []
|
| 181 |
+
tracked_predictions = []
|
| 182 |
+
|
| 183 |
+
if model.iterations != args.inference_iterations:
|
| 184 |
+
print('WARNING: you are setting inference iterations to a value not used during training!')
|
| 185 |
+
|
| 186 |
+
model.iterations = args.inference_iterations
|
| 187 |
+
|
| 188 |
+
if 'plots' in args.actions:
|
| 189 |
+
|
| 190 |
+
with torch.inference_mode(): # Disable gradient calculations
|
| 191 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=0, dynamic_ncols=True) as pbar:
|
| 192 |
+
imgi = 0
|
| 193 |
+
for bi, (inputs, targets) in enumerate(loader):
|
| 194 |
+
inputs = inputs.to(device)
|
| 195 |
+
targets = targets.to(device)
|
| 196 |
+
if bi==0:
|
| 197 |
+
dynamics_inputs, _ = next(iter(loader_crop)) # Use this because of batching
|
| 198 |
+
_, _, _, _, post_activations_viz, _ = model(inputs, track=True)
|
| 199 |
+
plot_neural_dynamics(post_activations_viz, 15*10, args.output_dir, axis_snap=True, N_per_row=15)
|
| 200 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 201 |
+
|
| 202 |
+
tracked_predictions.append(predictions.detach().cpu().numpy())
|
| 203 |
+
tracked_targets.append(targets.detach().cpu().numpy())
|
| 204 |
+
tracked_certainties.append(certainties.detach().cpu().numpy())
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
pbar.set_description(f'Processing base image of size {inputs.shape}')
|
| 210 |
+
pbar.update(1)
|
| 211 |
+
if ((bi % args.plot_every == 0) or bi == len(loader)-1) and bi!=0: #
|
| 212 |
+
|
| 213 |
+
concatenated_certainties = np.concatenate(tracked_certainties, axis=0)
|
| 214 |
+
concatenated_targets = np.concatenate(tracked_targets, axis=0)
|
| 215 |
+
concatenated_predictions = np.concatenate(tracked_predictions, axis=0)
|
| 216 |
+
concatenated_predictions_argsorted = np.argsort(concatenated_predictions, 1)[:,::-1]
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
for topk in [1, 5]:
|
| 221 |
+
concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
|
| 222 |
+
|
| 223 |
+
accs_instant, accs_avg, accs_certain = [], [], []
|
| 224 |
+
accs_avg_logits, accs_weighted_logits = [],[]
|
| 225 |
+
with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
|
| 226 |
+
pbarinner.set_description('Acc types')
|
| 227 |
+
for stepi in np.arange(concatenated_predictions.shape[-1]):
|
| 228 |
+
pred_avg = softmax(concatenated_predictions, 1)[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
|
| 229 |
+
pred_instant = concatenated_predictions_argsorted_topk[:,:,stepi]
|
| 230 |
+
pred_certain = concatenated_predictions_argsorted_topk[np.arange(concatenated_predictions.shape[0]),:, concatenated_certainties[:,1,:stepi+1].argmax(1)]
|
| 231 |
+
pred_avg_logits = concatenated_predictions[:,:,:stepi+1].mean(-1).argsort(1)[:,-topk:]
|
| 232 |
+
pred_weighted_logits = (concatenated_predictions[:,:,:stepi+1] * concatenated_certainties[:,1:,:stepi+1]).sum(-1).argsort(1)[:, -topk:]
|
| 233 |
+
pbarinner.update(1)
|
| 234 |
+
accs_instant.append(np.any(pred_instant==concatenated_targets[...,np.newaxis], -1).mean())
|
| 235 |
+
accs_avg.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
|
| 236 |
+
accs_avg_logits.append(np.any(pred_avg==concatenated_targets[...,np.newaxis], -1).mean())
|
| 237 |
+
accs_weighted_logits.append(np.any(pred_weighted_logits==concatenated_targets[...,np.newaxis], -1).mean())
|
| 238 |
+
accs_certain.append(np.any(pred_avg_logits==concatenated_targets[...,np.newaxis], -1).mean())
|
| 239 |
+
fig = plt.figure(figsize=(10*figscale, 4*figscale))
|
| 240 |
+
ax = fig.add_subplot(111)
|
| 241 |
+
cp = sns.color_palette("bright")
|
| 242 |
+
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_instant), linestyle='-', color=cp[0], label='Instant')
|
| 243 |
+
# ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg), linestyle='--', color=cp[1], label='Based on average probability up to this step')
|
| 244 |
+
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_certain), linestyle=':', color=cp[2], label='Most certain')
|
| 245 |
+
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_avg_logits), linestyle='-.', color=cp[3], label='Average logits')
|
| 246 |
+
ax.plot(np.arange(concatenated_predictions.shape[-1])+1, 100*np.array(accs_weighted_logits), linestyle='--', color=cp[4], label='Logits weighted by certainty')
|
| 247 |
+
ax.set_xlim([0, concatenated_predictions.shape[-1]+1])
|
| 248 |
+
ax.set_ylim([75, 92])
|
| 249 |
+
ax.set_xlabel('Internal ticks')
|
| 250 |
+
ax.set_ylabel(f'Top-k={topk} accuracy')
|
| 251 |
+
ax.legend(loc='lower right')
|
| 252 |
+
fig.tight_layout(pad=0.1)
|
| 253 |
+
fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.png', dpi=200)
|
| 254 |
+
fig.savefig(f'{args.output_dir}/accuracy_types_{topk}.pdf', dpi=200)
|
| 255 |
+
plt.close(fig)
|
| 256 |
+
print(f'k={topk}. Accuracy most certain at last internal tick={100*np.array(accs_certain)[-1]:0.4f}') # Using certainty based approach
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
indices_over_80 = []
|
| 260 |
+
classes_80 = {}
|
| 261 |
+
corrects_80 = {}
|
| 262 |
+
|
| 263 |
+
topk = 5
|
| 264 |
+
concatenated_predictions_argsorted_topk = concatenated_predictions_argsorted[:,:topk]
|
| 265 |
+
for certainty_threshold in [0.5, 0.8, 0.9]:
|
| 266 |
+
# certainty_threshold = 0.6
|
| 267 |
+
percentage_corrects = []
|
| 268 |
+
percentage_incorrects = []
|
| 269 |
+
with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
|
| 270 |
+
pbarinner.set_description(f'Certainty threshold={certainty_threshold}')
|
| 271 |
+
for stepi in np.arange(concatenated_predictions.shape[-1]):
|
| 272 |
+
certainty_here = concatenated_certainties[:,1,stepi]
|
| 273 |
+
certainty_mask = certainty_here>=certainty_threshold
|
| 274 |
+
predictions_here = concatenated_predictions_argsorted_topk[:,:,stepi]
|
| 275 |
+
is_correct_here = np.any(predictions_here==concatenated_targets[...,np.newaxis], axis=-1)
|
| 276 |
+
percentage_corrects.append(is_correct_here[certainty_mask].sum()/predictions_here.shape[0])
|
| 277 |
+
percentage_incorrects.append((~is_correct_here)[certainty_mask].sum()/predictions_here.shape[0])
|
| 278 |
+
|
| 279 |
+
if certainty_threshold==0.8:
|
| 280 |
+
indices_certain = np.where(certainty_mask)[0]
|
| 281 |
+
for index in indices_certain:
|
| 282 |
+
if index not in indices_over_80:
|
| 283 |
+
indices_over_80.append(index)
|
| 284 |
+
if concatenated_targets[index] not in classes_80:
|
| 285 |
+
classes_80[concatenated_targets[index]] = [stepi]
|
| 286 |
+
corrects_80[concatenated_targets[index]] = [is_correct_here[index]]
|
| 287 |
+
else:
|
| 288 |
+
classes_80[concatenated_targets[index]] = classes_80[concatenated_targets[index]]+[stepi]
|
| 289 |
+
corrects_80[concatenated_targets[index]] = corrects_80[concatenated_targets[index]]+[is_correct_here[index]]
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
pbarinner.update(1)
|
| 293 |
+
fig = plt.figure(figsize=(6.5*figscale, 4*figscale))
|
| 294 |
+
ax = fig.add_subplot(111)
|
| 295 |
+
ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
|
| 296 |
+
percentage_corrects,
|
| 297 |
+
color='forestgreen',
|
| 298 |
+
hatch='OO',
|
| 299 |
+
width=0.9,
|
| 300 |
+
label='Positive',
|
| 301 |
+
alpha=0.9,
|
| 302 |
+
linewidth=1.0*figscale)
|
| 303 |
+
|
| 304 |
+
ax.bar(np.arange(concatenated_predictions.shape[-1])+1,
|
| 305 |
+
percentage_incorrects,
|
| 306 |
+
bottom=percentage_corrects,
|
| 307 |
+
color='crimson',
|
| 308 |
+
hatch='xx',
|
| 309 |
+
width=0.9,
|
| 310 |
+
label='Negative',
|
| 311 |
+
alpha=0.9,
|
| 312 |
+
linewidth=1.0*figscale)
|
| 313 |
+
ax.set_xlim(-1, concatenated_predictions.shape[-1]+1)
|
| 314 |
+
ax.set_xlabel('Internal tick')
|
| 315 |
+
ax.set_ylabel('% of data')
|
| 316 |
+
ax.legend(loc='lower right')
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
fig.tight_layout(pad=0.1)
|
| 320 |
+
fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.png', dpi=200)
|
| 321 |
+
fig.savefig(f'{args.output_dir}/steps_versus_correct_{certainty_threshold}.pdf', dpi=200)
|
| 322 |
+
plt.close(fig)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class_list = list(classes_80.keys())
|
| 326 |
+
mean_steps = [np.mean(classes_80[cls]) for cls in class_list]
|
| 327 |
+
std_steps = [np.std(classes_80[cls]) for cls in class_list]
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
# Following code plots the class distribution over internal ticks
|
| 331 |
+
indices_to_show = np.arange(1000)
|
| 332 |
+
|
| 333 |
+
colours = cmap_diverse = plt.get_cmap('rainbow')(np.linspace(0, 1, 1000))
|
| 334 |
+
# np.random.shuffle(colours)
|
| 335 |
+
bottom = np.zeros(concatenated_predictions.shape[-1])
|
| 336 |
+
|
| 337 |
+
fig = plt.figure(figsize=(7*figscale, 4*figscale))
|
| 338 |
+
ax = fig.add_subplot(111)
|
| 339 |
+
for iii, idx in enumerate(indices_to_show):
|
| 340 |
+
if idx in classes_80:
|
| 341 |
+
steps = classes_80[idx]
|
| 342 |
+
colour = colours[iii]
|
| 343 |
+
vs, cts = np.unique(steps, return_counts=True)
|
| 344 |
+
|
| 345 |
+
bar = np.zeros(concatenated_predictions.shape[-1])
|
| 346 |
+
bar[vs] = cts
|
| 347 |
+
ax.bar(np.arange(concatenated_predictions.shape[-1])+1, bar, bottom=bottom, color=colour, width=1, edgecolor='none')
|
| 348 |
+
bottom += bar
|
| 349 |
+
ax.set_xlabel('Internal ticks')
|
| 350 |
+
ax.set_ylabel('Counts over 0.8 certainty')
|
| 351 |
+
fig.tight_layout(pad=0.1)
|
| 352 |
+
fig.savefig(f'{args.output_dir}/class_counts.png', dpi=200)
|
| 353 |
+
fig.savefig(f'{args.output_dir}/class_counts.pdf', dpi=200)
|
| 354 |
+
plt.close(fig)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# The following code plots calibration
|
| 361 |
+
probability_space = np.linspace(0, 1, 10)
|
| 362 |
+
fig = plt.figure(figsize=(6*figscale, 4*figscale))
|
| 363 |
+
ax = fig.add_subplot(111)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
color_linspace = np.linspace(0, 1, concatenated_predictions.shape[-1])
|
| 367 |
+
with tqdm(total=(concatenated_predictions.shape[-1]), initial=0, leave=False, position=1, dynamic_ncols=True) as pbarinner:
|
| 368 |
+
pbarinner.set_description(f'Calibration')
|
| 369 |
+
for stepi in np.arange(concatenated_predictions.shape[-1]):
|
| 370 |
+
color = cmap_calib(color_linspace[stepi])
|
| 371 |
+
pred = concatenated_predictions[:,:,stepi].argmax(1)
|
| 372 |
+
is_correct = pred == concatenated_targets # BxT
|
| 373 |
+
probabilities = softmax(concatenated_predictions[:,:,:stepi+1], axis=1)[np.arange(concatenated_predictions.shape[0]),pred].mean(-1)#softmax(concatenated_predictions[:,:,stepi], axis=1).max(1)
|
| 374 |
+
probability_space = np.linspace(0, 1, 10)
|
| 375 |
+
accuracies_per_bin = []
|
| 376 |
+
bin_centers = []
|
| 377 |
+
for pi in range(len(probability_space)-1):
|
| 378 |
+
bin_low = probability_space[pi]
|
| 379 |
+
bin_high = probability_space[pi+1]
|
| 380 |
+
mask = ((probabilities >=bin_low) & (probabilities < bin_high)) if pi !=len(probability_space)-2 else ((probabilities >=bin_low) & (probabilities <= bin_high))
|
| 381 |
+
accuracies_per_bin.append(is_correct[mask].mean())
|
| 382 |
+
bin_centers.append(probabilities[mask].mean())
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
if stepi==concatenated_predictions.shape[-1]-1:
|
| 386 |
+
ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color='#4050f7', alpha=1, label='After all ticks')
|
| 387 |
+
else: ax.plot(bin_centers, accuracies_per_bin, linestyle='-', marker='.', color=color, alpha=0.65)
|
| 388 |
+
pbarinner.update(1)
|
| 389 |
+
ax.plot(probability_space, np.linspace(0, 1, len(probability_space)), 'k--')
|
| 390 |
+
|
| 391 |
+
ax.legend(loc='upper left')
|
| 392 |
+
ax.set_xlim([-0.01, 1.01])
|
| 393 |
+
ax.set_ylim([-0.01, 1.01])
|
| 394 |
+
|
| 395 |
+
sm = plt.cm.ScalarMappable(cmap=cmap_calib, norm=plt.Normalize(vmin=0, vmax=concatenated_predictions.shape[-1] - 1))
|
| 396 |
+
sm.set_array([]) # Empty array for colormap
|
| 397 |
+
cbar = fig.colorbar(sm, ax=ax, orientation='vertical', pad=0.02)
|
| 398 |
+
cbar.set_label('Internal ticks')
|
| 399 |
+
|
| 400 |
+
ax.set_xlabel('Mean predicted probabilities')
|
| 401 |
+
ax.set_ylabel('Ratio of positives')
|
| 402 |
+
fig.tight_layout(pad=0.1)
|
| 403 |
+
fig.savefig(f'{args.output_dir}/imagenet_calibration.png', dpi=200)
|
| 404 |
+
fig.savefig(f'{args.output_dir}/imagenet_calibration.pdf', dpi=200)
|
| 405 |
+
plt.close(fig)
|
| 406 |
+
if 'videos' in args.actions:
|
| 407 |
+
if not args.data_indices: # If list is empty
|
| 408 |
+
n_samples = len(validation_dataset)
|
| 409 |
+
num_to_sample = min(args.N_to_viz, n_samples)
|
| 410 |
+
replace = n_samples < num_to_sample
|
| 411 |
+
data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
|
| 412 |
+
print(f"Selected random indices: {data_indices}")
|
| 413 |
+
else:
|
| 414 |
+
data_indices = args.data_indices
|
| 415 |
+
print(f"Using specified indices: {data_indices}")
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
for di in data_indices:
|
| 419 |
+
print(f'\nBuilding viz for dataset index {di}.')
|
| 420 |
+
|
| 421 |
+
# --- Get Data & Run Inference ---
|
| 422 |
+
# inputs_norm is already normalized by the transform
|
| 423 |
+
inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
|
| 424 |
+
|
| 425 |
+
# Add batch dimension and send to device
|
| 426 |
+
inputs = inputs.to(device).unsqueeze(0)
|
| 427 |
+
|
| 428 |
+
# Run model inference
|
| 429 |
+
predictions, certainties, synchronisation, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
|
| 430 |
+
# predictions: (B, Classes, Steps), attention_tracking: (Steps*B*Heads, SeqLen)
|
| 431 |
+
n_steps = predictions.size(-1)
|
| 432 |
+
|
| 433 |
+
# --- Reshape Attention ---
|
| 434 |
+
# Infer feature map size from model internals (assuming B=1)
|
| 435 |
+
h_feat, w_feat = model.kv_features.shape[-2:]
|
| 436 |
+
|
| 437 |
+
n_heads = attention_tracking.shape[2]
|
| 438 |
+
# Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
|
| 439 |
+
attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
|
| 440 |
+
|
| 441 |
+
# --- Setup for Plotting ---
|
| 442 |
+
step_linspace = np.linspace(0, 1, n_steps) # For step colors
|
| 443 |
+
# Define color maps
|
| 444 |
+
cmap_spectral = sns.color_palette("Spectral", as_cmap=True)
|
| 445 |
+
cmap_attention = sns.color_palette('viridis', as_cmap=True)
|
| 446 |
+
|
| 447 |
+
# Create output directory for this index
|
| 448 |
+
index_output_dir = os.path.join(args.output_dir, str(di))
|
| 449 |
+
os.makedirs(index_output_dir, exist_ok=True)
|
| 450 |
+
|
| 451 |
+
frames = [] # Store frames for GIF
|
| 452 |
+
head_routes = {h: [] for h in range(n_heads)} # Store (y,x) path points per head
|
| 453 |
+
head_routes[-1] = []
|
| 454 |
+
route_colours_step = [] # Store colors for each step's path segments
|
| 455 |
+
|
| 456 |
+
# --- Loop Through Each Step ---
|
| 457 |
+
for step_i in range(n_steps):
|
| 458 |
+
|
| 459 |
+
# --- Prepare Image for Display ---
|
| 460 |
+
# Denormalize the input tensor for visualization
|
| 461 |
+
data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
|
| 462 |
+
mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
|
| 463 |
+
std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
|
| 464 |
+
data_img_denorm = data_img_tensor * std_tensor + mean_tensor
|
| 465 |
+
# Permute to (H, W, C) and convert to numpy, clip to [0, 1]
|
| 466 |
+
data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
|
| 467 |
+
data_img_np = np.clip(data_img_np, 0, 1)
|
| 468 |
+
img_h, img_w = data_img_np.shape[:2]
|
| 469 |
+
|
| 470 |
+
# --- Process Attention & Certainty ---
|
| 471 |
+
# Average attention over last few steps (from original code)
|
| 472 |
+
start_step = max(0, step_i - 5)
|
| 473 |
+
attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
|
| 474 |
+
# Get certainties up to current step
|
| 475 |
+
certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
|
| 476 |
+
|
| 477 |
+
# --- Calculate Attention Paths (using bilinear interp) ---
|
| 478 |
+
# Interpolate attention to image size using bilinear for center finding
|
| 479 |
+
attention_interp_bilinear = F.interpolate(
|
| 480 |
+
torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
|
| 481 |
+
size=(img_h, img_w),
|
| 482 |
+
mode=interp_mode,
|
| 483 |
+
# align_corners=False
|
| 484 |
+
).squeeze(0) # Remove batch dim -> (Heads, H, W)
|
| 485 |
+
|
| 486 |
+
# Normalize each head's map to [0, 1]
|
| 487 |
+
# Deal with mean
|
| 488 |
+
attn_mean = attention_interp_bilinear.mean(0)
|
| 489 |
+
attn_mean_min = attn_mean.min()
|
| 490 |
+
attn_mean_max = attn_mean.max()
|
| 491 |
+
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
|
| 492 |
+
centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
|
| 493 |
+
|
| 494 |
+
if centers: # If islands found
|
| 495 |
+
largest_island_idx = np.argmax(areas)
|
| 496 |
+
current_center = centers[largest_island_idx] # (y, x)
|
| 497 |
+
head_routes[-1].append(current_center)
|
| 498 |
+
elif head_routes[-1]: # If no center now, repeat last known center if history exists
|
| 499 |
+
head_routes[-1].append(head_routes[-1][-1])
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
attn_min = attention_interp_bilinear.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
| 503 |
+
attn_max = attention_interp_bilinear.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
| 504 |
+
attention_interp_bilinear = (attention_interp_bilinear - attn_min) / (attn_max - attn_min + 1e-6)
|
| 505 |
+
|
| 506 |
+
# Store step color
|
| 507 |
+
current_colour = list(cmap_spectral(step_linspace[step_i]))
|
| 508 |
+
route_colours_step.append(current_colour)
|
| 509 |
+
|
| 510 |
+
# Find island center for each head
|
| 511 |
+
for head_i in range(n_heads):
|
| 512 |
+
attn_head_np = attention_interp_bilinear[head_i].detach().cpu().numpy()
|
| 513 |
+
# Keep threshold=0.7 based on original call
|
| 514 |
+
centers, areas = find_island_centers(attn_head_np, threshold=0.7)
|
| 515 |
+
|
| 516 |
+
if centers: # If islands found
|
| 517 |
+
largest_island_idx = np.argmax(areas)
|
| 518 |
+
current_center = centers[largest_island_idx] # (y, x)
|
| 519 |
+
head_routes[head_i].append(current_center)
|
| 520 |
+
elif head_routes[head_i]: # If no center now, repeat last known center if history exists
|
| 521 |
+
head_routes[head_i].append(head_routes[head_i][-1])
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
# --- Plotting Setup ---
|
| 526 |
+
mosaic = [['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
|
| 527 |
+
['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
|
| 528 |
+
['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
|
| 529 |
+
['head_mean', 'head_mean', 'head_mean', 'head_mean', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay', 'head_mean_overlay'],
|
| 530 |
+
['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay', 'head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],
|
| 531 |
+
['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay','head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],
|
| 532 |
+
['head_8', 'head_8_overlay', 'head_9', 'head_9_overlay','head_10', 'head_10_overlay', 'head_11', 'head_11_overlay'],
|
| 533 |
+
['head_12', 'head_12_overlay', 'head_13', 'head_13_overlay','head_14', 'head_14_overlay', 'head_15', 'head_15_overlay'],
|
| 534 |
+
['probabilities', 'probabilities','probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty'],
|
| 535 |
+
]
|
| 536 |
+
|
| 537 |
+
img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
|
| 538 |
+
aspect_ratio = (8 * figscale, 9 * figscale * img_aspect) # W, H
|
| 539 |
+
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 540 |
+
|
| 541 |
+
for ax in axes.values():
|
| 542 |
+
ax.axis('off')
|
| 543 |
+
|
| 544 |
+
# --- Plot Certainty ---
|
| 545 |
+
ax_cert = axes['certainty']
|
| 546 |
+
ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
|
| 547 |
+
# Add background color based on prediction correctness at each step
|
| 548 |
+
for ii in range(len(certainties_now)):
|
| 549 |
+
is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
|
| 550 |
+
facecolor = 'limegreen' if is_correct else 'orchid'
|
| 551 |
+
ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
|
| 552 |
+
# Mark the last point
|
| 553 |
+
ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
|
| 554 |
+
ax_cert.axis('off')
|
| 555 |
+
ax_cert.set_ylim([0.05, 1.05])
|
| 556 |
+
ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
|
| 557 |
+
|
| 558 |
+
# --- Plot Probabilities ---
|
| 559 |
+
ax_prob = axes['probabilities']
|
| 560 |
+
# Get probabilities for the current step
|
| 561 |
+
ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
|
| 562 |
+
k = 15 # Top k predictions
|
| 563 |
+
topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
|
| 564 |
+
topk_indices = topk_indices.numpy()
|
| 565 |
+
topk_probs = topk_probs.numpy()
|
| 566 |
+
|
| 567 |
+
top_classes = np.array(class_labels)[topk_indices]
|
| 568 |
+
true_class_idx = ground_truth_target # Ground truth index
|
| 569 |
+
|
| 570 |
+
# Determine bar colors (green if correct, blue otherwise - consistent with original)
|
| 571 |
+
colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
|
| 572 |
+
|
| 573 |
+
# Plot horizontal bars (inverted range for top-down display)
|
| 574 |
+
ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
|
| 575 |
+
ax_prob.set_xlim([0, 1])
|
| 576 |
+
ax_prob.axis('off')
|
| 577 |
+
|
| 578 |
+
# Add text labels for top classes
|
| 579 |
+
for i, name_idx in enumerate(topk_indices):
|
| 580 |
+
name = class_labels[name_idx] # Get name from index
|
| 581 |
+
is_correct = name_idx == true_class_idx
|
| 582 |
+
fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
|
| 583 |
+
text_str = f'{name[:40]}' # Truncate long names
|
| 584 |
+
# Position text on the left side of the horizontal bars
|
| 585 |
+
ax_prob.text(
|
| 586 |
+
0.01, # Small offset from left edge
|
| 587 |
+
k - 1 - i, # Y-position corresponding to the bar
|
| 588 |
+
text_str,
|
| 589 |
+
#transform=ax_prob.transAxes, # Use data coordinates for Y
|
| 590 |
+
verticalalignment='center',
|
| 591 |
+
horizontalalignment='left',
|
| 592 |
+
fontsize=8,
|
| 593 |
+
color=fg_color,
|
| 594 |
+
alpha=0.9, # Slightly more visible than 0.5
|
| 595 |
+
path_effects=[
|
| 596 |
+
patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
|
| 597 |
+
patheffects.Normal()
|
| 598 |
+
])
|
| 599 |
+
|
| 600 |
+
|
| 601 |
+
# --- Plot Attention Heads & Overlays (using nearest interp) ---
|
| 602 |
+
# Re-interpolate attention using nearest neighbor for visual plotting
|
| 603 |
+
attention_interp_plot = F.interpolate(
|
| 604 |
+
torch.from_numpy(attention_now).unsqueeze(0).float(),
|
| 605 |
+
size=(img_h, img_w),
|
| 606 |
+
mode=interp_mode, # 'nearest'
|
| 607 |
+
).squeeze(0)
|
| 608 |
+
|
| 609 |
+
attn_mean = attention_interp_plot.mean(0)
|
| 610 |
+
attn_mean_min = attn_mean.min()
|
| 611 |
+
attn_mean_max = attn_mean.max()
|
| 612 |
+
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
# Normalize each head's map to [0, 1]
|
| 616 |
+
attn_min_plot = attention_interp_plot.view(n_heads, -1).min(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
| 617 |
+
attn_max_plot = attention_interp_plot.view(n_heads, -1).max(dim=-1, keepdim=True)[0].unsqueeze(-1)
|
| 618 |
+
attention_interp_plot = (attention_interp_plot - attn_min_plot) / (attn_max_plot - attn_min_plot + 1e-6)
|
| 619 |
+
attention_interp_plot_np = attention_interp_plot.detach().cpu().numpy()
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
for head_i in list(range(n_heads)) + [-1]:
|
| 627 |
+
axname = f'head_{head_i}' if head_i != -1 else 'head_mean'
|
| 628 |
+
if axname not in axes: continue # Skip if mosaic doesn't have this head
|
| 629 |
+
|
| 630 |
+
ax = axes[axname]
|
| 631 |
+
ax_overlay = axes[f'{axname}_overlay']
|
| 632 |
+
|
| 633 |
+
# Plot attention heatmap
|
| 634 |
+
this_attn = attention_interp_plot_np[head_i] if head_i != -1 else attn_mean
|
| 635 |
+
img_to_plot = cmap_attention(this_attn)
|
| 636 |
+
ax.imshow(img_to_plot)
|
| 637 |
+
ax.axis('off')
|
| 638 |
+
|
| 639 |
+
# Plot overlay: image + paths
|
| 640 |
+
these_route_steps = head_routes[head_i]
|
| 641 |
+
arrow_scale = 1.5 if head_i != -1 else 3
|
| 642 |
+
|
| 643 |
+
if these_route_steps: # Only plot if path exists
|
| 644 |
+
# Separate y and x coordinates
|
| 645 |
+
y_coords, x_coords = zip(*these_route_steps)
|
| 646 |
+
y_coords = np.array(y_coords)
|
| 647 |
+
x_coords = np.array(x_coords)
|
| 648 |
+
|
| 649 |
+
# Flip y-coordinates for correct plotting (imshow origin is top-left)
|
| 650 |
+
# NOTE: Original flip seemed complex, simplifying to standard flip
|
| 651 |
+
y_coords_flipped = img_h - 1 - y_coords
|
| 652 |
+
|
| 653 |
+
# Show original image flipped vertically to match coordinate system
|
| 654 |
+
ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
|
| 655 |
+
|
| 656 |
+
# Draw arrows for path segments
|
| 657 |
+
# Arrow size scaling from original
|
| 658 |
+
for i in range(len(these_route_steps) - 1):
|
| 659 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 660 |
+
dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
|
| 661 |
+
|
| 662 |
+
# Draw white background arrow (thicker)
|
| 663 |
+
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
|
| 664 |
+
linewidth=1.6 * arrow_scale * 1.3,
|
| 665 |
+
head_width=1.9 * arrow_scale * 1.3,
|
| 666 |
+
head_length=1.4 * arrow_scale * 1.45,
|
| 667 |
+
fc='white', ec='white', length_includes_head=True, alpha=1)
|
| 668 |
+
# Draw colored foreground arrow
|
| 669 |
+
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
|
| 670 |
+
linewidth=1.6 * arrow_scale,
|
| 671 |
+
head_width=1.9 * arrow_scale,
|
| 672 |
+
head_length=1.4 * arrow_scale,
|
| 673 |
+
fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
|
| 674 |
+
length_includes_head=True)
|
| 675 |
+
|
| 676 |
+
else: # If no path yet, just show the image
|
| 677 |
+
ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
# Set limits and turn off axes for overlay
|
| 681 |
+
ax_overlay.set_xlim([0, img_w - 1])
|
| 682 |
+
ax_overlay.set_ylim([0, img_h - 1])
|
| 683 |
+
ax_overlay.axis('off')
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
# --- Finalize and Save Frame ---
|
| 687 |
+
fig.tight_layout(pad=0.1) # Adjust spacing
|
| 688 |
+
|
| 689 |
+
# Render the plot to a numpy array
|
| 690 |
+
canvas = fig.canvas
|
| 691 |
+
canvas.draw()
|
| 692 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 693 |
+
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
|
| 694 |
+
|
| 695 |
+
frames.append(image_numpy) # Add to list for GIF
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
plt.close(fig) # Close figure to free memory
|
| 700 |
+
|
| 701 |
+
# --- Save GIF ---
|
| 702 |
+
gif_path = os.path.join(index_output_dir, f'{str(di)}_viz.gif')
|
| 703 |
+
print(f"Saving GIF to {gif_path}...")
|
| 704 |
+
imageio.mimsave(gif_path, frames, fps=15, loop=0) # loop=0 means infinite loop
|
| 705 |
+
save_frames_to_mp4([fm[:,:,::-1] for fm in frames], os.path.join(index_output_dir, f'{str(di)}_viz.mp4'), fps=15, gop_size=1, preset='veryslow')
|
| 706 |
+
if 'demo' in args.actions:
|
| 707 |
+
|
| 708 |
+
|
| 709 |
+
|
| 710 |
+
# --- Select Data Indices ---
|
| 711 |
+
if not args.data_indices: # If list is empty
|
| 712 |
+
n_samples = len(validation_dataset)
|
| 713 |
+
num_to_sample = min(args.N_to_viz, n_samples)
|
| 714 |
+
replace = n_samples < num_to_sample
|
| 715 |
+
data_indices = np.random.choice(np.arange(n_samples), size=num_to_sample, replace=replace)
|
| 716 |
+
print(f"Selected random indices: {data_indices}")
|
| 717 |
+
else:
|
| 718 |
+
data_indices = args.data_indices
|
| 719 |
+
print(f"Using specified indices: {data_indices}")
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
for di in data_indices:
|
| 723 |
+
|
| 724 |
+
index_output_dir = os.path.join(args.output_dir, str(di))
|
| 725 |
+
os.makedirs(index_output_dir, exist_ok=True)
|
| 726 |
+
|
| 727 |
+
print(f'\nBuilding viz for dataset index {di}.')
|
| 728 |
+
|
| 729 |
+
inputs, ground_truth_target = validation_dataset.__getitem__(int(di))
|
| 730 |
+
|
| 731 |
+
# Add batch dimension and send to device
|
| 732 |
+
inputs = inputs.to(device).unsqueeze(0)
|
| 733 |
+
predictions, certainties, synchronisations_over_time, pre_activations, post_activations, attention_tracking = model(inputs, track=True)
|
| 734 |
+
|
| 735 |
+
# --- Reshape Attention ---
|
| 736 |
+
# Infer feature map size from model internals (assuming B=1)
|
| 737 |
+
h_feat, w_feat = model.kv_features.shape[-2:]
|
| 738 |
+
n_steps = predictions.size(-1)
|
| 739 |
+
n_heads = attention_tracking.shape[2]
|
| 740 |
+
# Reshape to (Steps, Heads, H_feat, W_feat) assuming B=1
|
| 741 |
+
attention_tracking = attention_tracking.reshape(n_steps, n_heads, h_feat, w_feat)
|
| 742 |
+
|
| 743 |
+
# --- Setup for Plotting ---
|
| 744 |
+
step_linspace = np.linspace(0, 1, n_steps) # For step colors
|
| 745 |
+
# Define color maps
|
| 746 |
+
cmap_steps = sns.color_palette("Spectral", as_cmap=True)
|
| 747 |
+
cmap_attention = sns.color_palette('viridis', as_cmap=True)
|
| 748 |
+
|
| 749 |
+
# Create output directory for this index
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
frames = [] # Store frames for GIF
|
| 753 |
+
head_routes = [] # Store (y,x) path points per head
|
| 754 |
+
route_colours_step = [] # Store colors for each step's path segments
|
| 755 |
+
|
| 756 |
+
# --- Loop Through Each Step ---
|
| 757 |
+
for step_i in range(n_steps):
|
| 758 |
+
|
| 759 |
+
# Store step color
|
| 760 |
+
current_colour = list(cmap_steps(step_linspace[step_i]))
|
| 761 |
+
route_colours_step.append(current_colour)
|
| 762 |
+
|
| 763 |
+
# --- Prepare Image for Display ---
|
| 764 |
+
# Denormalize the input tensor for visualization
|
| 765 |
+
data_img_tensor = inputs[0].cpu() # Get first item in batch, move to CPU
|
| 766 |
+
mean_tensor = torch.tensor(dataset_mean).view(3, 1, 1)
|
| 767 |
+
std_tensor = torch.tensor(dataset_std).view(3, 1, 1)
|
| 768 |
+
data_img_denorm = data_img_tensor * std_tensor + mean_tensor
|
| 769 |
+
# Permute to (H, W, C) and convert to numpy, clip to [0, 1]
|
| 770 |
+
data_img_np = data_img_denorm.permute(1, 2, 0).detach().numpy()
|
| 771 |
+
data_img_np = np.clip(data_img_np, 0, 1)
|
| 772 |
+
img_h, img_w = data_img_np.shape[:2]
|
| 773 |
+
|
| 774 |
+
# --- Process Attention & Certainty ---
|
| 775 |
+
# Average attention over last few steps (from original code)
|
| 776 |
+
start_step = max(0, step_i - 5)
|
| 777 |
+
attention_now = attention_tracking[start_step : step_i + 1].mean(0) # Avg over steps -> (Heads, H_feat, W_feat)
|
| 778 |
+
# Get certainties up to current step
|
| 779 |
+
certainties_now = certainties[0, 1, :step_i+1].detach().cpu().numpy() # Assuming index 1 holds relevant certainty
|
| 780 |
+
|
| 781 |
+
# --- Calculate Attention Paths (using bilinear interp) ---
|
| 782 |
+
# Interpolate attention to image size using bilinear for center finding
|
| 783 |
+
attention_interp_bilinear = F.interpolate(
|
| 784 |
+
torch.from_numpy(attention_now).unsqueeze(0).float(), # Add batch dim, ensure float
|
| 785 |
+
size=(img_h, img_w),
|
| 786 |
+
mode=interp_mode,
|
| 787 |
+
).squeeze(0) # Remove batch dim -> (Heads, H, W)
|
| 788 |
+
|
| 789 |
+
attn_mean = attention_interp_bilinear.mean(0)
|
| 790 |
+
attn_mean_min = attn_mean.min()
|
| 791 |
+
attn_mean_max = attn_mean.max()
|
| 792 |
+
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
|
| 793 |
+
centers, areas = find_island_centers(attn_mean.detach().cpu().numpy(), threshold=0.7)
|
| 794 |
+
|
| 795 |
+
if centers: # If islands found
|
| 796 |
+
largest_island_idx = np.argmax(areas)
|
| 797 |
+
current_center = centers[largest_island_idx] # (y, x)
|
| 798 |
+
head_routes.append(current_center)
|
| 799 |
+
elif head_routes: # If no center now, repeat last known center if history exists
|
| 800 |
+
head_routes.append(head_routes[-1])
|
| 801 |
+
|
| 802 |
+
# --- Plotting Setup ---
|
| 803 |
+
# if n_heads != 8: print(f"Warning: Plotting layout assumes 8 heads, found {n_heads}. Layout may be incorrect.")
|
| 804 |
+
mosaic = [['head_0', 'head_1', 'head_2', 'head_3', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
|
| 805 |
+
['head_4', 'head_5', 'head_6', 'head_7', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
|
| 806 |
+
['head_8', 'head_9', 'head_10', 'head_11', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
|
| 807 |
+
['head_12', 'head_13', 'head_14', 'head_15', 'head_mean', 'head_mean', 'head_mean', 'head_mean', 'overlay', 'overlay', 'overlay', 'overlay'],
|
| 808 |
+
['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
|
| 809 |
+
['probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'probabilities', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty'],
|
| 810 |
+
]
|
| 811 |
+
|
| 812 |
+
img_aspect = data_img_np.shape[0] / data_img_np.shape[1]
|
| 813 |
+
aspect_ratio = (12 * figscale, 6 * figscale * img_aspect) # W, H
|
| 814 |
+
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 815 |
+
for ax in axes.values():
|
| 816 |
+
ax.axis('off')
|
| 817 |
+
|
| 818 |
+
# --- Plot Certainty ---
|
| 819 |
+
ax_cert = axes['certainty']
|
| 820 |
+
ax_cert.plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1)
|
| 821 |
+
# Add background color based on prediction correctness at each step
|
| 822 |
+
for ii in range(len(certainties_now)):
|
| 823 |
+
is_correct = predictions[0, :, ii].argmax(-1).item() == ground_truth_target # .item() for scalar tensor
|
| 824 |
+
facecolor = 'limegreen' if is_correct else 'orchid'
|
| 825 |
+
ax_cert.axvspan(ii, ii + 1, facecolor=facecolor, edgecolor=None, lw=0, alpha=0.3)
|
| 826 |
+
# Mark the last point
|
| 827 |
+
ax_cert.plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
|
| 828 |
+
ax_cert.axis('off')
|
| 829 |
+
ax_cert.set_ylim([0.05, 1.05])
|
| 830 |
+
ax_cert.set_xlim([0, n_steps]) # Use n_steps for consistent x-axis limit
|
| 831 |
+
|
| 832 |
+
# --- Plot Probabilities ---
|
| 833 |
+
ax_prob = axes['probabilities']
|
| 834 |
+
# Get probabilities for the current step
|
| 835 |
+
ps = torch.softmax(predictions[0, :, step_i], -1).detach().cpu()
|
| 836 |
+
k = 15 # Top k predictions
|
| 837 |
+
topk_probs, topk_indices = torch.topk(ps, k, dim=0, largest=True)
|
| 838 |
+
topk_indices = topk_indices.numpy()
|
| 839 |
+
topk_probs = topk_probs.numpy()
|
| 840 |
+
|
| 841 |
+
top_classes = np.array(class_labels)[topk_indices]
|
| 842 |
+
true_class_idx = ground_truth_target # Ground truth index
|
| 843 |
+
|
| 844 |
+
# Determine bar colors (green if correct, blue otherwise - consistent with original)
|
| 845 |
+
colours = ['g' if idx == true_class_idx else 'b' for idx in topk_indices]
|
| 846 |
+
|
| 847 |
+
# Plot horizontal bars (inverted range for top-down display)
|
| 848 |
+
ax_prob.barh(np.arange(k)[::-1], topk_probs, color=colours, alpha=1) # Use barh and inverted range
|
| 849 |
+
ax_prob.set_xlim([0, 1])
|
| 850 |
+
ax_prob.axis('off')
|
| 851 |
+
|
| 852 |
+
# Add text labels for top classes
|
| 853 |
+
for i, name_idx in enumerate(topk_indices):
|
| 854 |
+
name = class_labels[name_idx] # Get name from index
|
| 855 |
+
is_correct = name_idx == true_class_idx
|
| 856 |
+
fg_color = 'darkgreen' if is_correct else 'crimson' # Text colors from original
|
| 857 |
+
text_str = f'{name[:40]}' # Truncate long names
|
| 858 |
+
# Position text on the left side of the horizontal bars
|
| 859 |
+
ax_prob.text(
|
| 860 |
+
0.01, # Small offset from left edge
|
| 861 |
+
k - 1 - i, # Y-position corresponding to the bar
|
| 862 |
+
text_str,
|
| 863 |
+
#transform=ax_prob.transAxes, # Use data coordinates for Y
|
| 864 |
+
verticalalignment='center',
|
| 865 |
+
horizontalalignment='left',
|
| 866 |
+
fontsize=8,
|
| 867 |
+
color=fg_color,
|
| 868 |
+
alpha=0.7, # Slightly more visible than 0.5
|
| 869 |
+
path_effects=[
|
| 870 |
+
patheffects.Stroke(linewidth=2, foreground='white'), # Adjusted stroke
|
| 871 |
+
patheffects.Normal()
|
| 872 |
+
])
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
# --- Plot Attention Heads & Overlays (using nearest interp) ---
|
| 876 |
+
# Re-interpolate attention using nearest neighbor for visual plotting
|
| 877 |
+
attention_interp_plot = F.interpolate(
|
| 878 |
+
torch.from_numpy(attention_now).unsqueeze(0).float(),
|
| 879 |
+
size=(img_h, img_w),
|
| 880 |
+
mode=interp_mode # 'nearest'
|
| 881 |
+
).squeeze(0)
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
attn_mean = attention_interp_plot.mean(0)
|
| 885 |
+
attn_mean_min = attn_mean.min()
|
| 886 |
+
attn_mean_max = attn_mean.max()
|
| 887 |
+
attn_mean = (attn_mean - attn_mean_min) / (attn_mean_max - attn_mean_min)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
img_to_plot = cmap_attention(attn_mean)
|
| 891 |
+
axes['head_mean'].imshow(img_to_plot)
|
| 892 |
+
axes['head_mean'].axis('off')
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
these_route_steps = head_routes
|
| 896 |
+
ax_overlay = axes['overlay']
|
| 897 |
+
|
| 898 |
+
if these_route_steps: # Only plot if path exists
|
| 899 |
+
# Separate y and x coordinates
|
| 900 |
+
y_coords, x_coords = zip(*these_route_steps)
|
| 901 |
+
y_coords = np.array(y_coords)
|
| 902 |
+
x_coords = np.array(x_coords)
|
| 903 |
+
|
| 904 |
+
# Flip y-coordinates for correct plotting (imshow origin is top-left)
|
| 905 |
+
# NOTE: Original flip seemed complex, simplifying to standard flip
|
| 906 |
+
y_coords_flipped = img_h - 1 - y_coords
|
| 907 |
+
|
| 908 |
+
# Show original image flipped vertically to match coordinate system
|
| 909 |
+
ax_overlay.imshow(np.flipud(data_img_np), origin='lower')
|
| 910 |
+
|
| 911 |
+
# Draw arrows for path segments
|
| 912 |
+
arrow_scale = 2 # Arrow size scaling from original
|
| 913 |
+
for i in range(len(these_route_steps) - 1):
|
| 914 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 915 |
+
dy = y_coords_flipped[i+1] - y_coords_flipped[i] # Use flipped y for delta
|
| 916 |
+
|
| 917 |
+
# Draw white background arrow (thicker)
|
| 918 |
+
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
|
| 919 |
+
linewidth=1.6 * arrow_scale * 1.3,
|
| 920 |
+
head_width=1.9 * arrow_scale * 1.3,
|
| 921 |
+
head_length=1.4 * arrow_scale * 1.45,
|
| 922 |
+
fc='white', ec='white', length_includes_head=True, alpha=1)
|
| 923 |
+
# Draw colored foreground arrow
|
| 924 |
+
ax_overlay.arrow(x_coords[i], y_coords_flipped[i], dx, dy,
|
| 925 |
+
linewidth=1.6 * arrow_scale,
|
| 926 |
+
head_width=1.9 * arrow_scale,
|
| 927 |
+
head_length=1.4 * arrow_scale,
|
| 928 |
+
fc=route_colours_step[i], ec=route_colours_step[i], # Use step color
|
| 929 |
+
length_includes_head=True)
|
| 930 |
+
# Set limits and turn off axes for overlay
|
| 931 |
+
ax_overlay.set_xlim([0, img_w - 1])
|
| 932 |
+
ax_overlay.set_ylim([0, img_h - 1])
|
| 933 |
+
ax_overlay.axis('off')
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
for head_i in range(n_heads):
|
| 937 |
+
if f'head_{head_i}' not in axes: continue # Skip if mosaic doesn't have this head
|
| 938 |
+
|
| 939 |
+
ax = axes[f'head_{head_i}']
|
| 940 |
+
|
| 941 |
+
# Plot attention heatmap
|
| 942 |
+
attn_up_to_now = attention_tracking[:step_i + 1, head_i].mean(0)
|
| 943 |
+
attn_up_to_now = (attn_up_to_now - attn_up_to_now.min())/(attn_up_to_now.max() - attn_up_to_now.min())
|
| 944 |
+
img_to_plot = cmap_attention(attn_up_to_now)
|
| 945 |
+
ax.imshow(img_to_plot)
|
| 946 |
+
ax.axis('off')
|
| 947 |
+
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
# --- Finalize and Save Frame ---
|
| 954 |
+
fig.tight_layout(pad=0.1) # Adjust spacing
|
| 955 |
+
|
| 956 |
+
# Render the plot to a numpy array
|
| 957 |
+
canvas = fig.canvas
|
| 958 |
+
canvas.draw()
|
| 959 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 960 |
+
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
|
| 961 |
+
|
| 962 |
+
frames.append(image_numpy) # Add to list for GIF
|
| 963 |
+
|
| 964 |
+
# Save individual frame if requested
|
| 965 |
+
if step_i==model.iterations-1:
|
| 966 |
+
fig.savefig(os.path.join(index_output_dir, f'frame_{step_i}.png'), dpi=200)
|
| 967 |
+
|
| 968 |
+
plt.close(fig) # Close figure to free memory
|
| 969 |
+
outfilename = os.path.join(index_output_dir, f'{di}_demo.mp4')
|
| 970 |
+
save_frames_to_mp4([fm[:,:,::-1] for fm in frames], outfilename, fps=15, gop_size=1, preset='veryslow')
|
| 971 |
+
|
| 972 |
+
|
tasks/image_classification/imagenet_classes.py
ADDED
|
@@ -0,0 +1,1007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
IMAGENET2012_CLASSES = OrderedDict(
|
| 5 |
+
{
|
| 6 |
+
"n01440764": "tench, Tinca tinca",
|
| 7 |
+
"n01443537": "goldfish, Carassius auratus",
|
| 8 |
+
"n01484850": "great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias",
|
| 9 |
+
"n01491361": "tiger shark, Galeocerdo cuvieri",
|
| 10 |
+
"n01494475": "hammerhead, hammerhead shark",
|
| 11 |
+
"n01496331": "electric ray, crampfish, numbfish, torpedo",
|
| 12 |
+
"n01498041": "stingray",
|
| 13 |
+
"n01514668": "cock",
|
| 14 |
+
"n01514859": "hen",
|
| 15 |
+
"n01518878": "ostrich, Struthio camelus",
|
| 16 |
+
"n01530575": "brambling, Fringilla montifringilla",
|
| 17 |
+
"n01531178": "goldfinch, Carduelis carduelis",
|
| 18 |
+
"n01532829": "house finch, linnet, Carpodacus mexicanus",
|
| 19 |
+
"n01534433": "junco, snowbird",
|
| 20 |
+
"n01537544": "indigo bunting, indigo finch, indigo bird, Passerina cyanea",
|
| 21 |
+
"n01558993": "robin, American robin, Turdus migratorius",
|
| 22 |
+
"n01560419": "bulbul",
|
| 23 |
+
"n01580077": "jay",
|
| 24 |
+
"n01582220": "magpie",
|
| 25 |
+
"n01592084": "chickadee",
|
| 26 |
+
"n01601694": "water ouzel, dipper",
|
| 27 |
+
"n01608432": "kite",
|
| 28 |
+
"n01614925": "bald eagle, American eagle, Haliaeetus leucocephalus",
|
| 29 |
+
"n01616318": "vulture",
|
| 30 |
+
"n01622779": "great grey owl, great gray owl, Strix nebulosa",
|
| 31 |
+
"n01629819": "European fire salamander, Salamandra salamandra",
|
| 32 |
+
"n01630670": "common newt, Triturus vulgaris",
|
| 33 |
+
"n01631663": "eft",
|
| 34 |
+
"n01632458": "spotted salamander, Ambystoma maculatum",
|
| 35 |
+
"n01632777": "axolotl, mud puppy, Ambystoma mexicanum",
|
| 36 |
+
"n01641577": "bullfrog, Rana catesbeiana",
|
| 37 |
+
"n01644373": "tree frog, tree-frog",
|
| 38 |
+
"n01644900": "tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui",
|
| 39 |
+
"n01664065": "loggerhead, loggerhead turtle, Caretta caretta",
|
| 40 |
+
"n01665541": "leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea",
|
| 41 |
+
"n01667114": "mud turtle",
|
| 42 |
+
"n01667778": "terrapin",
|
| 43 |
+
"n01669191": "box turtle, box tortoise",
|
| 44 |
+
"n01675722": "banded gecko",
|
| 45 |
+
"n01677366": "common iguana, iguana, Iguana iguana",
|
| 46 |
+
"n01682714": "American chameleon, anole, Anolis carolinensis",
|
| 47 |
+
"n01685808": "whiptail, whiptail lizard",
|
| 48 |
+
"n01687978": "agama",
|
| 49 |
+
"n01688243": "frilled lizard, Chlamydosaurus kingi",
|
| 50 |
+
"n01689811": "alligator lizard",
|
| 51 |
+
"n01692333": "Gila monster, Heloderma suspectum",
|
| 52 |
+
"n01693334": "green lizard, Lacerta viridis",
|
| 53 |
+
"n01694178": "African chameleon, Chamaeleo chamaeleon",
|
| 54 |
+
"n01695060": "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis",
|
| 55 |
+
"n01697457": "African crocodile, Nile crocodile, Crocodylus niloticus",
|
| 56 |
+
"n01698640": "American alligator, Alligator mississipiensis",
|
| 57 |
+
"n01704323": "triceratops",
|
| 58 |
+
"n01728572": "thunder snake, worm snake, Carphophis amoenus",
|
| 59 |
+
"n01728920": "ringneck snake, ring-necked snake, ring snake",
|
| 60 |
+
"n01729322": "hognose snake, puff adder, sand viper",
|
| 61 |
+
"n01729977": "green snake, grass snake",
|
| 62 |
+
"n01734418": "king snake, kingsnake",
|
| 63 |
+
"n01735189": "garter snake, grass snake",
|
| 64 |
+
"n01737021": "water snake",
|
| 65 |
+
"n01739381": "vine snake",
|
| 66 |
+
"n01740131": "night snake, Hypsiglena torquata",
|
| 67 |
+
"n01742172": "boa constrictor, Constrictor constrictor",
|
| 68 |
+
"n01744401": "rock python, rock snake, Python sebae",
|
| 69 |
+
"n01748264": "Indian cobra, Naja naja",
|
| 70 |
+
"n01749939": "green mamba",
|
| 71 |
+
"n01751748": "sea snake",
|
| 72 |
+
"n01753488": "horned viper, cerastes, sand viper, horned asp, Cerastes cornutus",
|
| 73 |
+
"n01755581": "diamondback, diamondback rattlesnake, Crotalus adamanteus",
|
| 74 |
+
"n01756291": "sidewinder, horned rattlesnake, Crotalus cerastes",
|
| 75 |
+
"n01768244": "trilobite",
|
| 76 |
+
"n01770081": "harvestman, daddy longlegs, Phalangium opilio",
|
| 77 |
+
"n01770393": "scorpion",
|
| 78 |
+
"n01773157": "black and gold garden spider, Argiope aurantia",
|
| 79 |
+
"n01773549": "barn spider, Araneus cavaticus",
|
| 80 |
+
"n01773797": "garden spider, Aranea diademata",
|
| 81 |
+
"n01774384": "black widow, Latrodectus mactans",
|
| 82 |
+
"n01774750": "tarantula",
|
| 83 |
+
"n01775062": "wolf spider, hunting spider",
|
| 84 |
+
"n01776313": "tick",
|
| 85 |
+
"n01784675": "centipede",
|
| 86 |
+
"n01795545": "black grouse",
|
| 87 |
+
"n01796340": "ptarmigan",
|
| 88 |
+
"n01797886": "ruffed grouse, partridge, Bonasa umbellus",
|
| 89 |
+
"n01798484": "prairie chicken, prairie grouse, prairie fowl",
|
| 90 |
+
"n01806143": "peacock",
|
| 91 |
+
"n01806567": "quail",
|
| 92 |
+
"n01807496": "partridge",
|
| 93 |
+
"n01817953": "African grey, African gray, Psittacus erithacus",
|
| 94 |
+
"n01818515": "macaw",
|
| 95 |
+
"n01819313": "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita",
|
| 96 |
+
"n01820546": "lorikeet",
|
| 97 |
+
"n01824575": "coucal",
|
| 98 |
+
"n01828970": "bee eater",
|
| 99 |
+
"n01829413": "hornbill",
|
| 100 |
+
"n01833805": "hummingbird",
|
| 101 |
+
"n01843065": "jacamar",
|
| 102 |
+
"n01843383": "toucan",
|
| 103 |
+
"n01847000": "drake",
|
| 104 |
+
"n01855032": "red-breasted merganser, Mergus serrator",
|
| 105 |
+
"n01855672": "goose",
|
| 106 |
+
"n01860187": "black swan, Cygnus atratus",
|
| 107 |
+
"n01871265": "tusker",
|
| 108 |
+
"n01872401": "echidna, spiny anteater, anteater",
|
| 109 |
+
"n01873310": "platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus",
|
| 110 |
+
"n01877812": "wallaby, brush kangaroo",
|
| 111 |
+
"n01882714": "koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus",
|
| 112 |
+
"n01883070": "wombat",
|
| 113 |
+
"n01910747": "jellyfish",
|
| 114 |
+
"n01914609": "sea anemone, anemone",
|
| 115 |
+
"n01917289": "brain coral",
|
| 116 |
+
"n01924916": "flatworm, platyhelminth",
|
| 117 |
+
"n01930112": "nematode, nematode worm, roundworm",
|
| 118 |
+
"n01943899": "conch",
|
| 119 |
+
"n01944390": "snail",
|
| 120 |
+
"n01945685": "slug",
|
| 121 |
+
"n01950731": "sea slug, nudibranch",
|
| 122 |
+
"n01955084": "chiton, coat-of-mail shell, sea cradle, polyplacophore",
|
| 123 |
+
"n01968897": "chambered nautilus, pearly nautilus, nautilus",
|
| 124 |
+
"n01978287": "Dungeness crab, Cancer magister",
|
| 125 |
+
"n01978455": "rock crab, Cancer irroratus",
|
| 126 |
+
"n01980166": "fiddler crab",
|
| 127 |
+
"n01981276": "king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica",
|
| 128 |
+
"n01983481": "American lobster, Northern lobster, Maine lobster, Homarus americanus",
|
| 129 |
+
"n01984695": "spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish",
|
| 130 |
+
"n01985128": "crayfish, crawfish, crawdad, crawdaddy",
|
| 131 |
+
"n01986214": "hermit crab",
|
| 132 |
+
"n01990800": "isopod",
|
| 133 |
+
"n02002556": "white stork, Ciconia ciconia",
|
| 134 |
+
"n02002724": "black stork, Ciconia nigra",
|
| 135 |
+
"n02006656": "spoonbill",
|
| 136 |
+
"n02007558": "flamingo",
|
| 137 |
+
"n02009229": "little blue heron, Egretta caerulea",
|
| 138 |
+
"n02009912": "American egret, great white heron, Egretta albus",
|
| 139 |
+
"n02011460": "bittern",
|
| 140 |
+
"n02012849": "crane",
|
| 141 |
+
"n02013706": "limpkin, Aramus pictus",
|
| 142 |
+
"n02017213": "European gallinule, Porphyrio porphyrio",
|
| 143 |
+
"n02018207": "American coot, marsh hen, mud hen, water hen, Fulica americana",
|
| 144 |
+
"n02018795": "bustard",
|
| 145 |
+
"n02025239": "ruddy turnstone, Arenaria interpres",
|
| 146 |
+
"n02027492": "red-backed sandpiper, dunlin, Erolia alpina",
|
| 147 |
+
"n02028035": "redshank, Tringa totanus",
|
| 148 |
+
"n02033041": "dowitcher",
|
| 149 |
+
"n02037110": "oystercatcher, oyster catcher",
|
| 150 |
+
"n02051845": "pelican",
|
| 151 |
+
"n02056570": "king penguin, Aptenodytes patagonica",
|
| 152 |
+
"n02058221": "albatross, mollymawk",
|
| 153 |
+
"n02066245": "grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus",
|
| 154 |
+
"n02071294": "killer whale, killer, orca, grampus, sea wolf, Orcinus orca",
|
| 155 |
+
"n02074367": "dugong, Dugong dugon",
|
| 156 |
+
"n02077923": "sea lion",
|
| 157 |
+
"n02085620": "Chihuahua",
|
| 158 |
+
"n02085782": "Japanese spaniel",
|
| 159 |
+
"n02085936": "Maltese dog, Maltese terrier, Maltese",
|
| 160 |
+
"n02086079": "Pekinese, Pekingese, Peke",
|
| 161 |
+
"n02086240": "Shih-Tzu",
|
| 162 |
+
"n02086646": "Blenheim spaniel",
|
| 163 |
+
"n02086910": "papillon",
|
| 164 |
+
"n02087046": "toy terrier",
|
| 165 |
+
"n02087394": "Rhodesian ridgeback",
|
| 166 |
+
"n02088094": "Afghan hound, Afghan",
|
| 167 |
+
"n02088238": "basset, basset hound",
|
| 168 |
+
"n02088364": "beagle",
|
| 169 |
+
"n02088466": "bloodhound, sleuthhound",
|
| 170 |
+
"n02088632": "bluetick",
|
| 171 |
+
"n02089078": "black-and-tan coonhound",
|
| 172 |
+
"n02089867": "Walker hound, Walker foxhound",
|
| 173 |
+
"n02089973": "English foxhound",
|
| 174 |
+
"n02090379": "redbone",
|
| 175 |
+
"n02090622": "borzoi, Russian wolfhound",
|
| 176 |
+
"n02090721": "Irish wolfhound",
|
| 177 |
+
"n02091032": "Italian greyhound",
|
| 178 |
+
"n02091134": "whippet",
|
| 179 |
+
"n02091244": "Ibizan hound, Ibizan Podenco",
|
| 180 |
+
"n02091467": "Norwegian elkhound, elkhound",
|
| 181 |
+
"n02091635": "otterhound, otter hound",
|
| 182 |
+
"n02091831": "Saluki, gazelle hound",
|
| 183 |
+
"n02092002": "Scottish deerhound, deerhound",
|
| 184 |
+
"n02092339": "Weimaraner",
|
| 185 |
+
"n02093256": "Staffordshire bullterrier, Staffordshire bull terrier",
|
| 186 |
+
"n02093428": "American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier",
|
| 187 |
+
"n02093647": "Bedlington terrier",
|
| 188 |
+
"n02093754": "Border terrier",
|
| 189 |
+
"n02093859": "Kerry blue terrier",
|
| 190 |
+
"n02093991": "Irish terrier",
|
| 191 |
+
"n02094114": "Norfolk terrier",
|
| 192 |
+
"n02094258": "Norwich terrier",
|
| 193 |
+
"n02094433": "Yorkshire terrier",
|
| 194 |
+
"n02095314": "wire-haired fox terrier",
|
| 195 |
+
"n02095570": "Lakeland terrier",
|
| 196 |
+
"n02095889": "Sealyham terrier, Sealyham",
|
| 197 |
+
"n02096051": "Airedale, Airedale terrier",
|
| 198 |
+
"n02096177": "cairn, cairn terrier",
|
| 199 |
+
"n02096294": "Australian terrier",
|
| 200 |
+
"n02096437": "Dandie Dinmont, Dandie Dinmont terrier",
|
| 201 |
+
"n02096585": "Boston bull, Boston terrier",
|
| 202 |
+
"n02097047": "miniature schnauzer",
|
| 203 |
+
"n02097130": "giant schnauzer",
|
| 204 |
+
"n02097209": "standard schnauzer",
|
| 205 |
+
"n02097298": "Scotch terrier, Scottish terrier, Scottie",
|
| 206 |
+
"n02097474": "Tibetan terrier, chrysanthemum dog",
|
| 207 |
+
"n02097658": "silky terrier, Sydney silky",
|
| 208 |
+
"n02098105": "soft-coated wheaten terrier",
|
| 209 |
+
"n02098286": "West Highland white terrier",
|
| 210 |
+
"n02098413": "Lhasa, Lhasa apso",
|
| 211 |
+
"n02099267": "flat-coated retriever",
|
| 212 |
+
"n02099429": "curly-coated retriever",
|
| 213 |
+
"n02099601": "golden retriever",
|
| 214 |
+
"n02099712": "Labrador retriever",
|
| 215 |
+
"n02099849": "Chesapeake Bay retriever",
|
| 216 |
+
"n02100236": "German short-haired pointer",
|
| 217 |
+
"n02100583": "vizsla, Hungarian pointer",
|
| 218 |
+
"n02100735": "English setter",
|
| 219 |
+
"n02100877": "Irish setter, red setter",
|
| 220 |
+
"n02101006": "Gordon setter",
|
| 221 |
+
"n02101388": "Brittany spaniel",
|
| 222 |
+
"n02101556": "clumber, clumber spaniel",
|
| 223 |
+
"n02102040": "English springer, English springer spaniel",
|
| 224 |
+
"n02102177": "Welsh springer spaniel",
|
| 225 |
+
"n02102318": "cocker spaniel, English cocker spaniel, cocker",
|
| 226 |
+
"n02102480": "Sussex spaniel",
|
| 227 |
+
"n02102973": "Irish water spaniel",
|
| 228 |
+
"n02104029": "kuvasz",
|
| 229 |
+
"n02104365": "schipperke",
|
| 230 |
+
"n02105056": "groenendael",
|
| 231 |
+
"n02105162": "malinois",
|
| 232 |
+
"n02105251": "briard",
|
| 233 |
+
"n02105412": "kelpie",
|
| 234 |
+
"n02105505": "komondor",
|
| 235 |
+
"n02105641": "Old English sheepdog, bobtail",
|
| 236 |
+
"n02105855": "Shetland sheepdog, Shetland sheep dog, Shetland",
|
| 237 |
+
"n02106030": "collie",
|
| 238 |
+
"n02106166": "Border collie",
|
| 239 |
+
"n02106382": "Bouvier des Flandres, Bouviers des Flandres",
|
| 240 |
+
"n02106550": "Rottweiler",
|
| 241 |
+
"n02106662": "German shepherd, German shepherd dog, German police dog, alsatian",
|
| 242 |
+
"n02107142": "Doberman, Doberman pinscher",
|
| 243 |
+
"n02107312": "miniature pinscher",
|
| 244 |
+
"n02107574": "Greater Swiss Mountain dog",
|
| 245 |
+
"n02107683": "Bernese mountain dog",
|
| 246 |
+
"n02107908": "Appenzeller",
|
| 247 |
+
"n02108000": "EntleBucher",
|
| 248 |
+
"n02108089": "boxer",
|
| 249 |
+
"n02108422": "bull mastiff",
|
| 250 |
+
"n02108551": "Tibetan mastiff",
|
| 251 |
+
"n02108915": "French bulldog",
|
| 252 |
+
"n02109047": "Great Dane",
|
| 253 |
+
"n02109525": "Saint Bernard, St Bernard",
|
| 254 |
+
"n02109961": "Eskimo dog, husky",
|
| 255 |
+
"n02110063": "malamute, malemute, Alaskan malamute",
|
| 256 |
+
"n02110185": "Siberian husky",
|
| 257 |
+
"n02110341": "dalmatian, coach dog, carriage dog",
|
| 258 |
+
"n02110627": "affenpinscher, monkey pinscher, monkey dog",
|
| 259 |
+
"n02110806": "basenji",
|
| 260 |
+
"n02110958": "pug, pug-dog",
|
| 261 |
+
"n02111129": "Leonberg",
|
| 262 |
+
"n02111277": "Newfoundland, Newfoundland dog",
|
| 263 |
+
"n02111500": "Great Pyrenees",
|
| 264 |
+
"n02111889": "Samoyed, Samoyede",
|
| 265 |
+
"n02112018": "Pomeranian",
|
| 266 |
+
"n02112137": "chow, chow chow",
|
| 267 |
+
"n02112350": "keeshond",
|
| 268 |
+
"n02112706": "Brabancon griffon",
|
| 269 |
+
"n02113023": "Pembroke, Pembroke Welsh corgi",
|
| 270 |
+
"n02113186": "Cardigan, Cardigan Welsh corgi",
|
| 271 |
+
"n02113624": "toy poodle",
|
| 272 |
+
"n02113712": "miniature poodle",
|
| 273 |
+
"n02113799": "standard poodle",
|
| 274 |
+
"n02113978": "Mexican hairless",
|
| 275 |
+
"n02114367": "timber wolf, grey wolf, gray wolf, Canis lupus",
|
| 276 |
+
"n02114548": "white wolf, Arctic wolf, Canis lupus tundrarum",
|
| 277 |
+
"n02114712": "red wolf, maned wolf, Canis rufus, Canis niger",
|
| 278 |
+
"n02114855": "coyote, prairie wolf, brush wolf, Canis latrans",
|
| 279 |
+
"n02115641": "dingo, warrigal, warragal, Canis dingo",
|
| 280 |
+
"n02115913": "dhole, Cuon alpinus",
|
| 281 |
+
"n02116738": "African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus",
|
| 282 |
+
"n02117135": "hyena, hyaena",
|
| 283 |
+
"n02119022": "red fox, Vulpes vulpes",
|
| 284 |
+
"n02119789": "kit fox, Vulpes macrotis",
|
| 285 |
+
"n02120079": "Arctic fox, white fox, Alopex lagopus",
|
| 286 |
+
"n02120505": "grey fox, gray fox, Urocyon cinereoargenteus",
|
| 287 |
+
"n02123045": "tabby, tabby cat",
|
| 288 |
+
"n02123159": "tiger cat",
|
| 289 |
+
"n02123394": "Persian cat",
|
| 290 |
+
"n02123597": "Siamese cat, Siamese",
|
| 291 |
+
"n02124075": "Egyptian cat",
|
| 292 |
+
"n02125311": "cougar, puma, catamount, mountain lion, painter, panther, Felis concolor",
|
| 293 |
+
"n02127052": "lynx, catamount",
|
| 294 |
+
"n02128385": "leopard, Panthera pardus",
|
| 295 |
+
"n02128757": "snow leopard, ounce, Panthera uncia",
|
| 296 |
+
"n02128925": "jaguar, panther, Panthera onca, Felis onca",
|
| 297 |
+
"n02129165": "lion, king of beasts, Panthera leo",
|
| 298 |
+
"n02129604": "tiger, Panthera tigris",
|
| 299 |
+
"n02130308": "cheetah, chetah, Acinonyx jubatus",
|
| 300 |
+
"n02132136": "brown bear, bruin, Ursus arctos",
|
| 301 |
+
"n02133161": "American black bear, black bear, Ursus americanus, Euarctos americanus",
|
| 302 |
+
"n02134084": "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus",
|
| 303 |
+
"n02134418": "sloth bear, Melursus ursinus, Ursus ursinus",
|
| 304 |
+
"n02137549": "mongoose",
|
| 305 |
+
"n02138441": "meerkat, mierkat",
|
| 306 |
+
"n02165105": "tiger beetle",
|
| 307 |
+
"n02165456": "ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle",
|
| 308 |
+
"n02167151": "ground beetle, carabid beetle",
|
| 309 |
+
"n02168699": "long-horned beetle, longicorn, longicorn beetle",
|
| 310 |
+
"n02169497": "leaf beetle, chrysomelid",
|
| 311 |
+
"n02172182": "dung beetle",
|
| 312 |
+
"n02174001": "rhinoceros beetle",
|
| 313 |
+
"n02177972": "weevil",
|
| 314 |
+
"n02190166": "fly",
|
| 315 |
+
"n02206856": "bee",
|
| 316 |
+
"n02219486": "ant, emmet, pismire",
|
| 317 |
+
"n02226429": "grasshopper, hopper",
|
| 318 |
+
"n02229544": "cricket",
|
| 319 |
+
"n02231487": "walking stick, walkingstick, stick insect",
|
| 320 |
+
"n02233338": "cockroach, roach",
|
| 321 |
+
"n02236044": "mantis, mantid",
|
| 322 |
+
"n02256656": "cicada, cicala",
|
| 323 |
+
"n02259212": "leafhopper",
|
| 324 |
+
"n02264363": "lacewing, lacewing fly",
|
| 325 |
+
"n02268443": "dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk",
|
| 326 |
+
"n02268853": "damselfly",
|
| 327 |
+
"n02276258": "admiral",
|
| 328 |
+
"n02277742": "ringlet, ringlet butterfly",
|
| 329 |
+
"n02279972": "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus",
|
| 330 |
+
"n02280649": "cabbage butterfly",
|
| 331 |
+
"n02281406": "sulphur butterfly, sulfur butterfly",
|
| 332 |
+
"n02281787": "lycaenid, lycaenid butterfly",
|
| 333 |
+
"n02317335": "starfish, sea star",
|
| 334 |
+
"n02319095": "sea urchin",
|
| 335 |
+
"n02321529": "sea cucumber, holothurian",
|
| 336 |
+
"n02325366": "wood rabbit, cottontail, cottontail rabbit",
|
| 337 |
+
"n02326432": "hare",
|
| 338 |
+
"n02328150": "Angora, Angora rabbit",
|
| 339 |
+
"n02342885": "hamster",
|
| 340 |
+
"n02346627": "porcupine, hedgehog",
|
| 341 |
+
"n02356798": "fox squirrel, eastern fox squirrel, Sciurus niger",
|
| 342 |
+
"n02361337": "marmot",
|
| 343 |
+
"n02363005": "beaver",
|
| 344 |
+
"n02364673": "guinea pig, Cavia cobaya",
|
| 345 |
+
"n02389026": "sorrel",
|
| 346 |
+
"n02391049": "zebra",
|
| 347 |
+
"n02395406": "hog, pig, grunter, squealer, Sus scrofa",
|
| 348 |
+
"n02396427": "wild boar, boar, Sus scrofa",
|
| 349 |
+
"n02397096": "warthog",
|
| 350 |
+
"n02398521": "hippopotamus, hippo, river horse, Hippopotamus amphibius",
|
| 351 |
+
"n02403003": "ox",
|
| 352 |
+
"n02408429": "water buffalo, water ox, Asiatic buffalo, Bubalus bubalis",
|
| 353 |
+
"n02410509": "bison",
|
| 354 |
+
"n02412080": "ram, tup",
|
| 355 |
+
"n02415577": "bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis",
|
| 356 |
+
"n02417914": "ibex, Capra ibex",
|
| 357 |
+
"n02422106": "hartebeest",
|
| 358 |
+
"n02422699": "impala, Aepyceros melampus",
|
| 359 |
+
"n02423022": "gazelle",
|
| 360 |
+
"n02437312": "Arabian camel, dromedary, Camelus dromedarius",
|
| 361 |
+
"n02437616": "llama",
|
| 362 |
+
"n02441942": "weasel",
|
| 363 |
+
"n02442845": "mink",
|
| 364 |
+
"n02443114": "polecat, fitch, foulmart, foumart, Mustela putorius",
|
| 365 |
+
"n02443484": "black-footed ferret, ferret, Mustela nigripes",
|
| 366 |
+
"n02444819": "otter",
|
| 367 |
+
"n02445715": "skunk, polecat, wood pussy",
|
| 368 |
+
"n02447366": "badger",
|
| 369 |
+
"n02454379": "armadillo",
|
| 370 |
+
"n02457408": "three-toed sloth, ai, Bradypus tridactylus",
|
| 371 |
+
"n02480495": "orangutan, orang, orangutang, Pongo pygmaeus",
|
| 372 |
+
"n02480855": "gorilla, Gorilla gorilla",
|
| 373 |
+
"n02481823": "chimpanzee, chimp, Pan troglodytes",
|
| 374 |
+
"n02483362": "gibbon, Hylobates lar",
|
| 375 |
+
"n02483708": "siamang, Hylobates syndactylus, Symphalangus syndactylus",
|
| 376 |
+
"n02484975": "guenon, guenon monkey",
|
| 377 |
+
"n02486261": "patas, hussar monkey, Erythrocebus patas",
|
| 378 |
+
"n02486410": "baboon",
|
| 379 |
+
"n02487347": "macaque",
|
| 380 |
+
"n02488291": "langur",
|
| 381 |
+
"n02488702": "colobus, colobus monkey",
|
| 382 |
+
"n02489166": "proboscis monkey, Nasalis larvatus",
|
| 383 |
+
"n02490219": "marmoset",
|
| 384 |
+
"n02492035": "capuchin, ringtail, Cebus capucinus",
|
| 385 |
+
"n02492660": "howler monkey, howler",
|
| 386 |
+
"n02493509": "titi, titi monkey",
|
| 387 |
+
"n02493793": "spider monkey, Ateles geoffroyi",
|
| 388 |
+
"n02494079": "squirrel monkey, Saimiri sciureus",
|
| 389 |
+
"n02497673": "Madagascar cat, ring-tailed lemur, Lemur catta",
|
| 390 |
+
"n02500267": "indri, indris, Indri indri, Indri brevicaudatus",
|
| 391 |
+
"n02504013": "Indian elephant, Elephas maximus",
|
| 392 |
+
"n02504458": "African elephant, Loxodonta africana",
|
| 393 |
+
"n02509815": "lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens",
|
| 394 |
+
"n02510455": "giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca",
|
| 395 |
+
"n02514041": "barracouta, snoek",
|
| 396 |
+
"n02526121": "eel",
|
| 397 |
+
"n02536864": "coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch",
|
| 398 |
+
"n02606052": "rock beauty, Holocanthus tricolor",
|
| 399 |
+
"n02607072": "anemone fish",
|
| 400 |
+
"n02640242": "sturgeon",
|
| 401 |
+
"n02641379": "gar, garfish, garpike, billfish, Lepisosteus osseus",
|
| 402 |
+
"n02643566": "lionfish",
|
| 403 |
+
"n02655020": "puffer, pufferfish, blowfish, globefish",
|
| 404 |
+
"n02666196": "abacus",
|
| 405 |
+
"n02667093": "abaya",
|
| 406 |
+
"n02669723": "academic gown, academic robe, judge's robe",
|
| 407 |
+
"n02672831": "accordion, piano accordion, squeeze box",
|
| 408 |
+
"n02676566": "acoustic guitar",
|
| 409 |
+
"n02687172": "aircraft carrier, carrier, flattop, attack aircraft carrier",
|
| 410 |
+
"n02690373": "airliner",
|
| 411 |
+
"n02692877": "airship, dirigible",
|
| 412 |
+
"n02699494": "altar",
|
| 413 |
+
"n02701002": "ambulance",
|
| 414 |
+
"n02704792": "amphibian, amphibious vehicle",
|
| 415 |
+
"n02708093": "analog clock",
|
| 416 |
+
"n02727426": "apiary, bee house",
|
| 417 |
+
"n02730930": "apron",
|
| 418 |
+
"n02747177": "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin",
|
| 419 |
+
"n02749479": "assault rifle, assault gun",
|
| 420 |
+
"n02769748": "backpack, back pack, knapsack, packsack, rucksack, haversack",
|
| 421 |
+
"n02776631": "bakery, bakeshop, bakehouse",
|
| 422 |
+
"n02777292": "balance beam, beam",
|
| 423 |
+
"n02782093": "balloon",
|
| 424 |
+
"n02783161": "ballpoint, ballpoint pen, ballpen, Biro",
|
| 425 |
+
"n02786058": "Band Aid",
|
| 426 |
+
"n02787622": "banjo",
|
| 427 |
+
"n02788148": "bannister, banister, balustrade, balusters, handrail",
|
| 428 |
+
"n02790996": "barbell",
|
| 429 |
+
"n02791124": "barber chair",
|
| 430 |
+
"n02791270": "barbershop",
|
| 431 |
+
"n02793495": "barn",
|
| 432 |
+
"n02794156": "barometer",
|
| 433 |
+
"n02795169": "barrel, cask",
|
| 434 |
+
"n02797295": "barrow, garden cart, lawn cart, wheelbarrow",
|
| 435 |
+
"n02799071": "baseball",
|
| 436 |
+
"n02802426": "basketball",
|
| 437 |
+
"n02804414": "bassinet",
|
| 438 |
+
"n02804610": "bassoon",
|
| 439 |
+
"n02807133": "bathing cap, swimming cap",
|
| 440 |
+
"n02808304": "bath towel",
|
| 441 |
+
"n02808440": "bathtub, bathing tub, bath, tub",
|
| 442 |
+
"n02814533": "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
|
| 443 |
+
"n02814860": "beacon, lighthouse, beacon light, pharos",
|
| 444 |
+
"n02815834": "beaker",
|
| 445 |
+
"n02817516": "bearskin, busby, shako",
|
| 446 |
+
"n02823428": "beer bottle",
|
| 447 |
+
"n02823750": "beer glass",
|
| 448 |
+
"n02825657": "bell cote, bell cot",
|
| 449 |
+
"n02834397": "bib",
|
| 450 |
+
"n02835271": "bicycle-built-for-two, tandem bicycle, tandem",
|
| 451 |
+
"n02837789": "bikini, two-piece",
|
| 452 |
+
"n02840245": "binder, ring-binder",
|
| 453 |
+
"n02841315": "binoculars, field glasses, opera glasses",
|
| 454 |
+
"n02843684": "birdhouse",
|
| 455 |
+
"n02859443": "boathouse",
|
| 456 |
+
"n02860847": "bobsled, bobsleigh, bob",
|
| 457 |
+
"n02865351": "bolo tie, bolo, bola tie, bola",
|
| 458 |
+
"n02869837": "bonnet, poke bonnet",
|
| 459 |
+
"n02870880": "bookcase",
|
| 460 |
+
"n02871525": "bookshop, bookstore, bookstall",
|
| 461 |
+
"n02877765": "bottlecap",
|
| 462 |
+
"n02879718": "bow",
|
| 463 |
+
"n02883205": "bow tie, bow-tie, bowtie",
|
| 464 |
+
"n02892201": "brass, memorial tablet, plaque",
|
| 465 |
+
"n02892767": "brassiere, bra, bandeau",
|
| 466 |
+
"n02894605": "breakwater, groin, groyne, mole, bulwark, seawall, jetty",
|
| 467 |
+
"n02895154": "breastplate, aegis, egis",
|
| 468 |
+
"n02906734": "broom",
|
| 469 |
+
"n02909870": "bucket, pail",
|
| 470 |
+
"n02910353": "buckle",
|
| 471 |
+
"n02916936": "bulletproof vest",
|
| 472 |
+
"n02917067": "bullet train, bullet",
|
| 473 |
+
"n02927161": "butcher shop, meat market",
|
| 474 |
+
"n02930766": "cab, hack, taxi, taxicab",
|
| 475 |
+
"n02939185": "caldron, cauldron",
|
| 476 |
+
"n02948072": "candle, taper, wax light",
|
| 477 |
+
"n02950826": "cannon",
|
| 478 |
+
"n02951358": "canoe",
|
| 479 |
+
"n02951585": "can opener, tin opener",
|
| 480 |
+
"n02963159": "cardigan",
|
| 481 |
+
"n02965783": "car mirror",
|
| 482 |
+
"n02966193": "carousel, carrousel, merry-go-round, roundabout, whirligig",
|
| 483 |
+
"n02966687": "carpenter's kit, tool kit",
|
| 484 |
+
"n02971356": "carton",
|
| 485 |
+
"n02974003": "car wheel",
|
| 486 |
+
"n02977058": "cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM",
|
| 487 |
+
"n02978881": "cassette",
|
| 488 |
+
"n02979186": "cassette player",
|
| 489 |
+
"n02980441": "castle",
|
| 490 |
+
"n02981792": "catamaran",
|
| 491 |
+
"n02988304": "CD player",
|
| 492 |
+
"n02992211": "cello, violoncello",
|
| 493 |
+
"n02992529": "cellular telephone, cellular phone, cellphone, cell, mobile phone",
|
| 494 |
+
"n02999410": "chain",
|
| 495 |
+
"n03000134": "chainlink fence",
|
| 496 |
+
"n03000247": "chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour",
|
| 497 |
+
"n03000684": "chain saw, chainsaw",
|
| 498 |
+
"n03014705": "chest",
|
| 499 |
+
"n03016953": "chiffonier, commode",
|
| 500 |
+
"n03017168": "chime, bell, gong",
|
| 501 |
+
"n03018349": "china cabinet, china closet",
|
| 502 |
+
"n03026506": "Christmas stocking",
|
| 503 |
+
"n03028079": "church, church building",
|
| 504 |
+
"n03032252": "cinema, movie theater, movie theatre, movie house, picture palace",
|
| 505 |
+
"n03041632": "cleaver, meat cleaver, chopper",
|
| 506 |
+
"n03042490": "cliff dwelling",
|
| 507 |
+
"n03045698": "cloak",
|
| 508 |
+
"n03047690": "clog, geta, patten, sabot",
|
| 509 |
+
"n03062245": "cocktail shaker",
|
| 510 |
+
"n03063599": "coffee mug",
|
| 511 |
+
"n03063689": "coffeepot",
|
| 512 |
+
"n03065424": "coil, spiral, volute, whorl, helix",
|
| 513 |
+
"n03075370": "combination lock",
|
| 514 |
+
"n03085013": "computer keyboard, keypad",
|
| 515 |
+
"n03089624": "confectionery, confectionary, candy store",
|
| 516 |
+
"n03095699": "container ship, containership, container vessel",
|
| 517 |
+
"n03100240": "convertible",
|
| 518 |
+
"n03109150": "corkscrew, bottle screw",
|
| 519 |
+
"n03110669": "cornet, horn, trumpet, trump",
|
| 520 |
+
"n03124043": "cowboy boot",
|
| 521 |
+
"n03124170": "cowboy hat, ten-gallon hat",
|
| 522 |
+
"n03125729": "cradle",
|
| 523 |
+
"n03126707": "crane2",
|
| 524 |
+
"n03127747": "crash helmet",
|
| 525 |
+
"n03127925": "crate",
|
| 526 |
+
"n03131574": "crib, cot",
|
| 527 |
+
"n03133878": "Crock Pot",
|
| 528 |
+
"n03134739": "croquet ball",
|
| 529 |
+
"n03141823": "crutch",
|
| 530 |
+
"n03146219": "cuirass",
|
| 531 |
+
"n03160309": "dam, dike, dyke",
|
| 532 |
+
"n03179701": "desk",
|
| 533 |
+
"n03180011": "desktop computer",
|
| 534 |
+
"n03187595": "dial telephone, dial phone",
|
| 535 |
+
"n03188531": "diaper, nappy, napkin",
|
| 536 |
+
"n03196217": "digital clock",
|
| 537 |
+
"n03197337": "digital watch",
|
| 538 |
+
"n03201208": "dining table, board",
|
| 539 |
+
"n03207743": "dishrag, dishcloth",
|
| 540 |
+
"n03207941": "dishwasher, dish washer, dishwashing machine",
|
| 541 |
+
"n03208938": "disk brake, disc brake",
|
| 542 |
+
"n03216828": "dock, dockage, docking facility",
|
| 543 |
+
"n03218198": "dogsled, dog sled, dog sleigh",
|
| 544 |
+
"n03220513": "dome",
|
| 545 |
+
"n03223299": "doormat, welcome mat",
|
| 546 |
+
"n03240683": "drilling platform, offshore rig",
|
| 547 |
+
"n03249569": "drum, membranophone, tympan",
|
| 548 |
+
"n03250847": "drumstick",
|
| 549 |
+
"n03255030": "dumbbell",
|
| 550 |
+
"n03259280": "Dutch oven",
|
| 551 |
+
"n03271574": "electric fan, blower",
|
| 552 |
+
"n03272010": "electric guitar",
|
| 553 |
+
"n03272562": "electric locomotive",
|
| 554 |
+
"n03290653": "entertainment center",
|
| 555 |
+
"n03291819": "envelope",
|
| 556 |
+
"n03297495": "espresso maker",
|
| 557 |
+
"n03314780": "face powder",
|
| 558 |
+
"n03325584": "feather boa, boa",
|
| 559 |
+
"n03337140": "file, file cabinet, filing cabinet",
|
| 560 |
+
"n03344393": "fireboat",
|
| 561 |
+
"n03345487": "fire engine, fire truck",
|
| 562 |
+
"n03347037": "fire screen, fireguard",
|
| 563 |
+
"n03355925": "flagpole, flagstaff",
|
| 564 |
+
"n03372029": "flute, transverse flute",
|
| 565 |
+
"n03376595": "folding chair",
|
| 566 |
+
"n03379051": "football helmet",
|
| 567 |
+
"n03384352": "forklift",
|
| 568 |
+
"n03388043": "fountain",
|
| 569 |
+
"n03388183": "fountain pen",
|
| 570 |
+
"n03388549": "four-poster",
|
| 571 |
+
"n03393912": "freight car",
|
| 572 |
+
"n03394916": "French horn, horn",
|
| 573 |
+
"n03400231": "frying pan, frypan, skillet",
|
| 574 |
+
"n03404251": "fur coat",
|
| 575 |
+
"n03417042": "garbage truck, dustcart",
|
| 576 |
+
"n03424325": "gasmask, respirator, gas helmet",
|
| 577 |
+
"n03425413": "gas pump, gasoline pump, petrol pump, island dispenser",
|
| 578 |
+
"n03443371": "goblet",
|
| 579 |
+
"n03444034": "go-kart",
|
| 580 |
+
"n03445777": "golf ball",
|
| 581 |
+
"n03445924": "golfcart, golf cart",
|
| 582 |
+
"n03447447": "gondola",
|
| 583 |
+
"n03447721": "gong, tam-tam",
|
| 584 |
+
"n03450230": "gown",
|
| 585 |
+
"n03452741": "grand piano, grand",
|
| 586 |
+
"n03457902": "greenhouse, nursery, glasshouse",
|
| 587 |
+
"n03459775": "grille, radiator grille",
|
| 588 |
+
"n03461385": "grocery store, grocery, food market, market",
|
| 589 |
+
"n03467068": "guillotine",
|
| 590 |
+
"n03476684": "hair slide",
|
| 591 |
+
"n03476991": "hair spray",
|
| 592 |
+
"n03478589": "half track",
|
| 593 |
+
"n03481172": "hammer",
|
| 594 |
+
"n03482405": "hamper",
|
| 595 |
+
"n03483316": "hand blower, blow dryer, blow drier, hair dryer, hair drier",
|
| 596 |
+
"n03485407": "hand-held computer, hand-held microcomputer",
|
| 597 |
+
"n03485794": "handkerchief, hankie, hanky, hankey",
|
| 598 |
+
"n03492542": "hard disc, hard disk, fixed disk",
|
| 599 |
+
"n03494278": "harmonica, mouth organ, harp, mouth harp",
|
| 600 |
+
"n03495258": "harp",
|
| 601 |
+
"n03496892": "harvester, reaper",
|
| 602 |
+
"n03498962": "hatchet",
|
| 603 |
+
"n03527444": "holster",
|
| 604 |
+
"n03529860": "home theater, home theatre",
|
| 605 |
+
"n03530642": "honeycomb",
|
| 606 |
+
"n03532672": "hook, claw",
|
| 607 |
+
"n03534580": "hoopskirt, crinoline",
|
| 608 |
+
"n03535780": "horizontal bar, high bar",
|
| 609 |
+
"n03538406": "horse cart, horse-cart",
|
| 610 |
+
"n03544143": "hourglass",
|
| 611 |
+
"n03584254": "iPod",
|
| 612 |
+
"n03584829": "iron, smoothing iron",
|
| 613 |
+
"n03590841": "jack-o'-lantern",
|
| 614 |
+
"n03594734": "jean, blue jean, denim",
|
| 615 |
+
"n03594945": "jeep, landrover",
|
| 616 |
+
"n03595614": "jersey, T-shirt, tee shirt",
|
| 617 |
+
"n03598930": "jigsaw puzzle",
|
| 618 |
+
"n03599486": "jinrikisha, ricksha, rickshaw",
|
| 619 |
+
"n03602883": "joystick",
|
| 620 |
+
"n03617480": "kimono",
|
| 621 |
+
"n03623198": "knee pad",
|
| 622 |
+
"n03627232": "knot",
|
| 623 |
+
"n03630383": "lab coat, laboratory coat",
|
| 624 |
+
"n03633091": "ladle",
|
| 625 |
+
"n03637318": "lampshade, lamp shade",
|
| 626 |
+
"n03642806": "laptop, laptop computer",
|
| 627 |
+
"n03649909": "lawn mower, mower",
|
| 628 |
+
"n03657121": "lens cap, lens cover",
|
| 629 |
+
"n03658185": "letter opener, paper knife, paperknife",
|
| 630 |
+
"n03661043": "library",
|
| 631 |
+
"n03662601": "lifeboat",
|
| 632 |
+
"n03666591": "lighter, light, igniter, ignitor",
|
| 633 |
+
"n03670208": "limousine, limo",
|
| 634 |
+
"n03673027": "liner, ocean liner",
|
| 635 |
+
"n03676483": "lipstick, lip rouge",
|
| 636 |
+
"n03680355": "Loafer",
|
| 637 |
+
"n03690938": "lotion",
|
| 638 |
+
"n03691459": "loudspeaker, speaker, speaker unit, loudspeaker system, speaker system",
|
| 639 |
+
"n03692522": "loupe, jeweler's loupe",
|
| 640 |
+
"n03697007": "lumbermill, sawmill",
|
| 641 |
+
"n03706229": "magnetic compass",
|
| 642 |
+
"n03709823": "mailbag, postbag",
|
| 643 |
+
"n03710193": "mailbox, letter box",
|
| 644 |
+
"n03710637": "maillot",
|
| 645 |
+
"n03710721": "maillot, tank suit",
|
| 646 |
+
"n03717622": "manhole cover",
|
| 647 |
+
"n03720891": "maraca",
|
| 648 |
+
"n03721384": "marimba, xylophone",
|
| 649 |
+
"n03724870": "mask",
|
| 650 |
+
"n03729826": "matchstick",
|
| 651 |
+
"n03733131": "maypole",
|
| 652 |
+
"n03733281": "maze, labyrinth",
|
| 653 |
+
"n03733805": "measuring cup",
|
| 654 |
+
"n03742115": "medicine chest, medicine cabinet",
|
| 655 |
+
"n03743016": "megalith, megalithic structure",
|
| 656 |
+
"n03759954": "microphone, mike",
|
| 657 |
+
"n03761084": "microwave, microwave oven",
|
| 658 |
+
"n03763968": "military uniform",
|
| 659 |
+
"n03764736": "milk can",
|
| 660 |
+
"n03769881": "minibus",
|
| 661 |
+
"n03770439": "miniskirt, mini",
|
| 662 |
+
"n03770679": "minivan",
|
| 663 |
+
"n03773504": "missile",
|
| 664 |
+
"n03775071": "mitten",
|
| 665 |
+
"n03775546": "mixing bowl",
|
| 666 |
+
"n03776460": "mobile home, manufactured home",
|
| 667 |
+
"n03777568": "Model T",
|
| 668 |
+
"n03777754": "modem",
|
| 669 |
+
"n03781244": "monastery",
|
| 670 |
+
"n03782006": "monitor",
|
| 671 |
+
"n03785016": "moped",
|
| 672 |
+
"n03786901": "mortar",
|
| 673 |
+
"n03787032": "mortarboard",
|
| 674 |
+
"n03788195": "mosque",
|
| 675 |
+
"n03788365": "mosquito net",
|
| 676 |
+
"n03791053": "motor scooter, scooter",
|
| 677 |
+
"n03792782": "mountain bike, all-terrain bike, off-roader",
|
| 678 |
+
"n03792972": "mountain tent",
|
| 679 |
+
"n03793489": "mouse, computer mouse",
|
| 680 |
+
"n03794056": "mousetrap",
|
| 681 |
+
"n03796401": "moving van",
|
| 682 |
+
"n03803284": "muzzle",
|
| 683 |
+
"n03804744": "nail",
|
| 684 |
+
"n03814639": "neck brace",
|
| 685 |
+
"n03814906": "necklace",
|
| 686 |
+
"n03825788": "nipple",
|
| 687 |
+
"n03832673": "notebook, notebook computer",
|
| 688 |
+
"n03837869": "obelisk",
|
| 689 |
+
"n03838899": "oboe, hautboy, hautbois",
|
| 690 |
+
"n03840681": "ocarina, sweet potato",
|
| 691 |
+
"n03841143": "odometer, hodometer, mileometer, milometer",
|
| 692 |
+
"n03843555": "oil filter",
|
| 693 |
+
"n03854065": "organ, pipe organ",
|
| 694 |
+
"n03857828": "oscilloscope, scope, cathode-ray oscilloscope, CRO",
|
| 695 |
+
"n03866082": "overskirt",
|
| 696 |
+
"n03868242": "oxcart",
|
| 697 |
+
"n03868863": "oxygen mask",
|
| 698 |
+
"n03871628": "packet",
|
| 699 |
+
"n03873416": "paddle, boat paddle",
|
| 700 |
+
"n03874293": "paddlewheel, paddle wheel",
|
| 701 |
+
"n03874599": "padlock",
|
| 702 |
+
"n03876231": "paintbrush",
|
| 703 |
+
"n03877472": "pajama, pyjama, pj's, jammies",
|
| 704 |
+
"n03877845": "palace",
|
| 705 |
+
"n03884397": "panpipe, pandean pipe, syrinx",
|
| 706 |
+
"n03887697": "paper towel",
|
| 707 |
+
"n03888257": "parachute, chute",
|
| 708 |
+
"n03888605": "parallel bars, bars",
|
| 709 |
+
"n03891251": "park bench",
|
| 710 |
+
"n03891332": "parking meter",
|
| 711 |
+
"n03895866": "passenger car, coach, carriage",
|
| 712 |
+
"n03899768": "patio, terrace",
|
| 713 |
+
"n03902125": "pay-phone, pay-station",
|
| 714 |
+
"n03903868": "pedestal, plinth, footstall",
|
| 715 |
+
"n03908618": "pencil box, pencil case",
|
| 716 |
+
"n03908714": "pencil sharpener",
|
| 717 |
+
"n03916031": "perfume, essence",
|
| 718 |
+
"n03920288": "Petri dish",
|
| 719 |
+
"n03924679": "photocopier",
|
| 720 |
+
"n03929660": "pick, plectrum, plectron",
|
| 721 |
+
"n03929855": "pickelhaube",
|
| 722 |
+
"n03930313": "picket fence, paling",
|
| 723 |
+
"n03930630": "pickup, pickup truck",
|
| 724 |
+
"n03933933": "pier",
|
| 725 |
+
"n03935335": "piggy bank, penny bank",
|
| 726 |
+
"n03937543": "pill bottle",
|
| 727 |
+
"n03938244": "pillow",
|
| 728 |
+
"n03942813": "ping-pong ball",
|
| 729 |
+
"n03944341": "pinwheel",
|
| 730 |
+
"n03947888": "pirate, pirate ship",
|
| 731 |
+
"n03950228": "pitcher, ewer",
|
| 732 |
+
"n03954731": "plane, carpenter's plane, woodworking plane",
|
| 733 |
+
"n03956157": "planetarium",
|
| 734 |
+
"n03958227": "plastic bag",
|
| 735 |
+
"n03961711": "plate rack",
|
| 736 |
+
"n03967562": "plow, plough",
|
| 737 |
+
"n03970156": "plunger, plumber's helper",
|
| 738 |
+
"n03976467": "Polaroid camera, Polaroid Land camera",
|
| 739 |
+
"n03976657": "pole",
|
| 740 |
+
"n03977966": "police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria",
|
| 741 |
+
"n03980874": "poncho",
|
| 742 |
+
"n03982430": "pool table, billiard table, snooker table",
|
| 743 |
+
"n03983396": "pop bottle, soda bottle",
|
| 744 |
+
"n03991062": "pot, flowerpot",
|
| 745 |
+
"n03992509": "potter's wheel",
|
| 746 |
+
"n03995372": "power drill",
|
| 747 |
+
"n03998194": "prayer rug, prayer mat",
|
| 748 |
+
"n04004767": "printer",
|
| 749 |
+
"n04005630": "prison, prison house",
|
| 750 |
+
"n04008634": "projectile, missile",
|
| 751 |
+
"n04009552": "projector",
|
| 752 |
+
"n04019541": "puck, hockey puck",
|
| 753 |
+
"n04023962": "punching bag, punch bag, punching ball, punchball",
|
| 754 |
+
"n04026417": "purse",
|
| 755 |
+
"n04033901": "quill, quill pen",
|
| 756 |
+
"n04033995": "quilt, comforter, comfort, puff",
|
| 757 |
+
"n04037443": "racer, race car, racing car",
|
| 758 |
+
"n04039381": "racket, racquet",
|
| 759 |
+
"n04040759": "radiator",
|
| 760 |
+
"n04041544": "radio, wireless",
|
| 761 |
+
"n04044716": "radio telescope, radio reflector",
|
| 762 |
+
"n04049303": "rain barrel",
|
| 763 |
+
"n04065272": "recreational vehicle, RV, R.V.",
|
| 764 |
+
"n04067472": "reel",
|
| 765 |
+
"n04069434": "reflex camera",
|
| 766 |
+
"n04070727": "refrigerator, icebox",
|
| 767 |
+
"n04074963": "remote control, remote",
|
| 768 |
+
"n04081281": "restaurant, eating house, eating place, eatery",
|
| 769 |
+
"n04086273": "revolver, six-gun, six-shooter",
|
| 770 |
+
"n04090263": "rifle",
|
| 771 |
+
"n04099969": "rocking chair, rocker",
|
| 772 |
+
"n04111531": "rotisserie",
|
| 773 |
+
"n04116512": "rubber eraser, rubber, pencil eraser",
|
| 774 |
+
"n04118538": "rugby ball",
|
| 775 |
+
"n04118776": "rule, ruler",
|
| 776 |
+
"n04120489": "running shoe",
|
| 777 |
+
"n04125021": "safe",
|
| 778 |
+
"n04127249": "safety pin",
|
| 779 |
+
"n04131690": "saltshaker, salt shaker",
|
| 780 |
+
"n04133789": "sandal",
|
| 781 |
+
"n04136333": "sarong",
|
| 782 |
+
"n04141076": "sax, saxophone",
|
| 783 |
+
"n04141327": "scabbard",
|
| 784 |
+
"n04141975": "scale, weighing machine",
|
| 785 |
+
"n04146614": "school bus",
|
| 786 |
+
"n04147183": "schooner",
|
| 787 |
+
"n04149813": "scoreboard",
|
| 788 |
+
"n04152593": "screen, CRT screen",
|
| 789 |
+
"n04153751": "screw",
|
| 790 |
+
"n04154565": "screwdriver",
|
| 791 |
+
"n04162706": "seat belt, seatbelt",
|
| 792 |
+
"n04179913": "sewing machine",
|
| 793 |
+
"n04192698": "shield, buckler",
|
| 794 |
+
"n04200800": "shoe shop, shoe-shop, shoe store",
|
| 795 |
+
"n04201297": "shoji",
|
| 796 |
+
"n04204238": "shopping basket",
|
| 797 |
+
"n04204347": "shopping cart",
|
| 798 |
+
"n04208210": "shovel",
|
| 799 |
+
"n04209133": "shower cap",
|
| 800 |
+
"n04209239": "shower curtain",
|
| 801 |
+
"n04228054": "ski",
|
| 802 |
+
"n04229816": "ski mask",
|
| 803 |
+
"n04235860": "sleeping bag",
|
| 804 |
+
"n04238763": "slide rule, slipstick",
|
| 805 |
+
"n04239074": "sliding door",
|
| 806 |
+
"n04243546": "slot, one-armed bandit",
|
| 807 |
+
"n04251144": "snorkel",
|
| 808 |
+
"n04252077": "snowmobile",
|
| 809 |
+
"n04252225": "snowplow, snowplough",
|
| 810 |
+
"n04254120": "soap dispenser",
|
| 811 |
+
"n04254680": "soccer ball",
|
| 812 |
+
"n04254777": "sock",
|
| 813 |
+
"n04258138": "solar dish, solar collector, solar furnace",
|
| 814 |
+
"n04259630": "sombrero",
|
| 815 |
+
"n04263257": "soup bowl",
|
| 816 |
+
"n04264628": "space bar",
|
| 817 |
+
"n04265275": "space heater",
|
| 818 |
+
"n04266014": "space shuttle",
|
| 819 |
+
"n04270147": "spatula",
|
| 820 |
+
"n04273569": "speedboat",
|
| 821 |
+
"n04275548": "spider web, spider's web",
|
| 822 |
+
"n04277352": "spindle",
|
| 823 |
+
"n04285008": "sports car, sport car",
|
| 824 |
+
"n04286575": "spotlight, spot",
|
| 825 |
+
"n04296562": "stage",
|
| 826 |
+
"n04310018": "steam locomotive",
|
| 827 |
+
"n04311004": "steel arch bridge",
|
| 828 |
+
"n04311174": "steel drum",
|
| 829 |
+
"n04317175": "stethoscope",
|
| 830 |
+
"n04325704": "stole",
|
| 831 |
+
"n04326547": "stone wall",
|
| 832 |
+
"n04328186": "stopwatch, stop watch",
|
| 833 |
+
"n04330267": "stove",
|
| 834 |
+
"n04332243": "strainer",
|
| 835 |
+
"n04335435": "streetcar, tram, tramcar, trolley, trolley car",
|
| 836 |
+
"n04336792": "stretcher",
|
| 837 |
+
"n04344873": "studio couch, day bed",
|
| 838 |
+
"n04346328": "stupa, tope",
|
| 839 |
+
"n04347754": "submarine, pigboat, sub, U-boat",
|
| 840 |
+
"n04350905": "suit, suit of clothes",
|
| 841 |
+
"n04355338": "sundial",
|
| 842 |
+
"n04355933": "sunglass",
|
| 843 |
+
"n04356056": "sunglasses, dark glasses, shades",
|
| 844 |
+
"n04357314": "sunscreen, sunblock, sun blocker",
|
| 845 |
+
"n04366367": "suspension bridge",
|
| 846 |
+
"n04367480": "swab, swob, mop",
|
| 847 |
+
"n04370456": "sweatshirt",
|
| 848 |
+
"n04371430": "swimming trunks, bathing trunks",
|
| 849 |
+
"n04371774": "swing",
|
| 850 |
+
"n04372370": "switch, electric switch, electrical switch",
|
| 851 |
+
"n04376876": "syringe",
|
| 852 |
+
"n04380533": "table lamp",
|
| 853 |
+
"n04389033": "tank, army tank, armored combat vehicle, armoured combat vehicle",
|
| 854 |
+
"n04392985": "tape player",
|
| 855 |
+
"n04398044": "teapot",
|
| 856 |
+
"n04399382": "teddy, teddy bear",
|
| 857 |
+
"n04404412": "television, television system",
|
| 858 |
+
"n04409515": "tennis ball",
|
| 859 |
+
"n04417672": "thatch, thatched roof",
|
| 860 |
+
"n04418357": "theater curtain, theatre curtain",
|
| 861 |
+
"n04423845": "thimble",
|
| 862 |
+
"n04428191": "thresher, thrasher, threshing machine",
|
| 863 |
+
"n04429376": "throne",
|
| 864 |
+
"n04435653": "tile roof",
|
| 865 |
+
"n04442312": "toaster",
|
| 866 |
+
"n04443257": "tobacco shop, tobacconist shop, tobacconist",
|
| 867 |
+
"n04447861": "toilet seat",
|
| 868 |
+
"n04456115": "torch",
|
| 869 |
+
"n04458633": "totem pole",
|
| 870 |
+
"n04461696": "tow truck, tow car, wrecker",
|
| 871 |
+
"n04462240": "toyshop",
|
| 872 |
+
"n04465501": "tractor",
|
| 873 |
+
"n04467665": "trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi",
|
| 874 |
+
"n04476259": "tray",
|
| 875 |
+
"n04479046": "trench coat",
|
| 876 |
+
"n04482393": "tricycle, trike, velocipede",
|
| 877 |
+
"n04483307": "trimaran",
|
| 878 |
+
"n04485082": "tripod",
|
| 879 |
+
"n04486054": "triumphal arch",
|
| 880 |
+
"n04487081": "trolleybus, trolley coach, trackless trolley",
|
| 881 |
+
"n04487394": "trombone",
|
| 882 |
+
"n04493381": "tub, vat",
|
| 883 |
+
"n04501370": "turnstile",
|
| 884 |
+
"n04505470": "typewriter keyboard",
|
| 885 |
+
"n04507155": "umbrella",
|
| 886 |
+
"n04509417": "unicycle, monocycle",
|
| 887 |
+
"n04515003": "upright, upright piano",
|
| 888 |
+
"n04517823": "vacuum, vacuum cleaner",
|
| 889 |
+
"n04522168": "vase",
|
| 890 |
+
"n04523525": "vault",
|
| 891 |
+
"n04525038": "velvet",
|
| 892 |
+
"n04525305": "vending machine",
|
| 893 |
+
"n04532106": "vestment",
|
| 894 |
+
"n04532670": "viaduct",
|
| 895 |
+
"n04536866": "violin, fiddle",
|
| 896 |
+
"n04540053": "volleyball",
|
| 897 |
+
"n04542943": "waffle iron",
|
| 898 |
+
"n04548280": "wall clock",
|
| 899 |
+
"n04548362": "wallet, billfold, notecase, pocketbook",
|
| 900 |
+
"n04550184": "wardrobe, closet, press",
|
| 901 |
+
"n04552348": "warplane, military plane",
|
| 902 |
+
"n04553703": "washbasin, handbasin, washbowl, lavabo, wash-hand basin",
|
| 903 |
+
"n04554684": "washer, automatic washer, washing machine",
|
| 904 |
+
"n04557648": "water bottle",
|
| 905 |
+
"n04560804": "water jug",
|
| 906 |
+
"n04562935": "water tower",
|
| 907 |
+
"n04579145": "whiskey jug",
|
| 908 |
+
"n04579432": "whistle",
|
| 909 |
+
"n04584207": "wig",
|
| 910 |
+
"n04589890": "window screen",
|
| 911 |
+
"n04590129": "window shade",
|
| 912 |
+
"n04591157": "Windsor tie",
|
| 913 |
+
"n04591713": "wine bottle",
|
| 914 |
+
"n04592741": "wing",
|
| 915 |
+
"n04596742": "wok",
|
| 916 |
+
"n04597913": "wooden spoon",
|
| 917 |
+
"n04599235": "wool, woolen, woollen",
|
| 918 |
+
"n04604644": "worm fence, snake fence, snake-rail fence, Virginia fence",
|
| 919 |
+
"n04606251": "wreck",
|
| 920 |
+
"n04612504": "yawl",
|
| 921 |
+
"n04613696": "yurt",
|
| 922 |
+
"n06359193": "web site, website, internet site, site",
|
| 923 |
+
"n06596364": "comic book",
|
| 924 |
+
"n06785654": "crossword puzzle, crossword",
|
| 925 |
+
"n06794110": "street sign",
|
| 926 |
+
"n06874185": "traffic light, traffic signal, stoplight",
|
| 927 |
+
"n07248320": "book jacket, dust cover, dust jacket, dust wrapper",
|
| 928 |
+
"n07565083": "menu",
|
| 929 |
+
"n07579787": "plate",
|
| 930 |
+
"n07583066": "guacamole",
|
| 931 |
+
"n07584110": "consomme",
|
| 932 |
+
"n07590611": "hot pot, hotpot",
|
| 933 |
+
"n07613480": "trifle",
|
| 934 |
+
"n07614500": "ice cream, icecream",
|
| 935 |
+
"n07615774": "ice lolly, lolly, lollipop, popsicle",
|
| 936 |
+
"n07684084": "French loaf",
|
| 937 |
+
"n07693725": "bagel, beigel",
|
| 938 |
+
"n07695742": "pretzel",
|
| 939 |
+
"n07697313": "cheeseburger",
|
| 940 |
+
"n07697537": "hotdog, hot dog, red hot",
|
| 941 |
+
"n07711569": "mashed potato",
|
| 942 |
+
"n07714571": "head cabbage",
|
| 943 |
+
"n07714990": "broccoli",
|
| 944 |
+
"n07715103": "cauliflower",
|
| 945 |
+
"n07716358": "zucchini, courgette",
|
| 946 |
+
"n07716906": "spaghetti squash",
|
| 947 |
+
"n07717410": "acorn squash",
|
| 948 |
+
"n07717556": "butternut squash",
|
| 949 |
+
"n07718472": "cucumber, cuke",
|
| 950 |
+
"n07718747": "artichoke, globe artichoke",
|
| 951 |
+
"n07720875": "bell pepper",
|
| 952 |
+
"n07730033": "cardoon",
|
| 953 |
+
"n07734744": "mushroom",
|
| 954 |
+
"n07742313": "Granny Smith",
|
| 955 |
+
"n07745940": "strawberry",
|
| 956 |
+
"n07747607": "orange",
|
| 957 |
+
"n07749582": "lemon",
|
| 958 |
+
"n07753113": "fig",
|
| 959 |
+
"n07753275": "pineapple, ananas",
|
| 960 |
+
"n07753592": "banana",
|
| 961 |
+
"n07754684": "jackfruit, jak, jack",
|
| 962 |
+
"n07760859": "custard apple",
|
| 963 |
+
"n07768694": "pomegranate",
|
| 964 |
+
"n07802026": "hay",
|
| 965 |
+
"n07831146": "carbonara",
|
| 966 |
+
"n07836838": "chocolate sauce, chocolate syrup",
|
| 967 |
+
"n07860988": "dough",
|
| 968 |
+
"n07871810": "meat loaf, meatloaf",
|
| 969 |
+
"n07873807": "pizza, pizza pie",
|
| 970 |
+
"n07875152": "potpie",
|
| 971 |
+
"n07880968": "burrito",
|
| 972 |
+
"n07892512": "red wine",
|
| 973 |
+
"n07920052": "espresso",
|
| 974 |
+
"n07930864": "cup",
|
| 975 |
+
"n07932039": "eggnog",
|
| 976 |
+
"n09193705": "alp",
|
| 977 |
+
"n09229709": "bubble",
|
| 978 |
+
"n09246464": "cliff, drop, drop-off",
|
| 979 |
+
"n09256479": "coral reef",
|
| 980 |
+
"n09288635": "geyser",
|
| 981 |
+
"n09332890": "lakeside, lakeshore",
|
| 982 |
+
"n09399592": "promontory, headland, head, foreland",
|
| 983 |
+
"n09421951": "sandbar, sand bar",
|
| 984 |
+
"n09428293": "seashore, coast, seacoast, sea-coast",
|
| 985 |
+
"n09468604": "valley, vale",
|
| 986 |
+
"n09472597": "volcano",
|
| 987 |
+
"n09835506": "ballplayer, baseball player",
|
| 988 |
+
"n10148035": "groom, bridegroom",
|
| 989 |
+
"n10565667": "scuba diver",
|
| 990 |
+
"n11879895": "rapeseed",
|
| 991 |
+
"n11939491": "daisy",
|
| 992 |
+
"n12057211": "yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum",
|
| 993 |
+
"n12144580": "corn",
|
| 994 |
+
"n12267677": "acorn",
|
| 995 |
+
"n12620546": "hip, rose hip, rosehip",
|
| 996 |
+
"n12768682": "buckeye, horse chestnut, conker",
|
| 997 |
+
"n12985857": "coral fungus",
|
| 998 |
+
"n12998815": "agaric",
|
| 999 |
+
"n13037406": "gyromitra",
|
| 1000 |
+
"n13040303": "stinkhorn, carrion fungus",
|
| 1001 |
+
"n13044778": "earthstar",
|
| 1002 |
+
"n13052670": "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa",
|
| 1003 |
+
"n13054560": "bolete",
|
| 1004 |
+
"n13133613": "ear, spike, capitulum",
|
| 1005 |
+
"n15075141": "toilet tissue, toilet paper, bathroom tissue",
|
| 1006 |
+
}
|
| 1007 |
+
)
|
tasks/image_classification/plotting.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import imageio
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import matplotlib as mpl
|
| 9 |
+
from matplotlib import patheffects
|
| 10 |
+
mpl.use('Agg')
|
| 11 |
+
import seaborn as sns
|
| 12 |
+
import numpy as np
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
sns.set_style('darkgrid')
|
| 15 |
+
|
| 16 |
+
from tqdm.auto import tqdm
|
| 17 |
+
from scipy import ndimage
|
| 18 |
+
import umap
|
| 19 |
+
from scipy.special import softmax
|
| 20 |
+
|
| 21 |
+
import subprocess as sp
|
| 22 |
+
import cv2 # Still potentially useful for color conversion checks if needed
|
| 23 |
+
import os
|
| 24 |
+
|
| 25 |
+
def save_frames_to_mp4(frames, output_filename, fps=15.0, gop_size=None, crf=23, preset='medium', pix_fmt='yuv420p'):
|
| 26 |
+
"""
|
| 27 |
+
Saves a list of NumPy array frames to an MP4 video file using FFmpeg via subprocess.
|
| 28 |
+
|
| 29 |
+
Includes fix for odd frame dimensions by padding to the nearest even number using -vf pad.
|
| 30 |
+
|
| 31 |
+
Requires FFmpeg to be installed and available in the system PATH.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
frames (list): A list of NumPy arrays representing the video frames.
|
| 35 |
+
Expected format: uint8, (height, width, 3) for BGR color
|
| 36 |
+
or (height, width) for grayscale. Should be consistent.
|
| 37 |
+
output_filename (str): The path and name for the output MP4 file.
|
| 38 |
+
fps (float, optional): Frames per second for the output video. Defaults to 15.0.
|
| 39 |
+
gop_size (int, optional): Group of Pictures (GOP) size. This determines the
|
| 40 |
+
maximum interval between keyframes. Lower values
|
| 41 |
+
mean more frequent keyframes (better seeking, larger file).
|
| 42 |
+
Defaults to int(fps) (approx 1 keyframe per second).
|
| 43 |
+
crf (int, optional): Constant Rate Factor for H.264 encoding. Lower values mean
|
| 44 |
+
better quality and larger files. Typical range: 18-28.
|
| 45 |
+
Defaults to 23.
|
| 46 |
+
preset (str, optional): FFmpeg encoding speed preset. Affects encoding time
|
| 47 |
+
and compression efficiency. Options include 'ultrafast',
|
| 48 |
+
'superfast', 'veryfast', 'faster', 'fast', 'medium',
|
| 49 |
+
'slow', 'slower', 'veryslow'. Defaults to 'medium'.
|
| 50 |
+
"""
|
| 51 |
+
if not frames:
|
| 52 |
+
print("Error: The 'frames' list is empty. No video to save.")
|
| 53 |
+
return
|
| 54 |
+
|
| 55 |
+
# --- Determine Parameters from First Frame ---
|
| 56 |
+
try:
|
| 57 |
+
first_frame = frames[0]
|
| 58 |
+
print(first_frame.shape)
|
| 59 |
+
if not isinstance(first_frame, np.ndarray):
|
| 60 |
+
print(f"Error: Frame 0 is not a NumPy array (type: {type(first_frame)}).")
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
frame_height, frame_width = first_frame.shape[:2]
|
| 64 |
+
frame_size_str = f"{frame_width}x{frame_height}"
|
| 65 |
+
|
| 66 |
+
# Determine input pixel format based on first frame's shape
|
| 67 |
+
if len(first_frame.shape) == 3 and first_frame.shape[2] == 3:
|
| 68 |
+
input_pixel_format = 'bgr24' # Assume OpenCV's default BGR uint8
|
| 69 |
+
expected_dims = 3
|
| 70 |
+
print(f"Info: Detected color frames (shape: {first_frame.shape}). Expecting BGR input.")
|
| 71 |
+
elif len(first_frame.shape) == 2:
|
| 72 |
+
input_pixel_format = 'gray'
|
| 73 |
+
expected_dims = 2
|
| 74 |
+
print(f"Info: Detected grayscale frames (shape: {first_frame.shape}).")
|
| 75 |
+
else:
|
| 76 |
+
print(f"Error: Unsupported frame shape {first_frame.shape}. Must be (h, w) or (h, w, 3).")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
if first_frame.dtype != np.uint8:
|
| 80 |
+
print(f"Warning: First frame dtype is {first_frame.dtype}. Will attempt conversion to uint8.")
|
| 81 |
+
|
| 82 |
+
except IndexError:
|
| 83 |
+
print("Error: Could not access the first frame to determine dimensions.")
|
| 84 |
+
return
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"Error processing first frame: {e}")
|
| 87 |
+
return
|
| 88 |
+
|
| 89 |
+
# --- Set GOP size default if not provided ---
|
| 90 |
+
if gop_size is None:
|
| 91 |
+
gop_size = int(fps)
|
| 92 |
+
print(f"Info: GOP size not specified, defaulting to {gop_size} (approx 1 keyframe/sec).")
|
| 93 |
+
|
| 94 |
+
# --- Construct FFmpeg Command ---
|
| 95 |
+
# ADDED -vf pad filter to ensure even dimensions for libx264/yuv420p
|
| 96 |
+
# It calculates the nearest even dimensions >= original dimensions
|
| 97 |
+
# Example: 1600x1351 -> 1600x1352
|
| 98 |
+
pad_filter = "pad=ceil(iw/2)*2:ceil(ih/2)*2"
|
| 99 |
+
|
| 100 |
+
command = [
|
| 101 |
+
'ffmpeg',
|
| 102 |
+
'-y',
|
| 103 |
+
'-f', 'rawvideo',
|
| 104 |
+
'-vcodec', 'rawvideo',
|
| 105 |
+
'-pix_fmt', input_pixel_format,
|
| 106 |
+
'-s', frame_size_str,
|
| 107 |
+
'-r', str(float(fps)),
|
| 108 |
+
'-i', '-',
|
| 109 |
+
'-vf', pad_filter, # <--- ADDED VIDEO FILTER HERE
|
| 110 |
+
'-c:v', 'libx264',
|
| 111 |
+
'-pix_fmt', pix_fmt,
|
| 112 |
+
'-preset', preset,
|
| 113 |
+
'-crf', str(crf),
|
| 114 |
+
'-g', str(gop_size),
|
| 115 |
+
'-movflags', '+faststart',
|
| 116 |
+
output_filename
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
print(f"\n--- Starting FFmpeg ---")
|
| 120 |
+
print(f"Output File: {output_filename}")
|
| 121 |
+
print(f"Parameters: FPS={fps}, Size={frame_size_str}, GOP={gop_size}, CRF={crf}, Preset={preset}")
|
| 122 |
+
print(f"Applying Filter: -vf {pad_filter} (Ensures even dimensions)")
|
| 123 |
+
# print(f"FFmpeg Command: {' '.join(command)}") # Uncomment for debugging
|
| 124 |
+
|
| 125 |
+
# --- Execute FFmpeg via Subprocess ---
|
| 126 |
+
try:
|
| 127 |
+
process = sp.Popen(command, stdin=sp.PIPE, stdout=sp.PIPE, stderr=sp.PIPE)
|
| 128 |
+
|
| 129 |
+
print(f"\nWriting {len(frames)} frames to FFmpeg...")
|
| 130 |
+
progress_interval = max(1, len(frames) // 10) # Print progress roughly 10 times
|
| 131 |
+
|
| 132 |
+
for i, frame in enumerate(frames):
|
| 133 |
+
# Basic validation and conversion for each frame
|
| 134 |
+
if not isinstance(frame, np.ndarray):
|
| 135 |
+
print(f"Warning: Frame {i} is not a numpy array (type: {type(frame)}). Skipping.")
|
| 136 |
+
continue
|
| 137 |
+
if frame.shape[0] != frame_height or frame.shape[1] != frame_width:
|
| 138 |
+
print(f"Warning: Frame {i} has different dimensions {frame.shape[:2]}! Expected ({frame_height},{frame_width}). Skipping.")
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
current_dims = len(frame.shape)
|
| 142 |
+
if current_dims != expected_dims:
|
| 143 |
+
print(f"Warning: Frame {i} has inconsistent dimensions ({current_dims}D vs expected {expected_dims}D). Skipping.")
|
| 144 |
+
continue
|
| 145 |
+
if expected_dims == 3 and frame.shape[2] != 3:
|
| 146 |
+
print(f"Warning: Frame {i} is color but doesn't have 3 channels ({frame.shape}). Skipping.")
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
if frame.dtype != np.uint8:
|
| 150 |
+
try:
|
| 151 |
+
frame = np.clip(frame, 0, 255).astype(np.uint8)
|
| 152 |
+
except Exception as clip_err:
|
| 153 |
+
print(f"Error clipping/converting frame {i} dtype: {clip_err}. Skipping.")
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
# Write frame bytes to FFmpeg's stdin
|
| 157 |
+
try:
|
| 158 |
+
process.stdin.write(frame.tobytes())
|
| 159 |
+
except (OSError, BrokenPipeError) as pipe_err:
|
| 160 |
+
print(f"\nError writing frame {i} to FFmpeg stdin: {pipe_err}")
|
| 161 |
+
print("FFmpeg process likely terminated prematurely. Check FFmpeg errors below.")
|
| 162 |
+
try:
|
| 163 |
+
# Immediately try to read stderr if pipe breaks
|
| 164 |
+
stderr_output_on_error = process.stderr.read()
|
| 165 |
+
if stderr_output_on_error:
|
| 166 |
+
print("\n--- FFmpeg stderr output on error ---")
|
| 167 |
+
print(stderr_output_on_error.decode(errors='ignore'))
|
| 168 |
+
print("--- End FFmpeg stderr ---")
|
| 169 |
+
except Exception as read_err:
|
| 170 |
+
print(f"(Could not read stderr after pipe error: {read_err})")
|
| 171 |
+
return
|
| 172 |
+
except Exception as write_err:
|
| 173 |
+
print(f"Unexpected error writing frame {i}: {write_err}. Skipping.")
|
| 174 |
+
continue
|
| 175 |
+
|
| 176 |
+
if (i + 1) % progress_interval == 0 or (i + 1) == len(frames):
|
| 177 |
+
print(f" Processed frame {i + 1}/{len(frames)}")
|
| 178 |
+
|
| 179 |
+
print("\nFinished writing frames. Closing FFmpeg stdin and waiting for completion...")
|
| 180 |
+
process.stdin.close()
|
| 181 |
+
stdout, stderr = process.communicate()
|
| 182 |
+
return_code = process.wait()
|
| 183 |
+
|
| 184 |
+
print("\n--- FFmpeg Final Status ---")
|
| 185 |
+
if return_code == 0:
|
| 186 |
+
print(f"FFmpeg process completed successfully.")
|
| 187 |
+
print(f"Video saved as: {output_filename}")
|
| 188 |
+
else:
|
| 189 |
+
print(f"FFmpeg process failed with return code {return_code}.")
|
| 190 |
+
print("--- FFmpeg Standard Error Output: ---")
|
| 191 |
+
print(stderr.decode(errors='replace')) # Print stderr captured by communicate()
|
| 192 |
+
print("--- End FFmpeg Output ---")
|
| 193 |
+
print("Review the FFmpeg error message above for details (e.g., dimension errors, parameter issues).")
|
| 194 |
+
|
| 195 |
+
except FileNotFoundError:
|
| 196 |
+
print("\n--- FATAL ERROR ---")
|
| 197 |
+
print("Error: 'ffmpeg' command not found.")
|
| 198 |
+
print("Please ensure FFmpeg is installed and its directory is included in your system's PATH environment variable.")
|
| 199 |
+
print("Download from: https://ffmpeg.org/")
|
| 200 |
+
print("-------------------")
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print(f"\nAn unexpected error occurred during FFmpeg execution: {e}")
|
| 203 |
+
|
| 204 |
+
def find_island_centers(array_2d, threshold):
|
| 205 |
+
"""
|
| 206 |
+
Finds the center of mass of each island (connected component) in a 2D array.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
array_2d: A 2D numpy array of values.
|
| 210 |
+
threshold: The threshold to binarize the array.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
A list of tuples (y, x) representing the center of mass of each island.
|
| 214 |
+
"""
|
| 215 |
+
binary_image = array_2d > threshold
|
| 216 |
+
labeled_image, num_labels = ndimage.label(binary_image)
|
| 217 |
+
centers = []
|
| 218 |
+
areas = [] # Store the area of each island
|
| 219 |
+
for i in range(1, num_labels + 1):
|
| 220 |
+
island = (labeled_image == i)
|
| 221 |
+
total_mass = np.sum(array_2d[island])
|
| 222 |
+
if total_mass > 0:
|
| 223 |
+
y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
|
| 224 |
+
x_center = np.average(x_coords[island], weights=array_2d[island])
|
| 225 |
+
y_center = np.average(y_coords[island], weights=array_2d[island])
|
| 226 |
+
centers.append((round(y_center, 4), round(x_center, 4)))
|
| 227 |
+
areas.append(np.sum(island)) # Calculate area of the island
|
| 228 |
+
return centers, areas
|
| 229 |
+
|
| 230 |
+
def plot_neural_dynamics(post_activations_history, N_to_plot, save_location, axis_snap=False, N_per_row=5, which_neurons_mid=None, mid_colours=None, use_most_active_neurons=False):
|
| 231 |
+
assert N_to_plot%N_per_row==0, f'For nice visualisation, N_to_plot={N_to_plot} must be a multiple of N_per_row={N_per_row}'
|
| 232 |
+
assert post_activations_history.shape[-1] >= N_to_plot
|
| 233 |
+
figscale = 2
|
| 234 |
+
aspect_ratio = 3
|
| 235 |
+
mosaic = np.array([[f'{i}'] for i in range(N_to_plot)]).flatten().reshape(-1, N_per_row)
|
| 236 |
+
fig_synch, axes_synch = plt.subplot_mosaic(mosaic=mosaic, figsize=(figscale*mosaic.shape[1]*aspect_ratio*0.2, figscale*mosaic.shape[0]*0.2))
|
| 237 |
+
fig_mid, axes_mid = plt.subplot_mosaic(mosaic=mosaic, figsize=(figscale*mosaic.shape[1]*aspect_ratio*0.2, figscale*mosaic.shape[0]*0.2), dpi=200)
|
| 238 |
+
|
| 239 |
+
palette = sns.color_palette("husl", 8)
|
| 240 |
+
|
| 241 |
+
which_neurons_synch = np.arange(N_to_plot)
|
| 242 |
+
# which_neurons_mid = np.arange(N_to_plot, N_to_plot*2) if post_activations_history.shape[-1] >= 2*N_to_plot else np.random.choice(np.arange(post_activations_history.shape[-1]), size=N_to_plot, replace=True)
|
| 243 |
+
random_indices = np.random.choice(np.arange(post_activations_history.shape[-1]), size=N_to_plot, replace=post_activations_history.shape[-1] < N_to_plot)
|
| 244 |
+
if use_most_active_neurons:
|
| 245 |
+
metric = np.abs(np.fft.rfft(post_activations_history, axis=0))[3:].mean(0).std(0)
|
| 246 |
+
random_indices = np.argsort(metric)[-N_to_plot:]
|
| 247 |
+
np.random.shuffle(random_indices)
|
| 248 |
+
which_neurons_mid = which_neurons_mid if which_neurons_mid is not None else random_indices
|
| 249 |
+
|
| 250 |
+
if mid_colours is None:
|
| 251 |
+
mid_colours = [palette[np.random.randint(0, 8)] for ndx in range(N_to_plot)]
|
| 252 |
+
with tqdm(total=N_to_plot, initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 253 |
+
pbar_inner.set_description('Plotting neural dynamics')
|
| 254 |
+
for ndx in range(N_to_plot):
|
| 255 |
+
|
| 256 |
+
ax_s = axes_synch[f'{ndx}']
|
| 257 |
+
ax_m = axes_mid[f'{ndx}']
|
| 258 |
+
|
| 259 |
+
traces_s = post_activations_history[:,:,which_neurons_synch[ndx]].T
|
| 260 |
+
traces_m = post_activations_history[:,:,which_neurons_mid[ndx]].T
|
| 261 |
+
c_s = palette[np.random.randint(0, 8)]
|
| 262 |
+
c_m = mid_colours[ndx]
|
| 263 |
+
|
| 264 |
+
for traces_s_here, traces_m_here in zip(traces_s, traces_m):
|
| 265 |
+
ax_s.plot(np.arange(len(traces_s_here)), traces_s_here, linestyle='-', color=c_s, alpha=0.05, linewidth=0.6)
|
| 266 |
+
ax_m.plot(np.arange(len(traces_m_here)), traces_m_here, linestyle='-', color=c_m, alpha=0.05, linewidth=0.6)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color='white', alpha=1, linewidth=2.5)
|
| 270 |
+
ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color=c_s, alpha=1, linewidth=1.3)
|
| 271 |
+
ax_s.plot(np.arange(len(traces_s[0])), traces_s[0], linestyle='-', color='black', alpha=1, linewidth=0.3)
|
| 272 |
+
ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color='white', alpha=1, linewidth=2.5)
|
| 273 |
+
ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color=c_m, alpha=1, linewidth=1.3)
|
| 274 |
+
ax_m.plot(np.arange(len(traces_m[0])), traces_m[0], linestyle='-', color='black', alpha=1, linewidth=0.3)
|
| 275 |
+
if axis_snap and np.all(np.isfinite(traces_s[0])):
|
| 276 |
+
ax_s.set_ylim([np.min(traces_s[0])-np.ptp(traces_s[0])*0.05, np.max(traces_s[0])+np.ptp(traces_s[0])*0.05])
|
| 277 |
+
ax_m.set_ylim([np.min(traces_m[0])-np.ptp(traces_m[0])*0.05, np.max(traces_m[0])+np.ptp(traces_m[0])*0.05])
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
ax_s.grid(False)
|
| 281 |
+
ax_m.grid(False)
|
| 282 |
+
ax_s.set_xlim([0, len(traces_s[0])-1])
|
| 283 |
+
ax_m.set_xlim([0, len(traces_m[0])-1])
|
| 284 |
+
|
| 285 |
+
ax_s.set_xticklabels([])
|
| 286 |
+
ax_s.set_yticklabels([])
|
| 287 |
+
|
| 288 |
+
ax_m.set_xticklabels([])
|
| 289 |
+
ax_m.set_yticklabels([])
|
| 290 |
+
pbar_inner.update(1)
|
| 291 |
+
fig_synch.tight_layout(pad=0.05)
|
| 292 |
+
fig_mid.tight_layout(pad=0.05)
|
| 293 |
+
if save_location is not None:
|
| 294 |
+
fig_synch.savefig(f'{save_location}/neural_dynamics_synch.pdf', dpi=200)
|
| 295 |
+
fig_synch.savefig(f'{save_location}/neural_dynamics_synch.png', dpi=200)
|
| 296 |
+
fig_mid.savefig(f'{save_location}/neural_dynamics_other.pdf', dpi=200)
|
| 297 |
+
fig_mid.savefig(f'{save_location}/neural_dynamics_other.png', dpi=200)
|
| 298 |
+
plt.close(fig_synch)
|
| 299 |
+
plt.close(fig_mid)
|
| 300 |
+
return fig_synch, fig_mid, which_neurons_mid, mid_colours
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def make_classification_gif(image, target, predictions, certainties, post_activations, attention_tracking, class_labels, save_location):
|
| 305 |
+
cmap_viridis = sns.color_palette('viridis', as_cmap=True)
|
| 306 |
+
cmap_spectral = sns.color_palette("Spectral", as_cmap=True)
|
| 307 |
+
figscale = 2
|
| 308 |
+
with tqdm(total=post_activations.shape[0]+1, initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 309 |
+
pbar_inner.set_description('Computing UMAP')
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
low = np.percentile(post_activations, 1, axis=0, keepdims=True)
|
| 313 |
+
high = np.percentile(post_activations, 99, axis=0, keepdims=True)
|
| 314 |
+
post_activations_normed = np.clip((post_activations - low)/(high - low), 0, 1)
|
| 315 |
+
metric = 'cosine'
|
| 316 |
+
reducer = umap.UMAP(n_components=2,
|
| 317 |
+
n_neighbors=100,
|
| 318 |
+
min_dist=3,
|
| 319 |
+
spread=3.0,
|
| 320 |
+
metric=metric,
|
| 321 |
+
random_state=None,
|
| 322 |
+
# low_memory=True,
|
| 323 |
+
) if post_activations.shape[-1] > 2048 else umap.UMAP(n_components=2,
|
| 324 |
+
n_neighbors=20,
|
| 325 |
+
min_dist=1,
|
| 326 |
+
spread=1.0,
|
| 327 |
+
metric=metric,
|
| 328 |
+
random_state=None,
|
| 329 |
+
# low_memory=True,
|
| 330 |
+
)
|
| 331 |
+
positions = reducer.fit_transform(post_activations_normed.T)
|
| 332 |
+
|
| 333 |
+
x_umap = positions[:, 0]
|
| 334 |
+
y_umap = positions[:, 1]
|
| 335 |
+
|
| 336 |
+
pbar_inner.update(1)
|
| 337 |
+
pbar_inner.set_description('Iterating through to build frames')
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
frames = []
|
| 342 |
+
route_steps = {}
|
| 343 |
+
route_colours = []
|
| 344 |
+
|
| 345 |
+
n_steps = len(post_activations)
|
| 346 |
+
n_heads = attention_tracking.shape[1]
|
| 347 |
+
step_linspace = np.linspace(0, 1, n_steps)
|
| 348 |
+
|
| 349 |
+
for stepi in np.arange(0, n_steps, 1):
|
| 350 |
+
pbar_inner.set_description('Making frames for gif')
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
attention_now = attention_tracking[max(0, stepi-5):stepi+1].mean(0) # Make it smooth for pretty
|
| 354 |
+
# attention_now[:,0,0] = 0 # Corners can be weird looking
|
| 355 |
+
# attention_now[:,0,-1] = 0
|
| 356 |
+
# attention_now[:,-1,0] = 0
|
| 357 |
+
# attention_now[:,-1,-1] = 0
|
| 358 |
+
# attention_now = (attention_tracking[:stepi+1, 0] * decay).sum(0)/(decay.sum(0))
|
| 359 |
+
certainties_now = certainties[1, :stepi+1]
|
| 360 |
+
attention_interp = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), image.shape[:2], mode='bilinear')[0]
|
| 361 |
+
attention_interp = (attention_interp.flatten(1) - attention_interp.flatten(1).min(-1, keepdim=True)[0])/(attention_interp.flatten(1).max(-1, keepdim=True)[0] - attention_interp.flatten(1).min(-1, keepdim=True)[0])
|
| 362 |
+
attention_interp = attention_interp.reshape(n_heads, image.shape[0], image.shape[1])
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
colour = list(cmap_spectral(step_linspace[stepi]))
|
| 366 |
+
route_colours.append(colour)
|
| 367 |
+
for headi in range(min(8, n_heads)):
|
| 368 |
+
com_attn = np.copy(attention_interp[headi])
|
| 369 |
+
com_attn[com_attn < np.percentile(com_attn, 97)] = 0.0
|
| 370 |
+
if headi not in route_steps:
|
| 371 |
+
A = attention_interp[headi].detach().cpu().numpy()
|
| 372 |
+
centres, areas = find_island_centers(A, threshold=0.7)
|
| 373 |
+
route_steps[headi] = [centres[np.argmax(areas)]]
|
| 374 |
+
else:
|
| 375 |
+
A = attention_interp[headi].detach().cpu().numpy()
|
| 376 |
+
centres, areas = find_island_centers(A, threshold=0.7)
|
| 377 |
+
route_steps[headi] = route_steps[headi] + [centres[np.argmax(areas)]]
|
| 378 |
+
|
| 379 |
+
mosaic = [['head_0', 'head_0_overlay', 'head_1', 'head_1_overlay'],
|
| 380 |
+
['head_2', 'head_2_overlay', 'head_3', 'head_3_overlay'],
|
| 381 |
+
['head_4', 'head_4_overlay', 'head_5', 'head_5_overlay'],
|
| 382 |
+
['head_6', 'head_6_overlay', 'head_7', 'head_7_overlay'],
|
| 383 |
+
['probabilities', 'probabilities','certainty', 'certainty'],
|
| 384 |
+
['umap', 'umap', 'umap', 'umap'],
|
| 385 |
+
['umap', 'umap', 'umap', 'umap'],
|
| 386 |
+
['umap', 'umap', 'umap', 'umap'],
|
| 387 |
+
|
| 388 |
+
]
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
img_aspect = image.shape[0]/image.shape[1]
|
| 392 |
+
# print(img_aspect)
|
| 393 |
+
aspect_ratio = (4*figscale, 8*figscale*img_aspect)
|
| 394 |
+
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 395 |
+
for ax in axes.values():
|
| 396 |
+
ax.axis('off')
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
axes['certainty'].plot(np.arange(len(certainties_now)), certainties_now, 'k-', linewidth=figscale*1, label='1-(normalised entropy)')
|
| 400 |
+
for ii, (x, y) in enumerate(zip(np.arange(len(certainties_now)), certainties_now)):
|
| 401 |
+
is_correct = predictions[:, ii].argmax(-1)==target
|
| 402 |
+
if is_correct: axes['certainty'].axvspan(ii, ii + 1, facecolor='limegreen', edgecolor=None, lw=0, alpha=0.3)
|
| 403 |
+
else:
|
| 404 |
+
axes['certainty'].axvspan(ii, ii + 1, facecolor='orchid', edgecolor=None, lw=0, alpha=0.3)
|
| 405 |
+
axes['certainty'].plot(len(certainties_now)-1, certainties_now[-1], 'k.', markersize=figscale*4)
|
| 406 |
+
axes['certainty'].axis('off')
|
| 407 |
+
axes['certainty'].set_ylim([-0.05, 1.05])
|
| 408 |
+
axes['certainty'].set_xlim([0, certainties.shape[-1]+1])
|
| 409 |
+
|
| 410 |
+
ps = torch.softmax(torch.from_numpy(predictions[:, stepi]), -1)
|
| 411 |
+
k = 15 if len(class_labels) > 15 else len(class_labels)
|
| 412 |
+
topk = torch.topk (ps, k, dim = 0, largest=True).indices.detach().cpu().numpy()
|
| 413 |
+
top_classes = np.array(class_labels)[topk]
|
| 414 |
+
true_class = target
|
| 415 |
+
colours = [('b' if ci != true_class else 'g') for ci in topk]
|
| 416 |
+
bar_heights = ps[topk].detach().cpu().numpy()
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
axes['probabilities'].bar(np.arange(len(bar_heights))[::-1], bar_heights, color=np.array(colours), alpha=1)
|
| 420 |
+
axes['probabilities'].set_ylim([0, 1])
|
| 421 |
+
axes['probabilities'].axis('off')
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
for i, (name) in enumerate(top_classes):
|
| 425 |
+
prob = ps[i]
|
| 426 |
+
is_correct = name==class_labels[true_class]
|
| 427 |
+
fg_color = 'darkgreen' if is_correct else 'crimson'
|
| 428 |
+
text_str = f'{name[:40]}'
|
| 429 |
+
axes['probabilities'].text(
|
| 430 |
+
0.05,
|
| 431 |
+
0.95 - i * 0.055, # Adjust vertical position for each line
|
| 432 |
+
text_str,
|
| 433 |
+
transform=axes['probabilities'].transAxes,
|
| 434 |
+
verticalalignment='top',
|
| 435 |
+
fontsize=8, # Increased font size
|
| 436 |
+
color=fg_color,
|
| 437 |
+
alpha=0.5,
|
| 438 |
+
path_effects=[
|
| 439 |
+
patheffects.Stroke(linewidth=3, foreground='aliceblue'),
|
| 440 |
+
patheffects.Normal()
|
| 441 |
+
])
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
attention_now = attention_tracking[max(0, stepi-5):stepi+1].mean(0) # Make it smooth for pretty
|
| 446 |
+
# attention_now = (attention_tracking[:stepi+1, 0] * decay).sum(0)/(decay.sum(0))
|
| 447 |
+
certainties_now = certainties[1, :stepi+1]
|
| 448 |
+
attention_interp = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), image.shape[:2], mode='nearest')[0]
|
| 449 |
+
attention_interp = (attention_interp.flatten(1) - attention_interp.flatten(1).min(-1, keepdim=True)[0])/(attention_interp.flatten(1).max(-1, keepdim=True)[0] - attention_interp.flatten(1).min(-1, keepdim=True)[0])
|
| 450 |
+
attention_interp = attention_interp.reshape(n_heads, image.shape[0], image.shape[1])
|
| 451 |
+
|
| 452 |
+
for hi in range(min(8, n_heads)):
|
| 453 |
+
ax = axes[f'head_{hi}']
|
| 454 |
+
img_to_plot = cmap_viridis(attention_interp[hi].detach().cpu().numpy())
|
| 455 |
+
ax.imshow(img_to_plot)
|
| 456 |
+
|
| 457 |
+
ax_overlay = axes[f'head_{hi}_overlay']
|
| 458 |
+
|
| 459 |
+
these_route_steps = route_steps[hi]
|
| 460 |
+
y_coords, x_coords = zip(*these_route_steps)
|
| 461 |
+
y_coords = image.shape[-2] - np.array(list(y_coords))-1
|
| 462 |
+
|
| 463 |
+
ax_overlay.imshow(np.flip(image, axis=0), origin='lower')
|
| 464 |
+
# ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
|
| 465 |
+
arrow_scale = 1.5 if image.shape[0] > 32 else 0.8
|
| 466 |
+
for i in range(len(these_route_steps)-1):
|
| 467 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 468 |
+
dy = y_coords[i+1] - y_coords[i]
|
| 469 |
+
|
| 470 |
+
ax_overlay.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=1.6*arrow_scale*1.3, head_width=1.9*arrow_scale*1.3, head_length=1.4*arrow_scale*1.45, fc='white', ec='white', length_includes_head = True, alpha=1)
|
| 471 |
+
ax_overlay.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=1.6*arrow_scale, head_width=1.9*arrow_scale, head_length=1.4*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
|
| 472 |
+
|
| 473 |
+
ax_overlay.set_xlim([0,image.shape[1]-1])
|
| 474 |
+
ax_overlay.set_ylim([0,image.shape[0]-1])
|
| 475 |
+
ax_overlay.axis('off')
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
z = post_activations_normed[stepi]
|
| 479 |
+
|
| 480 |
+
axes['umap'].scatter(x_umap, y_umap, s=30, c=cmap_spectral(z))
|
| 481 |
+
|
| 482 |
+
fig.tight_layout(pad=0.1)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
canvas = fig.canvas
|
| 487 |
+
canvas.draw()
|
| 488 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 489 |
+
image_numpy = (image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3])
|
| 490 |
+
frames.append(image_numpy)
|
| 491 |
+
plt.close(fig)
|
| 492 |
+
pbar_inner.update(1)
|
| 493 |
+
pbar_inner.set_description('Saving gif')
|
| 494 |
+
imageio.mimsave(save_location, frames, fps=15, loop=100)
|
tasks/image_classification/scripts/train_cifar10.sh
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m tasks.image_classification.train \
|
| 2 |
+
--log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \
|
| 3 |
+
--model ctm
|
| 4 |
+
--dataset cifar10 \
|
| 5 |
+
--d_model 256 \
|
| 6 |
+
--d_input 64 \
|
| 7 |
+
--synapse_depth 5 \
|
| 8 |
+
--heads 16 \
|
| 9 |
+
--n_synch_out 256 \
|
| 10 |
+
--n_synch_action 512 \
|
| 11 |
+
--n_random_pairing_self 0 \
|
| 12 |
+
--neuron_select_type random-pairing \
|
| 13 |
+
--iterations 50 \
|
| 14 |
+
--memory_length 15 \
|
| 15 |
+
--deep_memory \
|
| 16 |
+
--memory_hidden_dims 64 \
|
| 17 |
+
--dropout 0.0 \
|
| 18 |
+
--dropout_nlm 0 \
|
| 19 |
+
--no-do_normalisation \
|
| 20 |
+
--positional_embedding_type none \
|
| 21 |
+
--backbone_type resnet18-1 \
|
| 22 |
+
--training_iterations 600001 \
|
| 23 |
+
--warmup_steps 1000 \
|
| 24 |
+
--use_scheduler \
|
| 25 |
+
--scheduler_type cosine \
|
| 26 |
+
--weight_decay 0.0001 \
|
| 27 |
+
--save_every 1000 \
|
| 28 |
+
--track_every 2000 \
|
| 29 |
+
--n_test_batches 50 \
|
| 30 |
+
--num_workers_train 8 \
|
| 31 |
+
--batch_size 512 \
|
| 32 |
+
--batch_size_test 512 \
|
| 33 |
+
--lr 1e-4 \
|
| 34 |
+
--device 0 \
|
| 35 |
+
--seed 1
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
python -m tasks.image_classification.train \
|
| 39 |
+
--log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \
|
| 40 |
+
--model ctm
|
| 41 |
+
--dataset cifar10 \
|
| 42 |
+
--d_model 256 \
|
| 43 |
+
--d_input 64 \
|
| 44 |
+
--synapse_depth 5 \
|
| 45 |
+
--heads 16 \
|
| 46 |
+
--n_synch_out 256 \
|
| 47 |
+
--n_synch_action 512 \
|
| 48 |
+
--n_random_pairing_self 0 \
|
| 49 |
+
--neuron_select_type random-pairing \
|
| 50 |
+
--iterations 50 \
|
| 51 |
+
--memory_length 15 \
|
| 52 |
+
--deep_memory \
|
| 53 |
+
--memory_hidden_dims 64 \
|
| 54 |
+
--dropout 0.0 \
|
| 55 |
+
--dropout_nlm 0 \
|
| 56 |
+
--no-do_normalisation \
|
| 57 |
+
--positional_embedding_type none \
|
| 58 |
+
--backbone_type resnet18-1 \
|
| 59 |
+
--training_iterations 600001 \
|
| 60 |
+
--warmup_steps 1000 \
|
| 61 |
+
--use_scheduler \
|
| 62 |
+
--scheduler_type cosine \
|
| 63 |
+
--weight_decay 0.0001 \
|
| 64 |
+
--save_every 1000 \
|
| 65 |
+
--track_every 2000 \
|
| 66 |
+
--n_test_batches 50 \
|
| 67 |
+
--num_workers_train 8 \
|
| 68 |
+
--batch_size 512 \
|
| 69 |
+
--batch_size_test 512 \
|
| 70 |
+
--lr 1e-4 \
|
| 71 |
+
--device 0 \
|
| 72 |
+
--seed 2
|
| 73 |
+
|
| 74 |
+
python -m tasks.image_classification.train \
|
| 75 |
+
--log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \
|
| 76 |
+
--model ctm
|
| 77 |
+
--dataset cifar10 \
|
| 78 |
+
--d_model 256 \
|
| 79 |
+
--d_input 64 \
|
| 80 |
+
--synapse_depth 5 \
|
| 81 |
+
--heads 16 \
|
| 82 |
+
--n_synch_out 256 \
|
| 83 |
+
--n_synch_action 512 \
|
| 84 |
+
--n_random_pairing_self 0 \
|
| 85 |
+
--neuron_select_type random-pairing \
|
| 86 |
+
--iterations 50 \
|
| 87 |
+
--memory_length 15 \
|
| 88 |
+
--deep_memory \
|
| 89 |
+
--memory_hidden_dims 64 \
|
| 90 |
+
--dropout 0.0 \
|
| 91 |
+
--dropout_nlm 0 \
|
| 92 |
+
--no-do_normalisation \
|
| 93 |
+
--positional_embedding_type none \
|
| 94 |
+
--backbone_type resnet18-1 \
|
| 95 |
+
--training_iterations 600001 \
|
| 96 |
+
--warmup_steps 1000 \
|
| 97 |
+
--use_scheduler \
|
| 98 |
+
--scheduler_type cosine \
|
| 99 |
+
--weight_decay 0.0001 \
|
| 100 |
+
--save_every 1000 \
|
| 101 |
+
--track_every 2000 \
|
| 102 |
+
--n_test_batches 50 \
|
| 103 |
+
--num_workers_train 8 \
|
| 104 |
+
--batch_size 512 \
|
| 105 |
+
--batch_size_test 512 \
|
| 106 |
+
--lr 1e-4 \
|
| 107 |
+
--device 0 \
|
| 108 |
+
--seed 42
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
python -m tasks.image_classification.train \
|
| 116 |
+
--log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \
|
| 117 |
+
--dataset cifar10 \
|
| 118 |
+
--model lstm \
|
| 119 |
+
--num_layers 2 \
|
| 120 |
+
--d_model 256 \
|
| 121 |
+
--d_input 64 \
|
| 122 |
+
--heads 16 \
|
| 123 |
+
--iterations 50 \
|
| 124 |
+
--dropout 0.0 \
|
| 125 |
+
--positional_embedding_type none \
|
| 126 |
+
--backbone_type resnet18-1 \
|
| 127 |
+
--training_iterations 600001 \
|
| 128 |
+
--warmup_steps 2000 \
|
| 129 |
+
--use_scheduler \
|
| 130 |
+
--scheduler_type cosine \
|
| 131 |
+
--weight_decay 0.0001 \
|
| 132 |
+
--save_every 1000 \
|
| 133 |
+
--track_every 2000 \
|
| 134 |
+
--n_test_batches 50 \
|
| 135 |
+
--reload \
|
| 136 |
+
--num_workers_train 8 \
|
| 137 |
+
--batch_size 512 \
|
| 138 |
+
--batch_size_test 512 \
|
| 139 |
+
--lr 1e-4 \
|
| 140 |
+
--device 0 \
|
| 141 |
+
--seed 1 \
|
| 142 |
+
--no-reload
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
python -m tasks.image_classification.train \
|
| 146 |
+
--log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \
|
| 147 |
+
--dataset cifar10 \
|
| 148 |
+
--model lstm \
|
| 149 |
+
--num_layers 2 \
|
| 150 |
+
--d_model 256 \
|
| 151 |
+
--d_input 64 \
|
| 152 |
+
--heads 16 \
|
| 153 |
+
--iterations 50 \
|
| 154 |
+
--dropout 0.0 \
|
| 155 |
+
--positional_embedding_type none \
|
| 156 |
+
--backbone_type resnet18-1 \
|
| 157 |
+
--training_iterations 600001 \
|
| 158 |
+
--warmup_steps 2000 \
|
| 159 |
+
--use_scheduler \
|
| 160 |
+
--scheduler_type cosine \
|
| 161 |
+
--weight_decay 0.0001 \
|
| 162 |
+
--save_every 1000 \
|
| 163 |
+
--track_every 2000 \
|
| 164 |
+
--n_test_batches 50 \
|
| 165 |
+
--reload \
|
| 166 |
+
--num_workers_train 8 \
|
| 167 |
+
--batch_size 512 \
|
| 168 |
+
--batch_size_test 512 \
|
| 169 |
+
--lr 1e-4 \
|
| 170 |
+
--device 0 \
|
| 171 |
+
--seed 2 \
|
| 172 |
+
--no-reload
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
python -m tasks.image_classification.train \
|
| 176 |
+
--log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \
|
| 177 |
+
--dataset cifar10 \
|
| 178 |
+
--model lstm \
|
| 179 |
+
--num_layers 2 \
|
| 180 |
+
--d_model 256 \
|
| 181 |
+
--d_input 64 \
|
| 182 |
+
--heads 16 \
|
| 183 |
+
--iterations 50 \
|
| 184 |
+
--dropout 0.0 \
|
| 185 |
+
--positional_embedding_type none \
|
| 186 |
+
--backbone_type resnet18-1 \
|
| 187 |
+
--training_iterations 600001 \
|
| 188 |
+
--warmup_steps 2000 \
|
| 189 |
+
--use_scheduler \
|
| 190 |
+
--scheduler_type cosine \
|
| 191 |
+
--weight_decay 0.0001 \
|
| 192 |
+
--save_every 1000 \
|
| 193 |
+
--track_every 2000 \
|
| 194 |
+
--n_test_batches 50 \
|
| 195 |
+
--reload \
|
| 196 |
+
--num_workers_train 8 \
|
| 197 |
+
--batch_size 512 \
|
| 198 |
+
--batch_size_test 512 \
|
| 199 |
+
--lr 1e-4 \
|
| 200 |
+
--device 0 \
|
| 201 |
+
--seed 42 \
|
| 202 |
+
--no-reload
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
python -m tasks.image_classification.train \
|
| 209 |
+
--log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=1 \
|
| 210 |
+
--dataset cifar10 \
|
| 211 |
+
--model ff \
|
| 212 |
+
--d_model 256 \
|
| 213 |
+
--memory_hidden_dims 64 \
|
| 214 |
+
--dropout 0.0 \
|
| 215 |
+
--dropout_nlm 0 \
|
| 216 |
+
--backbone_type resnet18-1 \
|
| 217 |
+
--training_iterations 600001 \
|
| 218 |
+
--warmup_steps 1000 \
|
| 219 |
+
--use_scheduler \
|
| 220 |
+
--scheduler_type cosine \
|
| 221 |
+
--weight_decay 0.0001 \
|
| 222 |
+
--save_every 1000 \
|
| 223 |
+
--track_every 2000 \
|
| 224 |
+
--n_test_batches 50 \
|
| 225 |
+
--num_workers_train 8 \
|
| 226 |
+
--batch_size 512 \
|
| 227 |
+
--batch_size_test 512 \
|
| 228 |
+
--lr 1e-4 \
|
| 229 |
+
--device 0 \
|
| 230 |
+
--seed 1
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
python -m tasks.image_classification.train \
|
| 234 |
+
--log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=2 \
|
| 235 |
+
--dataset cifar10 \
|
| 236 |
+
--model ff \
|
| 237 |
+
--d_model 256 \
|
| 238 |
+
--memory_hidden_dims 64 \
|
| 239 |
+
--dropout 0.0 \
|
| 240 |
+
--dropout_nlm 0 \
|
| 241 |
+
--backbone_type resnet18-1 \
|
| 242 |
+
--training_iterations 600001 \
|
| 243 |
+
--warmup_steps 1000 \
|
| 244 |
+
--use_scheduler \
|
| 245 |
+
--scheduler_type cosine \
|
| 246 |
+
--weight_decay 0.0001 \
|
| 247 |
+
--save_every 1000 \
|
| 248 |
+
--track_every 2000 \
|
| 249 |
+
--n_test_batches 50 \
|
| 250 |
+
--num_workers_train 8 \
|
| 251 |
+
--batch_size 512 \
|
| 252 |
+
--batch_size_test 512 \
|
| 253 |
+
--lr 1e-4 \
|
| 254 |
+
--device 0 \
|
| 255 |
+
--seed 2
|
| 256 |
+
|
| 257 |
+
python -m tasks.image_classification.train \
|
| 258 |
+
--log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=42 \
|
| 259 |
+
--dataset cifar10 \
|
| 260 |
+
--model ff \
|
| 261 |
+
--d_model 256 \
|
| 262 |
+
--memory_hidden_dims 64 \
|
| 263 |
+
--dropout 0.0 \
|
| 264 |
+
--dropout_nlm 0 \
|
| 265 |
+
--backbone_type resnet18-1 \
|
| 266 |
+
--training_iterations 600001 \
|
| 267 |
+
--warmup_steps 1000 \
|
| 268 |
+
--use_scheduler \
|
| 269 |
+
--scheduler_type cosine \
|
| 270 |
+
--weight_decay 0.0001 \
|
| 271 |
+
--save_every 1000 \
|
| 272 |
+
--track_every 2000 \
|
| 273 |
+
--n_test_batches 50 \
|
| 274 |
+
--num_workers_train 8 \
|
| 275 |
+
--batch_size 512 \
|
| 276 |
+
--batch_size_test 512 \
|
| 277 |
+
--lr 1e-4 \
|
| 278 |
+
--device 0 \
|
| 279 |
+
--seed 42
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
|
tasks/image_classification/scripts/train_imagenet.sh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed \
|
| 2 |
+
--log_dir logs/imagenet/d=4096--i=1024--heads=16--sd=8--nlm=64--synch=8192-2048-32-h=64-random-pairing--iters=50x25--backbone=152x4 \
|
| 3 |
+
--model ctm \
|
| 4 |
+
--dataset imagenet \
|
| 5 |
+
--d_model 4096 \
|
| 6 |
+
--d_input 1024 \
|
| 7 |
+
--synapse_depth 8 \
|
| 8 |
+
--heads 16 \
|
| 9 |
+
--n_synch_out 8196 \
|
| 10 |
+
--n_synch_action 2048 \
|
| 11 |
+
--n_random_pairing_self 32 \
|
| 12 |
+
--neuron_select_type random-pairing \
|
| 13 |
+
--iterations 50 \
|
| 14 |
+
--memory_length 25 \
|
| 15 |
+
--deep_memory \
|
| 16 |
+
--memory_hidden_dims 64 \
|
| 17 |
+
--dropout 0.2 \
|
| 18 |
+
--dropout_nlm 0 \
|
| 19 |
+
--no-do_normalisation \
|
| 20 |
+
--positional_embedding_type none \
|
| 21 |
+
--backbone_type resnet152-4 \
|
| 22 |
+
--batch_size 64 \
|
| 23 |
+
--batch_size_test 64 \
|
| 24 |
+
--n_test_batches 200 \
|
| 25 |
+
--lr 5e-4 \
|
| 26 |
+
--gradient_clipping 20 \
|
| 27 |
+
--training_iterations 500001 \
|
| 28 |
+
--save_every 1000 \
|
| 29 |
+
--track_every 5000 \
|
| 30 |
+
--warmup_steps 10000 \
|
| 31 |
+
--use_scheduler \
|
| 32 |
+
--scheduler_type cosine \
|
| 33 |
+
--weight_decay 0.0 \
|
| 34 |
+
--seed 1 \
|
| 35 |
+
--use_amp \
|
| 36 |
+
--reload \
|
| 37 |
+
--num_workers_train 8 \
|
| 38 |
+
--use_custom_sampler
|
tasks/image_classification/train.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
sns.set_style('darkgrid')
|
| 9 |
+
import torch
|
| 10 |
+
if torch.cuda.is_available():
|
| 11 |
+
# For faster
|
| 12 |
+
torch.set_float32_matmul_precision('high')
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from tqdm.auto import tqdm
|
| 15 |
+
|
| 16 |
+
from data.custom_datasets import ImageNet
|
| 17 |
+
from torchvision import datasets
|
| 18 |
+
from torchvision import transforms
|
| 19 |
+
from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
|
| 20 |
+
from models.ctm import ContinuousThoughtMachine
|
| 21 |
+
from models.lstm import LSTMBaseline
|
| 22 |
+
from models.ff import FFBaseline
|
| 23 |
+
from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif
|
| 24 |
+
from utils.housekeeping import set_seed, zip_python_code
|
| 25 |
+
from utils.losses import image_classification_loss # Used by CTM, LSTM
|
| 26 |
+
from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
|
| 27 |
+
|
| 28 |
+
from autoclip.torch import QuantileClip
|
| 29 |
+
|
| 30 |
+
import gc
|
| 31 |
+
import torchvision
|
| 32 |
+
torchvision.disable_beta_transforms_warning()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
import warnings
|
| 36 |
+
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
|
| 37 |
+
warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
|
| 38 |
+
warnings.filterwarnings(
|
| 39 |
+
"ignore",
|
| 40 |
+
"Corrupt EXIF data",
|
| 41 |
+
UserWarning,
|
| 42 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 43 |
+
)
|
| 44 |
+
warnings.filterwarnings(
|
| 45 |
+
"ignore",
|
| 46 |
+
"UserWarning: Metadata Warning",
|
| 47 |
+
UserWarning,
|
| 48 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 49 |
+
)
|
| 50 |
+
warnings.filterwarnings(
|
| 51 |
+
"ignore",
|
| 52 |
+
"UserWarning: Truncated File Read",
|
| 53 |
+
UserWarning,
|
| 54 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def parse_args():
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
|
| 61 |
+
# Model Selection
|
| 62 |
+
parser.add_argument('--model', type=str, default='ctm', choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
|
| 63 |
+
|
| 64 |
+
# Model Architecture
|
| 65 |
+
# Common
|
| 66 |
+
parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
|
| 67 |
+
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
|
| 68 |
+
parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.')
|
| 69 |
+
# CTM / LSTM specific
|
| 70 |
+
parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
|
| 71 |
+
parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).')
|
| 72 |
+
parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
|
| 73 |
+
parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).',
|
| 74 |
+
choices=['none',
|
| 75 |
+
'learnable-fourier',
|
| 76 |
+
'multi-learnable-fourier',
|
| 77 |
+
'custom-rotational'])
|
| 78 |
+
# CTM specific
|
| 79 |
+
parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
|
| 80 |
+
parser.add_argument('--n_synch_out', type=int, default=512, help='Number of neurons to use for output synch (CTM only).')
|
| 81 |
+
parser.add_argument('--n_synch_action', type=int, default=512, help='Number of neurons to use for observation/action synch (CTM only).')
|
| 82 |
+
parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
|
| 83 |
+
parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
|
| 84 |
+
parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
|
| 85 |
+
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
|
| 86 |
+
parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).')
|
| 87 |
+
parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
|
| 88 |
+
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
|
| 89 |
+
# LSTM specific
|
| 90 |
+
parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
|
| 91 |
+
|
| 92 |
+
# Training
|
| 93 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training.')
|
| 94 |
+
parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing.')
|
| 95 |
+
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.')
|
| 96 |
+
parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
|
| 97 |
+
parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
|
| 98 |
+
parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
|
| 99 |
+
parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
|
| 100 |
+
parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
|
| 101 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
|
| 102 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
|
| 103 |
+
parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
|
| 104 |
+
parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
|
| 105 |
+
parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components (backbone, synapses if CTM).')
|
| 106 |
+
parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.')
|
| 107 |
+
|
| 108 |
+
# Housekeeping
|
| 109 |
+
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 110 |
+
parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.')
|
| 111 |
+
parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
|
| 112 |
+
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
| 113 |
+
parser.add_argument('--seed', type=int, default=412, help='Random seed.')
|
| 114 |
+
parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
|
| 115 |
+
parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
|
| 116 |
+
parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
|
| 117 |
+
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
|
| 118 |
+
parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
|
| 119 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
|
| 120 |
+
parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
args = parser.parse_args()
|
| 124 |
+
return args
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def get_dataset(dataset, root):
|
| 128 |
+
if dataset=='imagenet':
|
| 129 |
+
dataset_mean = [0.485, 0.456, 0.406]
|
| 130 |
+
dataset_std = [0.229, 0.224, 0.225]
|
| 131 |
+
|
| 132 |
+
normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
|
| 133 |
+
train_transform = transforms.Compose([
|
| 134 |
+
transforms.RandomResizedCrop(224),
|
| 135 |
+
transforms.RandomHorizontalFlip(),
|
| 136 |
+
transforms.ToTensor(),
|
| 137 |
+
normalize])
|
| 138 |
+
test_transform = transforms.Compose([
|
| 139 |
+
transforms.Resize(256),
|
| 140 |
+
transforms.CenterCrop(224),
|
| 141 |
+
transforms.ToTensor(),
|
| 142 |
+
normalize])
|
| 143 |
+
|
| 144 |
+
class_labels = list(IMAGENET2012_CLASSES.values())
|
| 145 |
+
|
| 146 |
+
train_data = ImageNet(which_split='train', transform=train_transform)
|
| 147 |
+
test_data = ImageNet(which_split='validation', transform=test_transform)
|
| 148 |
+
elif dataset=='cifar10':
|
| 149 |
+
dataset_mean = [0.49139968, 0.48215827, 0.44653124]
|
| 150 |
+
dataset_std = [0.24703233, 0.24348505, 0.26158768]
|
| 151 |
+
normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
|
| 152 |
+
train_transform = transforms.Compose(
|
| 153 |
+
[transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
|
| 154 |
+
transforms.ToTensor(),
|
| 155 |
+
normalize,
|
| 156 |
+
])
|
| 157 |
+
|
| 158 |
+
test_transform = transforms.Compose(
|
| 159 |
+
[transforms.ToTensor(),
|
| 160 |
+
normalize,
|
| 161 |
+
])
|
| 162 |
+
train_data = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
|
| 163 |
+
test_data = datasets.CIFAR10(root, train=False, transform=test_transform, download=True)
|
| 164 |
+
class_labels = ['air', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
| 165 |
+
elif dataset=='cifar100':
|
| 166 |
+
dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
|
| 167 |
+
dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
|
| 168 |
+
normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
|
| 169 |
+
|
| 170 |
+
train_transform = transforms.Compose(
|
| 171 |
+
[transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
|
| 172 |
+
transforms.ToTensor(),
|
| 173 |
+
normalize,
|
| 174 |
+
])
|
| 175 |
+
test_transform = transforms.Compose(
|
| 176 |
+
[transforms.ToTensor(),
|
| 177 |
+
normalize,
|
| 178 |
+
])
|
| 179 |
+
train_data = datasets.CIFAR100(root, train=True, transform=train_transform, download=True)
|
| 180 |
+
test_data = datasets.CIFAR100(root, train=False, transform=test_transform, download=True)
|
| 181 |
+
idx_order = np.argsort(np.array(list(train_data.class_to_idx.values())))
|
| 182 |
+
class_labels = list(np.array(list(train_data.class_to_idx.keys()))[idx_order])
|
| 183 |
+
else:
|
| 184 |
+
raise NotImplementedError
|
| 185 |
+
|
| 186 |
+
return train_data, test_data, class_labels, dataset_mean, dataset_std
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if __name__=='__main__':
|
| 191 |
+
|
| 192 |
+
# Hosuekeeping
|
| 193 |
+
args = parse_args()
|
| 194 |
+
|
| 195 |
+
set_seed(args.seed, False)
|
| 196 |
+
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 197 |
+
|
| 198 |
+
assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
|
| 199 |
+
|
| 200 |
+
# Data
|
| 201 |
+
train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root)
|
| 202 |
+
|
| 203 |
+
num_workers_test = 1 # Defaulting to 1, change if needed
|
| 204 |
+
trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train)
|
| 205 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False)
|
| 206 |
+
|
| 207 |
+
prediction_reshaper = [-1] # Problem specific
|
| 208 |
+
args.out_dims = len(class_labels)
|
| 209 |
+
|
| 210 |
+
# For total reproducibility
|
| 211 |
+
zip_python_code(f'{args.log_dir}/repo_state.zip')
|
| 212 |
+
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 213 |
+
print(args, file=f)
|
| 214 |
+
|
| 215 |
+
# Configure device string
|
| 216 |
+
device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
|
| 217 |
+
print(f'Running model {args.model} on {device}')
|
| 218 |
+
|
| 219 |
+
# Build model conditionally
|
| 220 |
+
model = None
|
| 221 |
+
if args.model == 'ctm':
|
| 222 |
+
model = ContinuousThoughtMachine(
|
| 223 |
+
iterations=args.iterations,
|
| 224 |
+
d_model=args.d_model,
|
| 225 |
+
d_input=args.d_input,
|
| 226 |
+
heads=args.heads,
|
| 227 |
+
n_synch_out=args.n_synch_out,
|
| 228 |
+
n_synch_action=args.n_synch_action,
|
| 229 |
+
synapse_depth=args.synapse_depth,
|
| 230 |
+
memory_length=args.memory_length,
|
| 231 |
+
deep_nlms=args.deep_memory,
|
| 232 |
+
memory_hidden_dims=args.memory_hidden_dims,
|
| 233 |
+
do_layernorm_nlm=args.do_normalisation,
|
| 234 |
+
backbone_type=args.backbone_type,
|
| 235 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 236 |
+
out_dims=args.out_dims,
|
| 237 |
+
prediction_reshaper=prediction_reshaper,
|
| 238 |
+
dropout=args.dropout,
|
| 239 |
+
dropout_nlm=args.dropout_nlm,
|
| 240 |
+
neuron_select_type=args.neuron_select_type,
|
| 241 |
+
n_random_pairing_self=args.n_random_pairing_self,
|
| 242 |
+
).to(device)
|
| 243 |
+
elif args.model == 'lstm':
|
| 244 |
+
model = LSTMBaseline(
|
| 245 |
+
num_layers=args.num_layers,
|
| 246 |
+
iterations=args.iterations,
|
| 247 |
+
d_model=args.d_model,
|
| 248 |
+
d_input=args.d_input,
|
| 249 |
+
heads=args.heads,
|
| 250 |
+
backbone_type=args.backbone_type,
|
| 251 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 252 |
+
out_dims=args.out_dims,
|
| 253 |
+
prediction_reshaper=prediction_reshaper,
|
| 254 |
+
dropout=args.dropout,
|
| 255 |
+
).to(device)
|
| 256 |
+
elif args.model == 'ff':
|
| 257 |
+
model = FFBaseline(
|
| 258 |
+
d_model=args.d_model,
|
| 259 |
+
backbone_type=args.backbone_type,
|
| 260 |
+
out_dims=args.out_dims,
|
| 261 |
+
dropout=args.dropout,
|
| 262 |
+
).to(device)
|
| 263 |
+
else:
|
| 264 |
+
raise ValueError(f"Unknown model type: {args.model}")
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# For lazy modules so that we can get param count
|
| 268 |
+
pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
|
| 269 |
+
model(pseudo_inputs)
|
| 270 |
+
|
| 271 |
+
model.train()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
print(f'Total params: {sum(p.numel() for p in model.parameters())}')
|
| 275 |
+
decay_params = []
|
| 276 |
+
no_decay_params = []
|
| 277 |
+
no_decay_names = []
|
| 278 |
+
for name, param in model.named_parameters():
|
| 279 |
+
if not param.requires_grad:
|
| 280 |
+
continue # Skip parameters that don't require gradients
|
| 281 |
+
if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
|
| 282 |
+
no_decay_params.append(param)
|
| 283 |
+
no_decay_names.append(name)
|
| 284 |
+
else:
|
| 285 |
+
decay_params.append(param)
|
| 286 |
+
if len(no_decay_names):
|
| 287 |
+
print(f'WARNING, excluding: {no_decay_names}')
|
| 288 |
+
|
| 289 |
+
# Optimizer and scheduler (Common setup)
|
| 290 |
+
if len(no_decay_names) and args.weight_decay!=0:
|
| 291 |
+
optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
|
| 292 |
+
{'params': no_decay_params, 'weight_decay':0}],
|
| 293 |
+
lr=args.lr,
|
| 294 |
+
eps=1e-8 if not args.use_amp else 1e-6)
|
| 295 |
+
else:
|
| 296 |
+
optimizer = torch.optim.AdamW(model.parameters(),
|
| 297 |
+
lr=args.lr,
|
| 298 |
+
eps=1e-8 if not args.use_amp else 1e-6,
|
| 299 |
+
weight_decay=args.weight_decay)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
warmup_schedule = warmup(args.warmup_steps)
|
| 303 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
|
| 304 |
+
if args.use_scheduler:
|
| 305 |
+
if args.scheduler_type == 'multistep':
|
| 306 |
+
scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
|
| 307 |
+
elif args.scheduler_type == 'cosine':
|
| 308 |
+
scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
|
| 309 |
+
else:
|
| 310 |
+
raise NotImplementedError
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# Metrics tracking
|
| 314 |
+
start_iter = 0
|
| 315 |
+
train_losses = []
|
| 316 |
+
test_losses = []
|
| 317 |
+
train_accuracies = []
|
| 318 |
+
test_accuracies = []
|
| 319 |
+
iters = []
|
| 320 |
+
# Conditional metrics for CTM/LSTM
|
| 321 |
+
train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
|
| 322 |
+
test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
|
| 323 |
+
|
| 324 |
+
scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
|
| 325 |
+
|
| 326 |
+
# Reloading logic
|
| 327 |
+
if args.reload:
|
| 328 |
+
checkpoint_path = f'{args.log_dir}/checkpoint.pt'
|
| 329 |
+
if os.path.isfile(checkpoint_path):
|
| 330 |
+
print(f'Reloading from: {checkpoint_path}')
|
| 331 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 332 |
+
if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
|
| 333 |
+
load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
|
| 334 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 335 |
+
|
| 336 |
+
if not args.reload_model_only:
|
| 337 |
+
print('Reloading optimizer etc.')
|
| 338 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 339 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 340 |
+
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 341 |
+
start_iter = checkpoint['iteration']
|
| 342 |
+
# Load common metrics
|
| 343 |
+
train_losses = checkpoint['train_losses']
|
| 344 |
+
test_losses = checkpoint['test_losses']
|
| 345 |
+
train_accuracies = checkpoint['train_accuracies']
|
| 346 |
+
test_accuracies = checkpoint['test_accuracies']
|
| 347 |
+
iters = checkpoint['iters']
|
| 348 |
+
|
| 349 |
+
# Load conditional metrics if they exist in checkpoint and are expected for current model
|
| 350 |
+
if args.model in ['ctm', 'lstm']:
|
| 351 |
+
train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
|
| 352 |
+
test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
|
| 353 |
+
|
| 354 |
+
else:
|
| 355 |
+
print('Only reloading model!')
|
| 356 |
+
|
| 357 |
+
if 'torch_rng_state' in checkpoint:
|
| 358 |
+
# Reset seeds
|
| 359 |
+
torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte())
|
| 360 |
+
np.random.set_state(checkpoint['numpy_rng_state'])
|
| 361 |
+
random.setstate(checkpoint['random_rng_state'])
|
| 362 |
+
|
| 363 |
+
del checkpoint
|
| 364 |
+
gc.collect()
|
| 365 |
+
if torch.cuda.is_available():
|
| 366 |
+
torch.cuda.empty_cache()
|
| 367 |
+
|
| 368 |
+
# Conditional Compilation
|
| 369 |
+
if args.do_compile:
|
| 370 |
+
print('Compiling...')
|
| 371 |
+
if hasattr(model, 'backbone'):
|
| 372 |
+
model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True)
|
| 373 |
+
|
| 374 |
+
# Compile synapses only for CTM
|
| 375 |
+
if args.model == 'ctm':
|
| 376 |
+
model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True)
|
| 377 |
+
|
| 378 |
+
# Training
|
| 379 |
+
iterator = iter(trainloader)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar:
|
| 383 |
+
for bi in range(start_iter, args.training_iterations):
|
| 384 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 385 |
+
|
| 386 |
+
try:
|
| 387 |
+
inputs, targets = next(iterator)
|
| 388 |
+
except StopIteration:
|
| 389 |
+
iterator = iter(trainloader)
|
| 390 |
+
inputs, targets = next(iterator)
|
| 391 |
+
|
| 392 |
+
inputs = inputs.to(device)
|
| 393 |
+
targets = targets.to(device)
|
| 394 |
+
|
| 395 |
+
loss = None
|
| 396 |
+
accuracy = None
|
| 397 |
+
# Model-specific forward and loss calculation
|
| 398 |
+
with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
|
| 399 |
+
if args.do_compile: # CUDAGraph marking for clean compile
|
| 400 |
+
torch.compiler.cudagraph_mark_step_begin()
|
| 401 |
+
|
| 402 |
+
if args.model == 'ctm':
|
| 403 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 404 |
+
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 405 |
+
accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
|
| 406 |
+
pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
|
| 407 |
+
|
| 408 |
+
elif args.model == 'lstm':
|
| 409 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 410 |
+
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 411 |
+
# LSTM where_most_certain will just be -1 because use_most_certain is False owing to stability issues with LSTM training
|
| 412 |
+
accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
|
| 413 |
+
pbar_desc = f'LSTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
|
| 414 |
+
|
| 415 |
+
elif args.model == 'ff':
|
| 416 |
+
predictions = model(inputs)
|
| 417 |
+
loss = nn.CrossEntropyLoss()(predictions, targets)
|
| 418 |
+
accuracy = (predictions.argmax(1) == targets).float().mean().item()
|
| 419 |
+
pbar_desc = f'FF Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
|
| 420 |
+
|
| 421 |
+
scaler.scale(loss).backward()
|
| 422 |
+
|
| 423 |
+
if args.gradient_clipping!=-1:
|
| 424 |
+
scaler.unscale_(optimizer)
|
| 425 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
|
| 426 |
+
|
| 427 |
+
scaler.step(optimizer)
|
| 428 |
+
scaler.update()
|
| 429 |
+
optimizer.zero_grad(set_to_none=True)
|
| 430 |
+
scheduler.step()
|
| 431 |
+
|
| 432 |
+
pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
# Metrics tracking and plotting (conditional logic needed)
|
| 436 |
+
if (bi % args.track_every == 0 or bi == args.warmup_steps) and (bi != 0 or args.reload_model_only):
|
| 437 |
+
|
| 438 |
+
iters.append(bi)
|
| 439 |
+
current_train_losses = []
|
| 440 |
+
current_test_losses = []
|
| 441 |
+
current_train_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
|
| 442 |
+
current_test_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
|
| 443 |
+
current_train_accuracies_most_certain = [] # Only for CTM/LSTM
|
| 444 |
+
current_test_accuracies_most_certain = [] # Only for CTM/LSTM
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
# Reset BN stats using train mode
|
| 448 |
+
pbar.set_description('Resetting BN')
|
| 449 |
+
model.train()
|
| 450 |
+
for module in model.modules():
|
| 451 |
+
if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
|
| 452 |
+
module.reset_running_stats()
|
| 453 |
+
|
| 454 |
+
pbar.set_description('Tracking: Computing TRAIN metrics')
|
| 455 |
+
with torch.no_grad(): # Should use inference_mode? CTM/LSTM scripts used no_grad
|
| 456 |
+
loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
|
| 457 |
+
all_targets_list = []
|
| 458 |
+
all_predictions_list = [] # List to store raw predictions (B, C, T) or (B, C)
|
| 459 |
+
all_predictions_most_certain_list = [] # Only for CTM/LSTM
|
| 460 |
+
all_losses = []
|
| 461 |
+
|
| 462 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 463 |
+
for inferi, (inputs, targets) in enumerate(loader):
|
| 464 |
+
inputs = inputs.to(device)
|
| 465 |
+
targets = targets.to(device)
|
| 466 |
+
all_targets_list.append(targets.detach().cpu().numpy())
|
| 467 |
+
|
| 468 |
+
# Model-specific forward and loss for evaluation
|
| 469 |
+
if args.model == 'ctm':
|
| 470 |
+
these_predictions, certainties, _ = model(inputs)
|
| 471 |
+
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
|
| 472 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
|
| 473 |
+
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
|
| 474 |
+
|
| 475 |
+
elif args.model == 'lstm':
|
| 476 |
+
these_predictions, certainties, _ = model(inputs)
|
| 477 |
+
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
|
| 478 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
|
| 479 |
+
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
|
| 480 |
+
|
| 481 |
+
elif args.model == 'ff':
|
| 482 |
+
these_predictions = model(inputs)
|
| 483 |
+
loss = nn.CrossEntropyLoss()(these_predictions, targets)
|
| 484 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B,)
|
| 485 |
+
|
| 486 |
+
all_losses.append(loss.item())
|
| 487 |
+
|
| 488 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break # Check condition >= N-1
|
| 489 |
+
pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})')
|
| 490 |
+
pbar_inner.update(1)
|
| 491 |
+
|
| 492 |
+
all_targets = np.concatenate(all_targets_list)
|
| 493 |
+
all_predictions = np.concatenate(all_predictions_list) # Shape (N, T) or (N,)
|
| 494 |
+
train_losses.append(np.mean(all_losses))
|
| 495 |
+
|
| 496 |
+
if args.model in ['ctm', 'lstm']:
|
| 497 |
+
# Accuracies per tick for CTM/LSTM
|
| 498 |
+
current_train_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0) # Mean over batch dim -> Shape (T,)
|
| 499 |
+
train_accuracies.append(current_train_accuracies)
|
| 500 |
+
# Most certain accuracy
|
| 501 |
+
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
|
| 502 |
+
current_train_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
|
| 503 |
+
train_accuracies_most_certain.append(current_train_accuracies_most_certain)
|
| 504 |
+
else: # FF
|
| 505 |
+
current_train_accuracies = (all_targets == all_predictions).mean() # Shape scalar
|
| 506 |
+
train_accuracies.append(current_train_accuracies)
|
| 507 |
+
|
| 508 |
+
del these_predictions
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
# Switch to eval mode for test metrics (fixed BN stats)
|
| 512 |
+
model.eval()
|
| 513 |
+
pbar.set_description('Tracking: Computing TEST metrics')
|
| 514 |
+
with torch.inference_mode(): # Use inference_mode for test eval
|
| 515 |
+
loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
|
| 516 |
+
all_targets_list = []
|
| 517 |
+
all_predictions_list = []
|
| 518 |
+
all_predictions_most_certain_list = [] # Only for CTM/LSTM
|
| 519 |
+
all_losses = []
|
| 520 |
+
|
| 521 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 522 |
+
for inferi, (inputs, targets) in enumerate(loader):
|
| 523 |
+
inputs = inputs.to(device)
|
| 524 |
+
targets = targets.to(device)
|
| 525 |
+
all_targets_list.append(targets.detach().cpu().numpy())
|
| 526 |
+
|
| 527 |
+
# Model-specific forward and loss for evaluation
|
| 528 |
+
if args.model == 'ctm':
|
| 529 |
+
these_predictions, certainties, _ = model(inputs)
|
| 530 |
+
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
|
| 531 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
|
| 532 |
+
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
|
| 533 |
+
|
| 534 |
+
elif args.model == 'lstm':
|
| 535 |
+
these_predictions, certainties, _ = model(inputs)
|
| 536 |
+
loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
|
| 537 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
|
| 538 |
+
all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
|
| 539 |
+
|
| 540 |
+
elif args.model == 'ff':
|
| 541 |
+
these_predictions = model(inputs)
|
| 542 |
+
loss = nn.CrossEntropyLoss()(these_predictions, targets)
|
| 543 |
+
all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
|
| 544 |
+
|
| 545 |
+
all_losses.append(loss.item())
|
| 546 |
+
|
| 547 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 548 |
+
pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})')
|
| 549 |
+
pbar_inner.update(1)
|
| 550 |
+
|
| 551 |
+
all_targets = np.concatenate(all_targets_list)
|
| 552 |
+
all_predictions = np.concatenate(all_predictions_list)
|
| 553 |
+
test_losses.append(np.mean(all_losses))
|
| 554 |
+
|
| 555 |
+
if args.model in ['ctm', 'lstm']:
|
| 556 |
+
current_test_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0)
|
| 557 |
+
test_accuracies.append(current_test_accuracies)
|
| 558 |
+
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
|
| 559 |
+
current_test_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
|
| 560 |
+
test_accuracies_most_certain.append(current_test_accuracies_most_certain)
|
| 561 |
+
else: # FF
|
| 562 |
+
current_test_accuracies = (all_targets == all_predictions).mean()
|
| 563 |
+
test_accuracies.append(current_test_accuracies)
|
| 564 |
+
|
| 565 |
+
# Plotting (conditional)
|
| 566 |
+
figacc = plt.figure(figsize=(10, 10))
|
| 567 |
+
axacc_train = figacc.add_subplot(211)
|
| 568 |
+
axacc_test = figacc.add_subplot(212)
|
| 569 |
+
cm = sns.color_palette("viridis", as_cmap=True)
|
| 570 |
+
|
| 571 |
+
if args.model in ['ctm', 'lstm']:
|
| 572 |
+
# Plot per-tick accuracy for CTM/LSTM
|
| 573 |
+
train_acc_arr = np.array(train_accuracies) # Shape (N_iters, T)
|
| 574 |
+
test_acc_arr = np.array(test_accuracies) # Shape (N_iters, T)
|
| 575 |
+
num_ticks = train_acc_arr.shape[1]
|
| 576 |
+
for ti in range(num_ticks):
|
| 577 |
+
axacc_train.plot(iters, train_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
|
| 578 |
+
axacc_test.plot(iters, test_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
|
| 579 |
+
# Plot most certain accuracy
|
| 580 |
+
axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
|
| 581 |
+
axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
|
| 582 |
+
else: # FF
|
| 583 |
+
axacc_train.plot(iters, train_accuracies, 'k-', alpha=0.7, label='Accuracy') # Simple line
|
| 584 |
+
axacc_test.plot(iters, test_accuracies, 'k-', alpha=0.7, label='Accuracy')
|
| 585 |
+
|
| 586 |
+
axacc_train.set_title('Train Accuracy')
|
| 587 |
+
axacc_test.set_title('Test Accuracy')
|
| 588 |
+
axacc_train.legend(loc='lower right')
|
| 589 |
+
axacc_test.legend(loc='lower right')
|
| 590 |
+
axacc_train.set_xlim([0, args.training_iterations])
|
| 591 |
+
axacc_test.set_xlim([0, args.training_iterations])
|
| 592 |
+
if args.dataset=='cifar10':
|
| 593 |
+
axacc_train.set_ylim([0.75, 1])
|
| 594 |
+
axacc_test.set_ylim([0.75, 1])
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
figacc.tight_layout()
|
| 599 |
+
figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
|
| 600 |
+
plt.close(figacc)
|
| 601 |
+
|
| 602 |
+
figloss = plt.figure(figsize=(10, 5))
|
| 603 |
+
axloss = figloss.add_subplot(111)
|
| 604 |
+
axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}')
|
| 605 |
+
axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}')
|
| 606 |
+
axloss.legend(loc='upper right')
|
| 607 |
+
axloss.set_xlim([0, args.training_iterations])
|
| 608 |
+
axloss.set_ylim(bottom=0)
|
| 609 |
+
|
| 610 |
+
figloss.tight_layout()
|
| 611 |
+
figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
|
| 612 |
+
plt.close(figloss)
|
| 613 |
+
|
| 614 |
+
# Conditional Visualization (Only for CTM/LSTM)
|
| 615 |
+
if args.model in ['ctm', 'lstm']:
|
| 616 |
+
try: # For safety
|
| 617 |
+
inputs_viz, targets_viz = next(iter(testloader)) # Get a fresh batch
|
| 618 |
+
inputs_viz = inputs_viz.to(device)
|
| 619 |
+
targets_viz = targets_viz.to(device)
|
| 620 |
+
|
| 621 |
+
pbar.set_description('Tracking: Processing test data for viz')
|
| 622 |
+
predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
|
| 623 |
+
|
| 624 |
+
att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
|
| 625 |
+
attention_tracking_viz = attention_tracking_viz.reshape(
|
| 626 |
+
attention_tracking_viz.shape[0],
|
| 627 |
+
attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
|
| 628 |
+
|
| 629 |
+
pbar.set_description('Tracking: Neural dynamics plot')
|
| 630 |
+
plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
|
| 631 |
+
|
| 632 |
+
imgi = 0 # Visualize the first image in the batch
|
| 633 |
+
img_to_gif = np.moveaxis(np.clip(inputs_viz[imgi].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1)
|
| 634 |
+
|
| 635 |
+
pbar.set_description('Tracking: Producing attention gif')
|
| 636 |
+
make_classification_gif(img_to_gif,
|
| 637 |
+
targets_viz[imgi].item(),
|
| 638 |
+
predictions_viz[imgi].detach().cpu().numpy(),
|
| 639 |
+
certainties_viz[imgi].detach().cpu().numpy(),
|
| 640 |
+
post_activations_viz[:,imgi],
|
| 641 |
+
attention_tracking_viz[:,imgi],
|
| 642 |
+
class_labels,
|
| 643 |
+
f'{args.log_dir}/{imgi}_attention.gif',
|
| 644 |
+
)
|
| 645 |
+
del predictions_viz, certainties_viz, pre_activations_viz, post_activations_viz, attention_tracking_viz
|
| 646 |
+
except Exception as e:
|
| 647 |
+
print(f"Visualization failed for model {args.model}: {e}")
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
gc.collect()
|
| 652 |
+
if torch.cuda.is_available():
|
| 653 |
+
torch.cuda.empty_cache()
|
| 654 |
+
model.train() # Switch back to train mode
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
# Save model checkpoint (conditional metrics)
|
| 658 |
+
if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
|
| 659 |
+
pbar.set_description('Saving model checkpoint...')
|
| 660 |
+
checkpoint_data = {
|
| 661 |
+
'model_state_dict': model.state_dict(),
|
| 662 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 663 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 664 |
+
'scaler_state_dict': scaler.state_dict(),
|
| 665 |
+
'iteration': bi,
|
| 666 |
+
# Always save these
|
| 667 |
+
'train_losses': train_losses,
|
| 668 |
+
'test_losses': test_losses,
|
| 669 |
+
'train_accuracies': train_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
|
| 670 |
+
'test_accuracies': test_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
|
| 671 |
+
'iters': iters,
|
| 672 |
+
'args': args, # Save args used for this run
|
| 673 |
+
# RNG states
|
| 674 |
+
'torch_rng_state': torch.get_rng_state(),
|
| 675 |
+
'numpy_rng_state': np.random.get_state(),
|
| 676 |
+
'random_rng_state': random.getstate(),
|
| 677 |
+
}
|
| 678 |
+
# Conditionally add metrics specific to CTM/LSTM
|
| 679 |
+
if args.model in ['ctm', 'lstm']:
|
| 680 |
+
checkpoint_data['train_accuracies_most_certain'] = train_accuracies_most_certain
|
| 681 |
+
checkpoint_data['test_accuracies_most_certain'] = test_accuracies_most_certain
|
| 682 |
+
|
| 683 |
+
torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
|
| 684 |
+
|
| 685 |
+
pbar.update(1)
|
tasks/image_classification/train_distributed.py
ADDED
|
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
sns.set_style('darkgrid')
|
| 10 |
+
import torch
|
| 11 |
+
if torch.cuda.is_available():
|
| 12 |
+
# For faster
|
| 13 |
+
torch.set_float32_matmul_precision('high')
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.distributed as dist
|
| 16 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 17 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 18 |
+
from utils.samplers import FastRandomDistributedSampler
|
| 19 |
+
from tqdm.auto import tqdm
|
| 20 |
+
|
| 21 |
+
from tasks.image_classification.train import get_dataset # Use shared get_dataset
|
| 22 |
+
|
| 23 |
+
# Model Imports
|
| 24 |
+
from models.ctm import ContinuousThoughtMachine
|
| 25 |
+
from models.lstm import LSTMBaseline
|
| 26 |
+
from models.ff import FFBaseline
|
| 27 |
+
|
| 28 |
+
# Plotting/Utils Imports
|
| 29 |
+
from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif
|
| 30 |
+
from utils.housekeeping import set_seed, zip_python_code
|
| 31 |
+
from utils.losses import image_classification_loss # For CTM, LSTM
|
| 32 |
+
from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
|
| 33 |
+
|
| 34 |
+
import torchvision
|
| 35 |
+
torchvision.disable_beta_transforms_warning()
|
| 36 |
+
|
| 37 |
+
import warnings
|
| 38 |
+
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
|
| 39 |
+
warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
|
| 40 |
+
warnings.filterwarnings("ignore", message="UserWarning: Metadata Warning, tag 274 had too many entries: 4, expected 1")
|
| 41 |
+
warnings.filterwarnings(
|
| 42 |
+
"ignore",
|
| 43 |
+
"Corrupt EXIF data",
|
| 44 |
+
UserWarning,
|
| 45 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 46 |
+
)
|
| 47 |
+
warnings.filterwarnings(
|
| 48 |
+
"ignore",
|
| 49 |
+
"UserWarning: Metadata Warning",
|
| 50 |
+
UserWarning,
|
| 51 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 52 |
+
)
|
| 53 |
+
warnings.filterwarnings(
|
| 54 |
+
"ignore",
|
| 55 |
+
"UserWarning: Truncated File Read",
|
| 56 |
+
UserWarning,
|
| 57 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def parse_args():
|
| 62 |
+
parser = argparse.ArgumentParser()
|
| 63 |
+
|
| 64 |
+
# Model Selection
|
| 65 |
+
parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
|
| 66 |
+
|
| 67 |
+
# Model Architecture
|
| 68 |
+
# Common
|
| 69 |
+
parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
|
| 70 |
+
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
|
| 71 |
+
parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.')
|
| 72 |
+
# CTM / LSTM specific
|
| 73 |
+
parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
|
| 74 |
+
parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).')
|
| 75 |
+
parser.add_argument('--iterations', type=int, default=50, help='Number of internal ticks (CTM, LSTM).')
|
| 76 |
+
parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).',
|
| 77 |
+
choices=['none',
|
| 78 |
+
'learnable-fourier',
|
| 79 |
+
'multi-learnable-fourier',
|
| 80 |
+
'custom-rotational'])
|
| 81 |
+
# CTM specific
|
| 82 |
+
parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
|
| 83 |
+
parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).')
|
| 84 |
+
parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).')
|
| 85 |
+
parser.add_argument('--neuron_select_type', type=str, default='first-last', help='Protocol for selecting neuron subset (CTM only).')
|
| 86 |
+
parser.add_argument('--n_random_pairing_self', type=int, default=256, help='Number of neurons paired self-to-self for synch (CTM only).')
|
| 87 |
+
parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
|
| 88 |
+
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
|
| 89 |
+
parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).')
|
| 90 |
+
parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
|
| 91 |
+
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
|
| 92 |
+
# LSTM specific
|
| 93 |
+
parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
|
| 94 |
+
|
| 95 |
+
# Training
|
| 96 |
+
parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training (per GPU).')
|
| 97 |
+
parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing (per GPU).')
|
| 98 |
+
parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.')
|
| 99 |
+
parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
|
| 100 |
+
parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
|
| 101 |
+
parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
|
| 102 |
+
parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
|
| 103 |
+
parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
|
| 104 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
|
| 105 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
|
| 106 |
+
parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
|
| 107 |
+
parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
|
| 108 |
+
parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.')
|
| 109 |
+
parser.add_argument('--use_custom_sampler', action=argparse.BooleanOptionalAction, default=False, help='Use custom fast sampler to avoid reshuffling.')
|
| 110 |
+
parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
|
| 111 |
+
|
| 112 |
+
# Housekeeping
|
| 113 |
+
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 114 |
+
parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.')
|
| 115 |
+
parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
|
| 116 |
+
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
| 117 |
+
parser.add_argument('--seed', type=int, default=412, help='Random seed.')
|
| 118 |
+
parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
|
| 119 |
+
parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
|
| 120 |
+
parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.')
|
| 121 |
+
parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading?')
|
| 122 |
+
|
| 123 |
+
# Tracking
|
| 124 |
+
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
|
| 125 |
+
parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
|
| 126 |
+
parser.add_argument('--plot_indices', type=int, default=[0], nargs='+', help='Which indices in test data to plot?') # Defaulted to 0
|
| 127 |
+
|
| 128 |
+
# Precision
|
| 129 |
+
parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
|
| 130 |
+
args = parser.parse_args()
|
| 131 |
+
return args
|
| 132 |
+
|
| 133 |
+
# --- DDP Setup Functions ---
|
| 134 |
+
def setup_ddp():
|
| 135 |
+
if 'RANK' not in os.environ:
|
| 136 |
+
# Basic setup for non-distributed run
|
| 137 |
+
os.environ['RANK'] = '0'
|
| 138 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 139 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
| 140 |
+
os.environ['MASTER_PORT'] = '12355' # Ensure this port is free
|
| 141 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 142 |
+
print("Running in non-distributed mode (simulated DDP setup).")
|
| 143 |
+
# Need to manually init if only 1 process desired for non-GPU testing
|
| 144 |
+
if not torch.cuda.is_available() or int(os.environ['WORLD_SIZE']) == 1:
|
| 145 |
+
dist.init_process_group(backend='gloo') # Gloo backend for CPU
|
| 146 |
+
print("Initialized process group with Gloo backend for single/CPU process.")
|
| 147 |
+
rank = int(os.environ['RANK'])
|
| 148 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 149 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 150 |
+
return rank, world_size, local_rank
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# Standard DDP setup
|
| 154 |
+
dist.init_process_group(backend='nccl') # 'nccl' for NVIDIA GPUs
|
| 155 |
+
rank = int(os.environ['RANK'])
|
| 156 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 157 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 158 |
+
if torch.cuda.is_available():
|
| 159 |
+
torch.cuda.set_device(local_rank)
|
| 160 |
+
print(f"Rank {rank} setup on GPU {local_rank}")
|
| 161 |
+
else:
|
| 162 |
+
print(f"Rank {rank} setup on CPU (GPU not available or requested)")
|
| 163 |
+
return rank, world_size, local_rank
|
| 164 |
+
|
| 165 |
+
def cleanup_ddp():
|
| 166 |
+
if dist.is_initialized():
|
| 167 |
+
dist.destroy_process_group()
|
| 168 |
+
print("DDP cleanup complete.")
|
| 169 |
+
|
| 170 |
+
def is_main_process(rank):
|
| 171 |
+
return rank == 0
|
| 172 |
+
# --- End DDP Setup ---
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
if __name__=='__main__':
|
| 176 |
+
|
| 177 |
+
args = parse_args()
|
| 178 |
+
|
| 179 |
+
rank, world_size, local_rank = setup_ddp()
|
| 180 |
+
|
| 181 |
+
set_seed(args.seed + rank, False) # Add rank for different seeds per process
|
| 182 |
+
|
| 183 |
+
# Rank 0 handles directory creation and initial logging
|
| 184 |
+
if is_main_process(rank):
|
| 185 |
+
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 186 |
+
zip_python_code(f'{args.log_dir}/repo_state.zip')
|
| 187 |
+
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 188 |
+
print(args, file=f)
|
| 189 |
+
if world_size > 1: dist.barrier() # Sync after rank 0 setup
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
|
| 193 |
+
|
| 194 |
+
# Data Loading
|
| 195 |
+
train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root)
|
| 196 |
+
|
| 197 |
+
# Setup Samplers
|
| 198 |
+
# This custom sampler is useful when using large batch sizes for Cifar. Otherwise the reshuffle happens tediously often
|
| 199 |
+
train_sampler = (FastRandomDistributedSampler(train_data, num_replicas=world_size, rank=rank, seed=args.seed, epoch_steps=int(10e10))
|
| 200 |
+
if args.use_custom_sampler else
|
| 201 |
+
DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True, seed=args.seed))
|
| 202 |
+
test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank, shuffle=False, seed=args.seed) # No shuffle needed for test; consistent
|
| 203 |
+
|
| 204 |
+
# Setup DataLoaders
|
| 205 |
+
trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler,
|
| 206 |
+
num_workers=args.num_workers_train, pin_memory=True, drop_last=True) # drop_last=True often used in DDP
|
| 207 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, sampler=test_sampler,
|
| 208 |
+
num_workers=1, pin_memory=True, drop_last=False)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
prediction_reshaper = [-1] # Task specific
|
| 212 |
+
args.out_dims = len(class_labels)
|
| 213 |
+
|
| 214 |
+
# Setup Device
|
| 215 |
+
if torch.cuda.is_available():
|
| 216 |
+
device = torch.device(f'cuda:{local_rank}')
|
| 217 |
+
else:
|
| 218 |
+
device = torch.device('cpu')
|
| 219 |
+
if world_size > 1:
|
| 220 |
+
warnings.warn("Running DDP on CPU is not recommended.")
|
| 221 |
+
if is_main_process(rank):
|
| 222 |
+
print(f'Main process (Rank {rank}): Using device {device}. World size: {world_size}. Model: {args.model}')
|
| 223 |
+
|
| 224 |
+
# --- Model Definition (Conditional) ---
|
| 225 |
+
model_base = None # Base model before DDP wrapping
|
| 226 |
+
if args.model == 'ctm':
|
| 227 |
+
model_base = ContinuousThoughtMachine(
|
| 228 |
+
iterations=args.iterations,
|
| 229 |
+
d_model=args.d_model,
|
| 230 |
+
d_input=args.d_input,
|
| 231 |
+
heads=args.heads,
|
| 232 |
+
n_synch_out=args.n_synch_out,
|
| 233 |
+
n_synch_action=args.n_synch_action,
|
| 234 |
+
synapse_depth=args.synapse_depth,
|
| 235 |
+
memory_length=args.memory_length,
|
| 236 |
+
deep_nlms=args.deep_memory,
|
| 237 |
+
memory_hidden_dims=args.memory_hidden_dims,
|
| 238 |
+
do_layernorm_nlm=args.do_normalisation,
|
| 239 |
+
backbone_type=args.backbone_type,
|
| 240 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 241 |
+
out_dims=args.out_dims,
|
| 242 |
+
prediction_reshaper=prediction_reshaper,
|
| 243 |
+
dropout=args.dropout,
|
| 244 |
+
dropout_nlm=args.dropout_nlm,
|
| 245 |
+
neuron_select_type=args.neuron_select_type,
|
| 246 |
+
n_random_pairing_self=args.n_random_pairing_self,
|
| 247 |
+
).to(device)
|
| 248 |
+
elif args.model == 'lstm':
|
| 249 |
+
model_base = LSTMBaseline(
|
| 250 |
+
num_layers=args.num_layers,
|
| 251 |
+
iterations=args.iterations,
|
| 252 |
+
d_model=args.d_model,
|
| 253 |
+
d_input=args.d_input,
|
| 254 |
+
heads=args.heads,
|
| 255 |
+
backbone_type=args.backbone_type,
|
| 256 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 257 |
+
out_dims=args.out_dims,
|
| 258 |
+
prediction_reshaper=prediction_reshaper,
|
| 259 |
+
dropout=args.dropout,
|
| 260 |
+
start_type=args.start_type,
|
| 261 |
+
).to(device)
|
| 262 |
+
elif args.model == 'ff':
|
| 263 |
+
model_base = FFBaseline(
|
| 264 |
+
d_model=args.d_model,
|
| 265 |
+
backbone_type=args.backbone_type,
|
| 266 |
+
out_dims=args.out_dims,
|
| 267 |
+
dropout=args.dropout,
|
| 268 |
+
).to(device)
|
| 269 |
+
else:
|
| 270 |
+
raise ValueError(f"Unknown model type: {args.model}")
|
| 271 |
+
|
| 272 |
+
# Initialize lazy modules if any
|
| 273 |
+
try:
|
| 274 |
+
pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
|
| 275 |
+
model_base(pseudo_inputs)
|
| 276 |
+
except Exception as e:
|
| 277 |
+
print(f"Warning: Pseudo forward pass failed: {e}")
|
| 278 |
+
|
| 279 |
+
# Wrap model with DDP
|
| 280 |
+
if device.type == 'cuda' and world_size > 1:
|
| 281 |
+
model = DDP(model_base, device_ids=[local_rank], output_device=local_rank)
|
| 282 |
+
elif device.type == 'cpu' and world_size > 1:
|
| 283 |
+
model = DDP(model_base) # No device_ids for CPU
|
| 284 |
+
else: # Single process run
|
| 285 |
+
model = model_base # No DDP wrapping needed
|
| 286 |
+
|
| 287 |
+
if is_main_process(rank):
|
| 288 |
+
# Access underlying model for param count
|
| 289 |
+
param_count = sum(p.numel() for p in model.module.parameters() if p.requires_grad) if world_size > 1 else sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 290 |
+
print(f'Total trainable params: {param_count}')
|
| 291 |
+
# --- End Model Definition ---
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# Optimizer and scheduler
|
| 295 |
+
# Use model.parameters() directly, DDP handles it
|
| 296 |
+
decay_params = []
|
| 297 |
+
no_decay_params = []
|
| 298 |
+
no_decay_names = []
|
| 299 |
+
for name, param in model.named_parameters():
|
| 300 |
+
if not param.requires_grad:
|
| 301 |
+
continue # Skip parameters that don't require gradients
|
| 302 |
+
if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
|
| 303 |
+
no_decay_params.append(param)
|
| 304 |
+
no_decay_names.append(name)
|
| 305 |
+
else:
|
| 306 |
+
decay_params.append(param)
|
| 307 |
+
if len(no_decay_names) and is_main_process(rank):
|
| 308 |
+
print(f'WARNING, excluding: {no_decay_names}')
|
| 309 |
+
|
| 310 |
+
# Optimizer and scheduler (Common setup)
|
| 311 |
+
if len(no_decay_names) and args.weight_decay!=0:
|
| 312 |
+
optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
|
| 313 |
+
{'params': no_decay_params, 'weight_decay':0}],
|
| 314 |
+
lr=args.lr,
|
| 315 |
+
eps=1e-8 if not args.use_amp else 1e-6)
|
| 316 |
+
else:
|
| 317 |
+
optimizer = torch.optim.AdamW(model.parameters(),
|
| 318 |
+
lr=args.lr,
|
| 319 |
+
eps=1e-8 if not args.use_amp else 1e-6,
|
| 320 |
+
weight_decay=args.weight_decay)
|
| 321 |
+
|
| 322 |
+
warmup_schedule = warmup(args.warmup_steps)
|
| 323 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
|
| 324 |
+
if args.use_scheduler:
|
| 325 |
+
if args.scheduler_type == 'multistep':
|
| 326 |
+
scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
|
| 327 |
+
elif args.scheduler_type == 'cosine':
|
| 328 |
+
scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
|
| 329 |
+
else:
|
| 330 |
+
raise NotImplementedError
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# Metrics tracking (on Rank 0)
|
| 334 |
+
start_iter = 0
|
| 335 |
+
train_losses = []
|
| 336 |
+
test_losses = []
|
| 337 |
+
train_accuracies = [] # Placeholder for potential detailed accuracy
|
| 338 |
+
test_accuracies = [] # Placeholder for potential detailed accuracy
|
| 339 |
+
# Conditional metrics
|
| 340 |
+
train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None # Scalar accuracy list
|
| 341 |
+
test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None # Scalar accuracy list
|
| 342 |
+
train_accuracies_standard = [] if args.model == 'ff' else None # Standard accuracy list for FF
|
| 343 |
+
test_accuracies_standard = [] if args.model == 'ff' else None # Standard accuracy list for FF
|
| 344 |
+
iters = []
|
| 345 |
+
|
| 346 |
+
scaler = torch.amp.GradScaler("cuda" if device.type == 'cuda' else "cpu", enabled=args.use_amp)
|
| 347 |
+
# Reloading Logic
|
| 348 |
+
if args.reload:
|
| 349 |
+
map_location = device # Load directly onto the process's device
|
| 350 |
+
chkpt_path = f'{args.log_dir}/checkpoint.pt'
|
| 351 |
+
if os.path.isfile(chkpt_path):
|
| 352 |
+
print(f'Rank {rank}: Reloading from: {chkpt_path}')
|
| 353 |
+
checkpoint = torch.load(chkpt_path, map_location=map_location, weights_only=False)
|
| 354 |
+
|
| 355 |
+
# Determine underlying model based on whether DDP wrapping occurred
|
| 356 |
+
model_to_load = model.module if isinstance(model, DDP) else model
|
| 357 |
+
|
| 358 |
+
# Handle potential 'module.' prefix in saved state_dict
|
| 359 |
+
state_dict = checkpoint['model_state_dict']
|
| 360 |
+
has_module_prefix = all(k.startswith('module.') for k in state_dict)
|
| 361 |
+
is_wrapped = isinstance(model, DDP)
|
| 362 |
+
|
| 363 |
+
if has_module_prefix and not is_wrapped:
|
| 364 |
+
# Saved with DDP, loading into non-DDP model -> remove prefix
|
| 365 |
+
state_dict = {k.partition('module.')[2]: v for k,v in state_dict.items()}
|
| 366 |
+
elif not has_module_prefix and is_wrapped:
|
| 367 |
+
load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
|
| 368 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 369 |
+
state_dict = None # Prevent loading again
|
| 370 |
+
|
| 371 |
+
if state_dict is not None:
|
| 372 |
+
load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
|
| 373 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
if not args.reload_model_only:
|
| 377 |
+
print(f'Rank {rank}: Reloading optimizer, scheduler, scaler, iteration.')
|
| 378 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 379 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 380 |
+
scaler_state_dict = checkpoint['scaler_state_dict']
|
| 381 |
+
if scaler.is_enabled():
|
| 382 |
+
print("Loading non-empty GradScaler state dict.")
|
| 383 |
+
try:
|
| 384 |
+
scaler.load_state_dict(scaler_state_dict)
|
| 385 |
+
except Exception as e:
|
| 386 |
+
print(f"Error loading GradScaler state dict: {e}")
|
| 387 |
+
print("Continuing with a fresh GradScaler state.")
|
| 388 |
+
|
| 389 |
+
start_iter = checkpoint['iteration']
|
| 390 |
+
# Only rank 0 loads metric history
|
| 391 |
+
if is_main_process(rank) and not args.ignore_metrics_when_reloading:
|
| 392 |
+
print(f'Rank {rank}: Reloading metrics history.')
|
| 393 |
+
iters = checkpoint['iters']
|
| 394 |
+
train_losses = checkpoint['train_losses']
|
| 395 |
+
test_losses = checkpoint['test_losses']
|
| 396 |
+
train_accuracies = checkpoint['train_accuracies']
|
| 397 |
+
test_accuracies = checkpoint['test_accuracies']
|
| 398 |
+
if args.model in ['ctm', 'lstm']:
|
| 399 |
+
train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
|
| 400 |
+
test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
|
| 401 |
+
elif args.model == 'ff':
|
| 402 |
+
train_accuracies_standard = checkpoint['train_accuracies_standard']
|
| 403 |
+
test_accuracies_standard = checkpoint['test_accuracies_standard']
|
| 404 |
+
elif is_main_process(rank) and args.ignore_metrics_when_reloading:
|
| 405 |
+
print(f'Rank {rank}: Ignoring metrics history upon reload.')
|
| 406 |
+
|
| 407 |
+
else:
|
| 408 |
+
print(f'Rank {rank}: Only reloading model weights!')
|
| 409 |
+
|
| 410 |
+
# Load RNG states
|
| 411 |
+
if is_main_process(rank) and 'torch_rng_state' in checkpoint and not args.reload_model_only:
|
| 412 |
+
print(f'Rank {rank}: Loading RNG states (may need DDP adaptation for full reproducibility).')
|
| 413 |
+
torch.set_rng_state(checkpoint['torch_rng_state'].cpu()) # Load CPU state
|
| 414 |
+
# Add CUDA state loading if needed, ensuring correct device handling
|
| 415 |
+
np.random.set_state(checkpoint['numpy_rng_state'])
|
| 416 |
+
random.setstate(checkpoint['random_rng_state'])
|
| 417 |
+
|
| 418 |
+
del checkpoint
|
| 419 |
+
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
| 420 |
+
print(f"Rank {rank}: Reload finished, starting from iteration {start_iter}")
|
| 421 |
+
else:
|
| 422 |
+
print(f"Rank {rank}: Checkpoint not found at {chkpt_path}, starting from scratch.")
|
| 423 |
+
if world_size > 1: dist.barrier() # Sync after loading
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# Conditional Compilation
|
| 427 |
+
if args.do_compile:
|
| 428 |
+
if is_main_process(rank): print('Compiling model components...')
|
| 429 |
+
# Compile on the underlying model if wrapped
|
| 430 |
+
model_to_compile = model.module if isinstance(model, DDP) else model
|
| 431 |
+
if hasattr(model_to_compile, 'backbone'):
|
| 432 |
+
model_to_compile.backbone = torch.compile(model_to_compile.backbone, mode='reduce-overhead', fullgraph=True)
|
| 433 |
+
if args.model == 'ctm':
|
| 434 |
+
if hasattr(model_to_compile, 'synapses'):
|
| 435 |
+
model_to_compile.synapses = torch.compile(model_to_compile.synapses, mode='reduce-overhead', fullgraph=True)
|
| 436 |
+
if world_size > 1: dist.barrier() # Sync after compilation
|
| 437 |
+
if is_main_process(rank): print('Compilation finished.')
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# --- Training Loop ---
|
| 441 |
+
model.train() # Ensure model is in train mode
|
| 442 |
+
pbar = tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True, disable=not is_main_process(rank))
|
| 443 |
+
|
| 444 |
+
iterator = iter(trainloader)
|
| 445 |
+
|
| 446 |
+
for bi in range(start_iter, args.training_iterations):
|
| 447 |
+
|
| 448 |
+
# Set sampler epoch (important for shuffling in DistributedSampler)
|
| 449 |
+
if not args.use_custom_sampler and hasattr(train_sampler, 'set_epoch'):
|
| 450 |
+
train_sampler.set_epoch(bi)
|
| 451 |
+
|
| 452 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 453 |
+
|
| 454 |
+
time_start_data = time.time()
|
| 455 |
+
try:
|
| 456 |
+
inputs, targets = next(iterator)
|
| 457 |
+
except StopIteration:
|
| 458 |
+
# Reset iterator - set_epoch handles shuffling if needed
|
| 459 |
+
iterator = iter(trainloader)
|
| 460 |
+
inputs, targets = next(iterator)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 464 |
+
targets = targets.to(device, non_blocking=True)
|
| 465 |
+
time_end_data = time.time()
|
| 466 |
+
|
| 467 |
+
loss = None
|
| 468 |
+
# Model-specific forward and loss calculation
|
| 469 |
+
time_start_forward = time.time()
|
| 470 |
+
with torch.autocast(device_type="cuda" if device.type == 'cuda' else "cpu", dtype=torch.float16, enabled=args.use_amp):
|
| 471 |
+
if args.do_compile:
|
| 472 |
+
torch.compiler.cudagraph_mark_step_begin()
|
| 473 |
+
|
| 474 |
+
if args.model == 'ctm':
|
| 475 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 476 |
+
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 477 |
+
elif args.model == 'lstm':
|
| 478 |
+
predictions, certainties, synchronisation = model(inputs)
|
| 479 |
+
loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 480 |
+
elif args.model == 'ff':
|
| 481 |
+
predictions = model(inputs) # FF returns only predictions
|
| 482 |
+
loss = nn.CrossEntropyLoss()(predictions, targets)
|
| 483 |
+
where_most_certain = None # Not applicable for FF standard loss
|
| 484 |
+
time_end_forward = time.time()
|
| 485 |
+
time_start_backward = time.time()
|
| 486 |
+
|
| 487 |
+
scaler.scale(loss).backward() # DDP handles gradient synchronization
|
| 488 |
+
time_end_backward = time.time()
|
| 489 |
+
|
| 490 |
+
if args.gradient_clipping!=-1:
|
| 491 |
+
scaler.unscale_(optimizer)
|
| 492 |
+
# Clip gradients across all parameters controlled by the optimizer
|
| 493 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
|
| 494 |
+
|
| 495 |
+
scaler.step(optimizer)
|
| 496 |
+
scaler.update()
|
| 497 |
+
optimizer.zero_grad(set_to_none=True)
|
| 498 |
+
scheduler.step()
|
| 499 |
+
|
| 500 |
+
# --- Aggregation and Logging (Rank 0) ---
|
| 501 |
+
# Aggregate loss for logging
|
| 502 |
+
loss_log = loss.detach() # Use detached loss for aggregation
|
| 503 |
+
if world_size > 1: dist.all_reduce(loss_log, op=dist.ReduceOp.AVG)
|
| 504 |
+
|
| 505 |
+
if is_main_process(rank):
|
| 506 |
+
# Calculate accuracy locally on rank 0 for description (approximate)
|
| 507 |
+
# Note: This uses rank 0's batch, not aggregated accuracy
|
| 508 |
+
accuracy_local = 0.0
|
| 509 |
+
if args.model in ['ctm', 'lstm']:
|
| 510 |
+
accuracy_local = (predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain] == targets).float().mean().item()
|
| 511 |
+
where_certain_tensor = where_most_certain.float() # Use rank 0's tensor for stats
|
| 512 |
+
pbar_desc = f'Timing; d={(time_end_data-time_start_data):0.3f}, f={(time_end_forward-time_start_forward):0.3f}, b={(time_end_backward-time_start_backward):0.3f}. Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_local:.3f} LR={current_lr:.6f} WhereCert(loc)={where_certain_tensor.mean().item():.2f}'
|
| 513 |
+
elif args.model == 'ff':
|
| 514 |
+
accuracy_local = (predictions.argmax(1) == targets).float().mean().item()
|
| 515 |
+
pbar_desc = f'Timing; d={(time_end_data-time_start_data):0.3f}, f={(time_end_forward-time_start_forward):0.3f}, b={(time_end_backward-time_start_backward):0.3f}. Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_local:.3f} LR={current_lr:.6f}'
|
| 516 |
+
|
| 517 |
+
pbar.set_description(f'{args.model.upper()} {pbar_desc}')
|
| 518 |
+
# --- End Aggregation and Logging ---
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
# --- Evaluation and Plotting (Rank 0 + Aggregation) ---
|
| 522 |
+
if bi % args.track_every == 0 and (bi != 0 or args.reload_model_only):
|
| 523 |
+
|
| 524 |
+
model.eval()
|
| 525 |
+
with torch.inference_mode():
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# --- Distributed Evaluation ---
|
| 529 |
+
iters.append(bi)
|
| 530 |
+
|
| 531 |
+
# TRAIN METRICS
|
| 532 |
+
total_train_loss = torch.tensor(0.0, device=device)
|
| 533 |
+
total_train_correct_certain = torch.tensor(0.0, device=device) # CTM/LSTM
|
| 534 |
+
total_train_correct_standard = torch.tensor(0.0, device=device) # FF
|
| 535 |
+
total_train_samples = torch.tensor(0.0, device=device)
|
| 536 |
+
|
| 537 |
+
# Use a sampler for evaluation to ensure non-overlapping data if needed
|
| 538 |
+
train_eval_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=False)
|
| 539 |
+
train_eval_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, sampler=train_eval_sampler, num_workers=1, pin_memory=True)
|
| 540 |
+
|
| 541 |
+
pbar_inner_desc = 'Eval Train (Rank 0)' if is_main_process(rank) else None
|
| 542 |
+
with tqdm(total=len(train_eval_loader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
|
| 543 |
+
for inferi, (inputs, targets) in enumerate(train_eval_loader):
|
| 544 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 545 |
+
targets = targets.to(device, non_blocking=True)
|
| 546 |
+
|
| 547 |
+
loss_eval = None
|
| 548 |
+
if args.model == 'ctm':
|
| 549 |
+
predictions, certainties, _ = model(inputs)
|
| 550 |
+
loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 551 |
+
preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
|
| 552 |
+
total_train_correct_certain += (preds_eval == targets).sum()
|
| 553 |
+
elif args.model == 'lstm':
|
| 554 |
+
predictions, certainties, _ = model(inputs)
|
| 555 |
+
loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 556 |
+
preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
|
| 557 |
+
total_train_correct_certain += (preds_eval == targets).sum()
|
| 558 |
+
elif args.model == 'ff':
|
| 559 |
+
predictions = model(inputs)
|
| 560 |
+
loss_eval = nn.CrossEntropyLoss()(predictions, targets)
|
| 561 |
+
preds_eval = predictions.argmax(1)
|
| 562 |
+
total_train_correct_standard += (preds_eval == targets).sum()
|
| 563 |
+
|
| 564 |
+
total_train_loss += loss_eval * inputs.size(0)
|
| 565 |
+
total_train_samples += inputs.size(0)
|
| 566 |
+
|
| 567 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 568 |
+
pbar_inner.update(1)
|
| 569 |
+
|
| 570 |
+
# Aggregate Train Metrics
|
| 571 |
+
if world_size > 1:
|
| 572 |
+
dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
|
| 573 |
+
dist.all_reduce(total_train_correct_certain, op=dist.ReduceOp.SUM)
|
| 574 |
+
dist.all_reduce(total_train_correct_standard, op=dist.ReduceOp.SUM)
|
| 575 |
+
dist.all_reduce(total_train_samples, op=dist.ReduceOp.SUM)
|
| 576 |
+
|
| 577 |
+
# Calculate final Train metrics on Rank 0
|
| 578 |
+
if is_main_process(rank) and total_train_samples > 0:
|
| 579 |
+
avg_train_loss = total_train_loss.item() / total_train_samples.item()
|
| 580 |
+
train_losses.append(avg_train_loss)
|
| 581 |
+
if args.model in ['ctm', 'lstm']:
|
| 582 |
+
avg_train_acc_certain = total_train_correct_certain.item() / total_train_samples.item()
|
| 583 |
+
train_accuracies_most_certain.append(avg_train_acc_certain)
|
| 584 |
+
elif args.model == 'ff':
|
| 585 |
+
avg_train_acc_standard = total_train_correct_standard.item() / total_train_samples.item()
|
| 586 |
+
train_accuracies_standard.append(avg_train_acc_standard)
|
| 587 |
+
print(f"Iter {bi} Train Metrics (Agg): Loss={avg_train_loss:.4f}")
|
| 588 |
+
|
| 589 |
+
# TEST METRICS
|
| 590 |
+
total_test_loss = torch.tensor(0.0, device=device)
|
| 591 |
+
total_test_correct_certain = torch.tensor(0.0, device=device) # CTM/LSTM
|
| 592 |
+
total_test_correct_standard = torch.tensor(0.0, device=device) # FF
|
| 593 |
+
total_test_samples = torch.tensor(0.0, device=device)
|
| 594 |
+
|
| 595 |
+
pbar_inner_desc = 'Eval Test (Rank 0)' if is_main_process(rank) else None
|
| 596 |
+
with tqdm(total=len(testloader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
|
| 597 |
+
for inferi, (inputs, targets) in enumerate(testloader): # Testloader already uses sampler
|
| 598 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 599 |
+
targets = targets.to(device, non_blocking=True)
|
| 600 |
+
|
| 601 |
+
loss_eval = None
|
| 602 |
+
if args.model == 'ctm':
|
| 603 |
+
predictions, certainties, _ = model(inputs)
|
| 604 |
+
loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 605 |
+
preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
|
| 606 |
+
total_test_correct_certain += (preds_eval == targets).sum()
|
| 607 |
+
elif args.model == 'lstm':
|
| 608 |
+
predictions, certainties, _ = model(inputs)
|
| 609 |
+
loss_eval, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
|
| 610 |
+
preds_eval = predictions.argmax(1)[torch.arange(predictions.size(0), device=device), where_most_certain]
|
| 611 |
+
total_test_correct_certain += (preds_eval == targets).sum()
|
| 612 |
+
elif args.model == 'ff':
|
| 613 |
+
predictions = model(inputs)
|
| 614 |
+
loss_eval = nn.CrossEntropyLoss()(predictions, targets)
|
| 615 |
+
preds_eval = predictions.argmax(1)
|
| 616 |
+
total_test_correct_standard += (preds_eval == targets).sum()
|
| 617 |
+
|
| 618 |
+
total_test_loss += loss_eval * inputs.size(0)
|
| 619 |
+
total_test_samples += inputs.size(0)
|
| 620 |
+
|
| 621 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 622 |
+
pbar_inner.update(1)
|
| 623 |
+
|
| 624 |
+
# Aggregate Test Metrics
|
| 625 |
+
if world_size > 1:
|
| 626 |
+
dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
|
| 627 |
+
dist.all_reduce(total_test_correct_certain, op=dist.ReduceOp.SUM)
|
| 628 |
+
dist.all_reduce(total_test_correct_standard, op=dist.ReduceOp.SUM)
|
| 629 |
+
dist.all_reduce(total_test_samples, op=dist.ReduceOp.SUM)
|
| 630 |
+
|
| 631 |
+
# Calculate and Plot final Test metrics on Rank 0
|
| 632 |
+
if is_main_process(rank) and total_test_samples > 0:
|
| 633 |
+
avg_test_loss = total_test_loss.item() / total_test_samples.item()
|
| 634 |
+
test_losses.append(avg_test_loss)
|
| 635 |
+
acc_label = ''
|
| 636 |
+
acc_val = 0.0
|
| 637 |
+
if args.model in ['ctm', 'lstm']:
|
| 638 |
+
avg_test_acc_certain = total_test_correct_certain.item() / total_test_samples.item()
|
| 639 |
+
test_accuracies_most_certain.append(avg_test_acc_certain)
|
| 640 |
+
acc_label = f'Most certain ({avg_test_acc_certain:.3f})'
|
| 641 |
+
acc_val = avg_test_acc_certain
|
| 642 |
+
elif args.model == 'ff':
|
| 643 |
+
avg_test_acc_standard = total_test_correct_standard.item() / total_test_samples.item()
|
| 644 |
+
test_accuracies_standard.append(avg_test_acc_standard)
|
| 645 |
+
acc_label = f'Standard Acc ({avg_test_acc_standard:.3f})'
|
| 646 |
+
acc_val = avg_test_acc_standard
|
| 647 |
+
print(f"Iter {bi} Test Metrics (Agg): Loss={avg_test_loss:.4f}, Acc={acc_val:.4f}\n")
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
# --- Plotting ---
|
| 651 |
+
figacc = plt.figure(figsize=(10, 10))
|
| 652 |
+
axacc_train = figacc.add_subplot(211)
|
| 653 |
+
axacc_test = figacc.add_subplot(212)
|
| 654 |
+
|
| 655 |
+
if args.model in ['ctm', 'lstm']:
|
| 656 |
+
axacc_train.plot(iters, train_accuracies_most_certain, 'k-', alpha=0.9, label=f'Most certain ({train_accuracies_most_certain[-1]:.3f})')
|
| 657 |
+
axacc_test.plot(iters, test_accuracies_most_certain, 'k-', alpha=0.9, label=acc_label)
|
| 658 |
+
elif args.model == 'ff':
|
| 659 |
+
axacc_train.plot(iters, train_accuracies_standard, 'k-', alpha=0.9, label=f'Standard Acc ({train_accuracies_standard[-1]:.3f})')
|
| 660 |
+
axacc_test.plot(iters, test_accuracies_standard, 'k-', alpha=0.9, label=acc_label)
|
| 661 |
+
|
| 662 |
+
axacc_train.set_title('Train Accuracy (Aggregated)')
|
| 663 |
+
axacc_test.set_title('Test Accuracy (Aggregated)')
|
| 664 |
+
axacc_train.legend(loc='lower right')
|
| 665 |
+
axacc_test.legend(loc='lower right')
|
| 666 |
+
axacc_train.set_xlim([0, args.training_iterations])
|
| 667 |
+
axacc_test.set_xlim([0, args.training_iterations])
|
| 668 |
+
|
| 669 |
+
# Keep dataset specific ylim adjustments if needed
|
| 670 |
+
if args.dataset == 'imagenet':
|
| 671 |
+
# For easy comparison when training
|
| 672 |
+
train_ylim_set = False
|
| 673 |
+
if args.model in ['ctm', 'lstm'] and len(train_accuracies_most_certain)>0 and np.any(np.array(train_accuracies_most_certain)>0.4): train_ylim_set=True; axacc_train.set_ylim([0.4, 1])
|
| 674 |
+
if args.model == 'ff' and len(train_accuracies_standard)>0 and np.any(np.array(train_accuracies_standard)>0.4): train_ylim_set=True; axacc_train.set_ylim([0.4, 1])
|
| 675 |
+
|
| 676 |
+
test_ylim_set = False
|
| 677 |
+
if args.model in ['ctm', 'lstm'] and len(test_accuracies_most_certain)>0 and np.any(np.array(test_accuracies_most_certain)>0.3): test_ylim_set=True; axacc_test.set_ylim([0.3, 0.8])
|
| 678 |
+
if args.model == 'ff' and len(test_accuracies_standard)>0 and np.any(np.array(test_accuracies_standard)>0.3): test_ylim_set=True; axacc_test.set_ylim([0.3, 0.8])
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
figacc.tight_layout()
|
| 682 |
+
figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
|
| 683 |
+
plt.close(figacc)
|
| 684 |
+
|
| 685 |
+
# Loss Plot
|
| 686 |
+
figloss = plt.figure(figsize=(10, 5))
|
| 687 |
+
axloss = figloss.add_subplot(111)
|
| 688 |
+
axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train (Aggregated): {train_losses[-1]:.4f}')
|
| 689 |
+
axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test (Aggregated): {test_losses[-1]:.4f}')
|
| 690 |
+
axloss.legend(loc='upper right')
|
| 691 |
+
axloss.set_xlabel("Iteration")
|
| 692 |
+
axloss.set_ylabel("Loss")
|
| 693 |
+
axloss.set_xlim([0, args.training_iterations])
|
| 694 |
+
axloss.set_ylim(bottom=0)
|
| 695 |
+
figloss.tight_layout()
|
| 696 |
+
figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
|
| 697 |
+
plt.close(figloss)
|
| 698 |
+
# --- End Plotting ---
|
| 699 |
+
|
| 700 |
+
# Visualization on Rank 0
|
| 701 |
+
if is_main_process(rank) and args.model in ['ctm', 'lstm']:
|
| 702 |
+
try:
|
| 703 |
+
model_module = model.module if isinstance(model, DDP) else model # Get underlying model
|
| 704 |
+
# Simplified viz: use first batch from testloader
|
| 705 |
+
inputs_viz, targets_viz = next(iter(testloader))
|
| 706 |
+
inputs_viz = inputs_viz.to(device)
|
| 707 |
+
targets_viz = targets_viz.to(device)
|
| 708 |
+
|
| 709 |
+
pbar.set_description('Tracking (Rank 0): Viz Fwd Pass')
|
| 710 |
+
predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model_module(inputs_viz, track=True)
|
| 711 |
+
|
| 712 |
+
att_shape = (model_module.kv_features.shape[2], model_module.kv_features.shape[3])
|
| 713 |
+
attention_tracking_viz = attention_tracking_viz.reshape(
|
| 714 |
+
attention_tracking_viz.shape[0],
|
| 715 |
+
attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
pbar.set_description('Tracking (Rank 0): Dynamics Plot')
|
| 719 |
+
plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
|
| 720 |
+
|
| 721 |
+
# Plot specific indices from test_data directly
|
| 722 |
+
pbar.set_description('Tracking (Rank 0): GIF Generation')
|
| 723 |
+
for plot_idx in args.plot_indices:
|
| 724 |
+
try:
|
| 725 |
+
if plot_idx < len(test_data):
|
| 726 |
+
inputs_plot, target_plot = test_data.__getitem__(plot_idx)
|
| 727 |
+
inputs_plot = inputs_plot.unsqueeze(0).to(device)
|
| 728 |
+
|
| 729 |
+
preds_plot, certs_plot, _, _, posts_plot, atts_plot = model_module(inputs_plot, track=True)
|
| 730 |
+
atts_plot = atts_plot.reshape(atts_plot.shape[0], atts_plot.shape[1], -1, att_shape[0], att_shape[1])
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
img_gif = np.moveaxis(np.clip(inputs_plot[0].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1)
|
| 734 |
+
|
| 735 |
+
make_classification_gif(img_gif, target_plot, preds_plot[0].detach().cpu().numpy(), certs_plot[0].detach().cpu().numpy(),
|
| 736 |
+
posts_plot[:,0], atts_plot[:,0] if atts_plot is not None else None, class_labels,
|
| 737 |
+
f'{args.log_dir}/idx{plot_idx}_attention.gif')
|
| 738 |
+
else:
|
| 739 |
+
print(f"Warning: Plot index {plot_idx} out of range for test dataset size {len(test_data)}.")
|
| 740 |
+
except Exception as e_gif:
|
| 741 |
+
print(f"Rank 0 GIF generation failed for index {plot_idx}: {e_gif}")
|
| 742 |
+
|
| 743 |
+
except Exception as e_viz:
|
| 744 |
+
print(f"Rank 0 visualization failed: {e_viz}")
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
if world_size > 1: dist.barrier() # Sync after evaluation block
|
| 749 |
+
model.train() # Set back to train mode
|
| 750 |
+
# --- End Evaluation Block ---
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
# --- Checkpointing (Rank 0) ---
|
| 754 |
+
if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter and is_main_process(rank):
|
| 755 |
+
pbar.set_description('Rank 0: Saving checkpoint...')
|
| 756 |
+
save_path = f'{args.log_dir}/checkpoint.pt'
|
| 757 |
+
# Access underlying model state dict if DDP is used
|
| 758 |
+
model_state_to_save = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
|
| 759 |
+
|
| 760 |
+
save_dict = {
|
| 761 |
+
'model_state_dict': model_state_to_save,
|
| 762 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 763 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 764 |
+
'scaler_state_dict':scaler.state_dict(),
|
| 765 |
+
'iteration': bi,
|
| 766 |
+
'train_losses': train_losses,
|
| 767 |
+
'test_losses': test_losses,
|
| 768 |
+
'iters': iters,
|
| 769 |
+
'args': args,
|
| 770 |
+
'torch_rng_state': torch.get_rng_state(), # CPU state
|
| 771 |
+
'numpy_rng_state': np.random.get_state(),
|
| 772 |
+
'random_rng_state': random.getstate(),
|
| 773 |
+
# Include conditional metrics
|
| 774 |
+
'train_accuracies': train_accuracies, # Placeholder
|
| 775 |
+
'test_accuracies': test_accuracies, # Placeholder
|
| 776 |
+
}
|
| 777 |
+
if args.model in ['ctm', 'lstm']:
|
| 778 |
+
save_dict['train_accuracies_most_certain'] = train_accuracies_most_certain
|
| 779 |
+
save_dict['test_accuracies_most_certain'] = test_accuracies_most_certain
|
| 780 |
+
elif args.model == 'ff':
|
| 781 |
+
save_dict['train_accuracies_standard'] = train_accuracies_standard
|
| 782 |
+
save_dict['test_accuracies_standard'] = test_accuracies_standard
|
| 783 |
+
|
| 784 |
+
torch.save(save_dict , save_path)
|
| 785 |
+
pbar.set_description(f"Rank 0: Checkpoint saved to {save_path}")
|
| 786 |
+
# --- End Checkpointing ---
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
if world_size > 1: dist.barrier() # Sync before next iteration
|
| 790 |
+
|
| 791 |
+
# Update pbar on Rank 0
|
| 792 |
+
if is_main_process(rank):
|
| 793 |
+
pbar.update(1)
|
| 794 |
+
# --- End Training Loop ---
|
| 795 |
+
|
| 796 |
+
if is_main_process(rank):
|
| 797 |
+
pbar.close()
|
| 798 |
+
|
| 799 |
+
cleanup_ddp() # Cleanup DDP resources
|
tasks/mazes/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Mazes
|
| 2 |
+
|
| 3 |
+
This folder contains code for training and analysing 2D maze solving experiments
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## Training
|
| 7 |
+
To run the maze training that we used for the paper, run the following command from the parent directory:
|
| 8 |
+
```
|
| 9 |
+
python -m tasks.mazes.train --d_model 2048 --d_input 512 --synapse_depth 4 --heads 8 --n_synch_out 64 --n_synch_action 32 --neuron_select_type first-last --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 32 --dropout 0.1 --no-do_normalisation --positional_embedding_type none --backbone_type resnet34-2 --batch_size 64 --batch_size_test 64 --lr 1e-4 --training_iterations 1000001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs/mazes/d=2048--i=512--h=8--ns=64-32--iters=75x25--h=32--drop=0.1--pos=none--back=34-2--seed=42 --dataset mazes-medium --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50
|
| 10 |
+
```
|
tasks/mazes/analysis/README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analysis
|
| 2 |
+
|
| 3 |
+
This folder contains analysis code for 2D maze experiments. To build GIFs for imagenet run (from the base directory):
|
| 4 |
+
|
| 5 |
+
To run maze analysis run the following command from the parent directory:
|
| 6 |
+
```
|
| 7 |
+
python -m tasks.mazes.analysis.run --actions viz viz --checkpoint checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
You will need to download the checkpoint from here: https://drive.google.com/file/d/1vGiMaQCxzKVT68SipxDCW0W5n5jjEQnC/view?usp=drive_link . Extract this to the appropriate directory: `checkpoints/mazes/...` . Otherwise, use your own after training.
|
tasks/mazes/analysis/run.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
np.seterr(divide='ignore', invalid='warn') # Keep specific numpy error settings
|
| 4 |
+
import matplotlib as mpl
|
| 5 |
+
mpl.use('Agg') # Use Agg backend for matplotlib (important to set before importing pyplot)
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
sns.set_style('darkgrid') # Keep seaborn style
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
import cv2
|
| 12 |
+
import imageio # Used for saving GIFs in viz
|
| 13 |
+
|
| 14 |
+
# Local imports
|
| 15 |
+
from data.custom_datasets import MazeImageFolder
|
| 16 |
+
from models.ctm import ContinuousThoughtMachine
|
| 17 |
+
from tasks.mazes.plotting import draw_path #
|
| 18 |
+
from tasks.image_classification.plotting import save_frames_to_mp4
|
| 19 |
+
|
| 20 |
+
def has_solved_checker(x_maze, route, valid_only=True, fault_tolerance=1, exclusions=[]):
|
| 21 |
+
"""Checks if a route solves a maze."""
|
| 22 |
+
maze = np.copy(x_maze)
|
| 23 |
+
H, W, _ = maze.shape
|
| 24 |
+
start_coords = np.argwhere((maze == [1, 0, 0]).all(axis=2))
|
| 25 |
+
end_coords = np.argwhere((maze == [0, 1, 0]).all(axis=2))
|
| 26 |
+
|
| 27 |
+
if len(start_coords) == 0:
|
| 28 |
+
return False, (-1, -1), 0 # Cannot start
|
| 29 |
+
|
| 30 |
+
current_pos = tuple(start_coords[0])
|
| 31 |
+
target_pos = tuple(end_coords[0]) if len(end_coords) > 0 else None
|
| 32 |
+
|
| 33 |
+
mistakes_made = 0
|
| 34 |
+
final_pos = current_pos
|
| 35 |
+
path_taken_len = 0
|
| 36 |
+
|
| 37 |
+
for step in route:
|
| 38 |
+
if mistakes_made > fault_tolerance:
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
next_pos_candidate = list(current_pos) # Use a list for mutable coordinate calculation
|
| 42 |
+
if step == 0: next_pos_candidate[0] -= 1
|
| 43 |
+
elif step == 1: next_pos_candidate[0] += 1
|
| 44 |
+
elif step == 2: next_pos_candidate[1] -= 1
|
| 45 |
+
elif step == 3: next_pos_candidate[1] += 1
|
| 46 |
+
elif step == 4: pass # Stay in place
|
| 47 |
+
else: continue # Invalid step action
|
| 48 |
+
next_pos = tuple(next_pos_candidate)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
is_invalid_step = False
|
| 52 |
+
# Check bounds first, then maze content if in bounds
|
| 53 |
+
if not (0 <= next_pos[0] < H and 0 <= next_pos[1] < W):
|
| 54 |
+
is_invalid_step = True
|
| 55 |
+
elif np.all(maze[next_pos] == [0, 0, 0]): # Wall
|
| 56 |
+
is_invalid_step = True
|
| 57 |
+
|
| 58 |
+
if is_invalid_step:
|
| 59 |
+
mistakes_made += 1
|
| 60 |
+
if valid_only:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
current_pos = next_pos
|
| 64 |
+
path_taken_len += 1
|
| 65 |
+
|
| 66 |
+
if target_pos and current_pos == target_pos:
|
| 67 |
+
if mistakes_made <= fault_tolerance:
|
| 68 |
+
return True, current_pos, path_taken_len
|
| 69 |
+
|
| 70 |
+
if mistakes_made <= fault_tolerance:
|
| 71 |
+
# Assuming exclusions is a list of tuples (as populated in the 'gen' action)
|
| 72 |
+
if current_pos not in exclusions:
|
| 73 |
+
final_pos = current_pos
|
| 74 |
+
|
| 75 |
+
if target_pos and final_pos == target_pos and mistakes_made <= fault_tolerance: # Added mistakes_made check here
|
| 76 |
+
return True, final_pos, path_taken_len
|
| 77 |
+
return False, final_pos, path_taken_len
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def parse_args():
|
| 81 |
+
"""Parses command-line arguments for maze analysis."""
|
| 82 |
+
parser = argparse.ArgumentParser(description="Analyze Asynchronous Thought Machine on Maze Tasks")
|
| 83 |
+
parser.add_argument('--actions', type=str, nargs='+', default=['gen'], help="Actions: 'viz', 'gen'")
|
| 84 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help="GPU device index or -1 for CPU")
|
| 85 |
+
parser.add_argument('--checkpoint', type=str, default='checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt', help="Path to CTM checkpoint")
|
| 86 |
+
parser.add_argument('--output_dir', type=str, default='tasks/mazes/analysis/outputs', help="Directory for analysis outputs")
|
| 87 |
+
parser.add_argument('--dataset_for_viz', type=str, default='large', help="Dataset for 'viz' action")
|
| 88 |
+
parser.add_argument('--dataset_for_gen', type=str, default='extralarge', help="Dataset for 'gen' action")
|
| 89 |
+
parser.add_argument('--batch_size_test', type=int, default=32, help="Batch size for loading test data for 'viz'")
|
| 90 |
+
parser.add_argument('--max_reapplications', type=int, default=20, help="When testing generalisation to extra large mazes")
|
| 91 |
+
parser.add_argument('--legacy_scaling', action=argparse.BooleanOptionalAction, default=True, help='Legacy checkpoints scale between 0 and 1, new ones can scale -1 to 1.')
|
| 92 |
+
return parser.parse_args()
|
| 93 |
+
|
| 94 |
+
def _load_ctm_model(checkpoint_path, device):
|
| 95 |
+
"""Loads the ContinuousThoughtMachine model from a checkpoint."""
|
| 96 |
+
print(f"Loading checkpoint: {checkpoint_path}")
|
| 97 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 98 |
+
model_args = checkpoint['args']
|
| 99 |
+
|
| 100 |
+
# Handle legacy arguments for model_args
|
| 101 |
+
if not hasattr(model_args, 'backbone_type') and hasattr(model_args, 'resnet_type'):
|
| 102 |
+
model_args.backbone_type = f'{model_args.resnet_type}-{getattr(model_args, "resnet_feature_scales", [4])[-1]}'
|
| 103 |
+
|
| 104 |
+
# Ensure prediction_reshaper is derived correctly
|
| 105 |
+
# Assuming out_dims exists and is used for this
|
| 106 |
+
prediction_reshaper = [model_args.out_dims // 5, 5] if hasattr(model_args, 'out_dims') else None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
if not hasattr(model_args, 'neuron_select_type'):
|
| 110 |
+
model_args.neuron_select_type = 'first-last'
|
| 111 |
+
if not hasattr(model_args, 'n_random_pairing_self'):
|
| 112 |
+
model_args.n_random_pairing_self = 0
|
| 113 |
+
|
| 114 |
+
print("Instantiating CTM model...")
|
| 115 |
+
model = ContinuousThoughtMachine(
|
| 116 |
+
iterations=model_args.iterations,
|
| 117 |
+
d_model=model_args.d_model,
|
| 118 |
+
d_input=model_args.d_input,
|
| 119 |
+
heads=model_args.heads,
|
| 120 |
+
n_synch_out=model_args.n_synch_out,
|
| 121 |
+
n_synch_action=model_args.n_synch_action,
|
| 122 |
+
synapse_depth=model_args.synapse_depth,
|
| 123 |
+
memory_length=model_args.memory_length,
|
| 124 |
+
deep_nlms=model_args.deep_memory, # Mapping from model_args.deep_memory
|
| 125 |
+
memory_hidden_dims=model_args.memory_hidden_dims,
|
| 126 |
+
do_layernorm_nlm=model_args.do_normalisation, # Mapping from model_args.do_normalisation
|
| 127 |
+
backbone_type=model_args.backbone_type,
|
| 128 |
+
positional_embedding_type=model_args.positional_embedding_type,
|
| 129 |
+
out_dims=model_args.out_dims,
|
| 130 |
+
prediction_reshaper=prediction_reshaper,
|
| 131 |
+
dropout=0, # Explicitly setting dropout to 0 as in original
|
| 132 |
+
neuron_select_type=model_args.neuron_select_type,
|
| 133 |
+
n_random_pairing_self=model_args.n_random_pairing_self,
|
| 134 |
+
).to(device)
|
| 135 |
+
|
| 136 |
+
load_result = model.load_state_dict(checkpoint['state_dict'], strict=False)
|
| 137 |
+
print(f"Loaded state_dict. Missing keys: {load_result.missing_keys}, Unexpected keys: {load_result.unexpected_keys}")
|
| 138 |
+
model.eval()
|
| 139 |
+
return model
|
| 140 |
+
|
| 141 |
+
# --- Main Execution Block ---
|
| 142 |
+
if __name__=='__main__':
|
| 143 |
+
args = parse_args()
|
| 144 |
+
|
| 145 |
+
if args.device[0] != -1 and torch.cuda.is_available():
|
| 146 |
+
device = f'cuda:{args.device[0]}'
|
| 147 |
+
else:
|
| 148 |
+
device = 'cpu'
|
| 149 |
+
print(f"Using device: {device}")
|
| 150 |
+
|
| 151 |
+
palette = sns.color_palette("husl", 8)
|
| 152 |
+
cmap = plt.get_cmap('gist_rainbow')
|
| 153 |
+
|
| 154 |
+
# --- Generalisation Action ('gen') ---
|
| 155 |
+
if 'gen' in args.actions:
|
| 156 |
+
model = _load_ctm_model(args.checkpoint, device)
|
| 157 |
+
|
| 158 |
+
print(f"\n--- Running Generalisation Analysis ('gen'): {args.dataset_for_gen} ---")
|
| 159 |
+
target_dataset_name = f'{args.dataset_for_gen}'
|
| 160 |
+
data_root = f'data/mazes/{target_dataset_name}/test'
|
| 161 |
+
max_target_route_len = 50 # Specific to 'gen' action
|
| 162 |
+
|
| 163 |
+
test_data = MazeImageFolder(
|
| 164 |
+
root=data_root, which_set='test',
|
| 165 |
+
maze_route_length=max_target_route_len,
|
| 166 |
+
expand_range=not args.legacy_scaling, # Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
|
| 167 |
+
trunc=True
|
| 168 |
+
)
|
| 169 |
+
# Load a single large batch for 'gen'
|
| 170 |
+
testloader = torch.utils.data.DataLoader(
|
| 171 |
+
test_data, batch_size=min(len(test_data), 2000),
|
| 172 |
+
shuffle=False, num_workers=1
|
| 173 |
+
)
|
| 174 |
+
inputs, targets = next(iter(testloader))
|
| 175 |
+
|
| 176 |
+
actual_lengths = (targets != 4).sum(dim=-1)
|
| 177 |
+
sorted_indices = torch.argsort(actual_lengths, descending=True)
|
| 178 |
+
inputs, targets, actual_lengths = inputs[sorted_indices], targets[sorted_indices], actual_lengths[sorted_indices]
|
| 179 |
+
|
| 180 |
+
test_how_many = min(1000, len(inputs))
|
| 181 |
+
print(f"Processing {test_how_many} mazes sorted by length...")
|
| 182 |
+
|
| 183 |
+
results = {}
|
| 184 |
+
fault_tolerance = 2 # Specific to 'gen' analysis
|
| 185 |
+
output_gen_dir = os.path.join(args.output_dir, 'gen', args.dataset_for_gen)
|
| 186 |
+
os.makedirs(output_gen_dir, exist_ok=True)
|
| 187 |
+
|
| 188 |
+
for n_tested in range(test_how_many):
|
| 189 |
+
maze_actual_length = actual_lengths[n_tested].item()
|
| 190 |
+
maze_idx_display = n_tested + 1
|
| 191 |
+
print(f"Testing maze {maze_idx_display}/{test_how_many} (Len: {maze_actual_length})...")
|
| 192 |
+
|
| 193 |
+
initial_input_maze = inputs[n_tested:n_tested+1].clone().to(device)
|
| 194 |
+
maze_output_dir = os.path.join(output_gen_dir, f"maze_{maze_idx_display}")
|
| 195 |
+
|
| 196 |
+
re_applications = 0
|
| 197 |
+
has_solved = False
|
| 198 |
+
current_input_maze = initial_input_maze
|
| 199 |
+
exclusions = []
|
| 200 |
+
long_frames = []
|
| 201 |
+
ongoing_solution_img = None
|
| 202 |
+
|
| 203 |
+
while not has_solved and re_applications < args.max_reapplications:
|
| 204 |
+
re_applications += 1
|
| 205 |
+
with torch.no_grad():
|
| 206 |
+
predictions, certainties, _, _, _, attention_tracking = model(current_input_maze, track=True)
|
| 207 |
+
|
| 208 |
+
h_feat, w_feat = model.kv_features.shape[-2:]
|
| 209 |
+
attention_tracking = attention_tracking.reshape(attention_tracking.shape[0], -1, h_feat, w_feat)
|
| 210 |
+
|
| 211 |
+
n_steps_viz = predictions.shape[-1] # Use a different name to avoid conflict if n_steps is used elsewhere
|
| 212 |
+
step_linspace = np.linspace(0, 1, n_steps_viz)
|
| 213 |
+
current_maze_np = current_input_maze[0].permute(1,2,0).detach().cpu().numpy()
|
| 214 |
+
|
| 215 |
+
for stepi in range(n_steps_viz):
|
| 216 |
+
pred_route = predictions[0, :, stepi].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
|
| 217 |
+
frame = draw_path(current_maze_np, pred_route)
|
| 218 |
+
if attention_tracking is not None and stepi < attention_tracking.shape[0]:
|
| 219 |
+
try:
|
| 220 |
+
attn = attention_tracking[stepi].mean(0)
|
| 221 |
+
attn_resized = cv2.resize(attn, (current_maze_np.shape[1], current_maze_np.shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 222 |
+
if attn_resized.max() > attn_resized.min():
|
| 223 |
+
attn_norm = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min())
|
| 224 |
+
attn_norm[attn_norm < np.percentile(attn_norm, 80)] = 0.0
|
| 225 |
+
frame = np.clip((np.copy(frame)*(1-attn_norm[:,:,np.newaxis])*1 + (attn_norm[:,:,np.newaxis]*0.8 * np.reshape(np.array(cmap(step_linspace[stepi]))[:3], (1, 1, 3)))), 0, 1)
|
| 226 |
+
except Exception: # Keep broad except for visualization robustness
|
| 227 |
+
pass
|
| 228 |
+
frame_resized = cv2.resize(frame, (int(current_maze_np.shape[1]*4), int(current_maze_np.shape[0]*4)), interpolation=cv2.INTER_NEAREST) # Corrected shape[1]*4 for height
|
| 229 |
+
long_frames.append((np.clip(frame_resized, 0, 1) * 255).astype(np.uint8))
|
| 230 |
+
|
| 231 |
+
where_most_certain = certainties[0, 1].argmax().item()
|
| 232 |
+
chosen_pred_route = predictions[0, :, where_most_certain].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
|
| 233 |
+
current_start_loc_list = np.argwhere((current_maze_np == [1, 0, 0]).all(axis=2)).tolist()
|
| 234 |
+
|
| 235 |
+
# Ensure current_start_loc_list is not empty before trying to access its elements
|
| 236 |
+
if not current_start_loc_list:
|
| 237 |
+
print(f"Warning: Could not find start location in maze {maze_idx_display} during reapplication {re_applications}. Stopping reapplication.")
|
| 238 |
+
break # Cannot proceed without a start location
|
| 239 |
+
|
| 240 |
+
solved_now, final_pos, _ = has_solved_checker(current_maze_np, chosen_pred_route, True, fault_tolerance, exclusions)
|
| 241 |
+
|
| 242 |
+
path_img = draw_path(current_maze_np, chosen_pred_route, cmap=cmap, valid_only=True)
|
| 243 |
+
if ongoing_solution_img is None:
|
| 244 |
+
ongoing_solution_img = path_img
|
| 245 |
+
else:
|
| 246 |
+
mask = (np.any(ongoing_solution_img!=path_img, -1))&(~np.all(path_img==[1,1,1], -1))&(~np.all(ongoing_solution_img==[1,0,0], -1))
|
| 247 |
+
ongoing_solution_img[mask] = path_img[mask]
|
| 248 |
+
|
| 249 |
+
if solved_now:
|
| 250 |
+
has_solved = True
|
| 251 |
+
break
|
| 252 |
+
|
| 253 |
+
if tuple(current_start_loc_list[0]) == final_pos:
|
| 254 |
+
exclusions.append(tuple(current_start_loc_list[0]))
|
| 255 |
+
|
| 256 |
+
next_input = current_input_maze.clone()
|
| 257 |
+
old_start_idx = tuple(current_start_loc_list[0])
|
| 258 |
+
next_input[0, :, old_start_idx[0], old_start_idx[1]] = 1.0 # Reset old start to path
|
| 259 |
+
|
| 260 |
+
if 0 <= final_pos[0] < next_input.shape[2] and 0 <= final_pos[1] < next_input.shape[3]:
|
| 261 |
+
next_input[0, :, final_pos[0], final_pos[1]] = torch.tensor([1,0,0], device=device, dtype=next_input.dtype) # New start
|
| 262 |
+
else:
|
| 263 |
+
print(f"Warning: final_pos {final_pos} out of bounds for maze {maze_idx_display}. Stopping reapplication.")
|
| 264 |
+
break
|
| 265 |
+
current_input_maze = next_input
|
| 266 |
+
|
| 267 |
+
if has_solved:
|
| 268 |
+
print(f'Solved maze of length {maze_actual_length}! Saving...')
|
| 269 |
+
os.makedirs(maze_output_dir, exist_ok=True)
|
| 270 |
+
if ongoing_solution_img is not None:
|
| 271 |
+
cv2.imwrite(os.path.join(maze_output_dir, 'ongoing_solution.png'), (ongoing_solution_img * 255).astype(np.uint8)[:,:,::-1])
|
| 272 |
+
if long_frames:
|
| 273 |
+
save_frames_to_mp4([fm[:,:,::-1] for fm in long_frames], os.path.join(maze_output_dir, f'combined_process.mp4'), fps=45, gop_size=10, preset='veryslow', crf=20)
|
| 274 |
+
else:
|
| 275 |
+
print(f'Failed maze of length {maze_actual_length} after {re_applications} reapplications. Not saving visuals for this maze.')
|
| 276 |
+
|
| 277 |
+
if maze_actual_length not in results: results[maze_actual_length] = []
|
| 278 |
+
results[maze_actual_length].append((has_solved, re_applications))
|
| 279 |
+
|
| 280 |
+
fig_success, ax_success = plt.subplots()
|
| 281 |
+
fig_reapp, ax_reapp = plt.subplots()
|
| 282 |
+
sorted_lengths = sorted(results.keys())
|
| 283 |
+
if sorted_lengths:
|
| 284 |
+
success_rates = [np.mean([r[0] for r in results[l]]) * 100 for l in sorted_lengths]
|
| 285 |
+
reapps_mean = [np.mean([r[1] for r in results[l] if r[0]]) if any(r[0] for r in results[l]) else np.nan for l in sorted_lengths]
|
| 286 |
+
ax_success.plot(sorted_lengths, success_rates, linestyle='-', color=palette[0])
|
| 287 |
+
ax_reapp.plot(sorted_lengths, reapps_mean, linestyle='-', color=palette[5])
|
| 288 |
+
ax_success.set_xlabel('Route Length'); ax_success.set_ylabel('Success (%)')
|
| 289 |
+
ax_reapp.set_xlabel('Route Length'); ax_reapp.set_ylabel('Re-applications (Avg on Success)')
|
| 290 |
+
fig_success.tight_layout(pad=0.1); fig_reapp.tight_layout(pad=0.1)
|
| 291 |
+
fig_success.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-success_rate.png'), dpi=200)
|
| 292 |
+
fig_success.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-success_rate.pdf'), dpi=200)
|
| 293 |
+
fig_reapp.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-re-applications.png'), dpi=200)
|
| 294 |
+
fig_reapp.savefig(os.path.join(output_gen_dir, f'{args.dataset_for_gen}-re-applications.pdf'), dpi=200)
|
| 295 |
+
plt.close(fig_success); plt.close(fig_reapp)
|
| 296 |
+
np.savez(os.path.join(output_gen_dir, f'{args.dataset_for_gen}_results.npz'), results=results)
|
| 297 |
+
|
| 298 |
+
print("\n--- Generalisation Analysis ('gen') Complete ---")
|
| 299 |
+
|
| 300 |
+
# --- Visualization Action ('viz') ---
|
| 301 |
+
if 'viz' in args.actions:
|
| 302 |
+
model = _load_ctm_model(args.checkpoint, device)
|
| 303 |
+
|
| 304 |
+
print(f"\n--- Running Visualization ('viz'): {args.dataset_for_viz} ---")
|
| 305 |
+
output_viz_dir = os.path.join(args.output_dir, 'viz')
|
| 306 |
+
os.makedirs(output_viz_dir, exist_ok=True)
|
| 307 |
+
|
| 308 |
+
target_dataset_name = f'{args.dataset_for_viz}'
|
| 309 |
+
data_root = f'data/mazes/{target_dataset_name}/test'
|
| 310 |
+
test_data = MazeImageFolder(
|
| 311 |
+
root=data_root, which_set='test',
|
| 312 |
+
maze_route_length=100, # Max route length for viz data
|
| 313 |
+
expand_range=not args.legacy_scaling, # # Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
|
| 314 |
+
trunc=True
|
| 315 |
+
)
|
| 316 |
+
testloader = torch.utils.data.DataLoader(
|
| 317 |
+
test_data, batch_size=args.batch_size_test,
|
| 318 |
+
shuffle=False, num_workers=1
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
all_inputs, all_targets, all_lengths = [], [], []
|
| 322 |
+
for b_in, b_tgt in testloader:
|
| 323 |
+
all_inputs.append(b_in)
|
| 324 |
+
all_targets.append(b_tgt)
|
| 325 |
+
all_lengths.append((b_tgt != 4).sum(dim=-1))
|
| 326 |
+
|
| 327 |
+
if not all_inputs:
|
| 328 |
+
print("Error: No data in visualization loader. Exiting 'viz' action.")
|
| 329 |
+
exit()
|
| 330 |
+
|
| 331 |
+
all_inputs, all_targets, all_lengths = torch.cat(all_inputs), torch.cat(all_targets), torch.cat(all_lengths)
|
| 332 |
+
|
| 333 |
+
num_viz_mazes = 10
|
| 334 |
+
num_viz_mazes = min(num_viz_mazes, len(all_lengths))
|
| 335 |
+
|
| 336 |
+
if num_viz_mazes == 0:
|
| 337 |
+
print("Error: No mazes found to visualize. Exiting 'viz' action.")
|
| 338 |
+
exit()
|
| 339 |
+
|
| 340 |
+
top_indices = torch.argsort(all_lengths, descending=True)[:num_viz_mazes]
|
| 341 |
+
inputs_viz, targets_viz = all_inputs[top_indices].to(device), all_targets[top_indices]
|
| 342 |
+
|
| 343 |
+
print(f"Visualizing {len(inputs_viz)} longest mazes...")
|
| 344 |
+
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
predictions, _, _, _, _, attention_tracking = model(inputs_viz, track=True)
|
| 347 |
+
|
| 348 |
+
# Reshape attention: (Steps, Batch, Heads, H_feat, W_feat) assuming model.kv_features has H_feat, W_feat
|
| 349 |
+
# The original reshape was slightly different, this tries to match the likely intended dimensions for per-step, per-batch item attention
|
| 350 |
+
if attention_tracking is not None and hasattr(model, 'kv_features') and model.kv_features is not None:
|
| 351 |
+
attention_tracking = attention_tracking.reshape(
|
| 352 |
+
attention_tracking.shape[0], # Iterations/Steps
|
| 353 |
+
inputs_viz.size(0), # Batch size (num_viz_mazes)
|
| 354 |
+
-1, # Heads (inferred)
|
| 355 |
+
model.kv_features.shape[-2], # H_feat
|
| 356 |
+
model.kv_features.shape[-1] # W_feat
|
| 357 |
+
)
|
| 358 |
+
else:
|
| 359 |
+
attention_tracking = None # Ensure it's None if it can't be reshaped
|
| 360 |
+
print("Warning: Could not reshape attention_tracking. Visualizations may not include attention overlays.")
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
for maze_i in range(inputs_viz.size(0)):
|
| 364 |
+
maze_idx_display = maze_i + 1
|
| 365 |
+
maze_output_dir = os.path.join(output_viz_dir, f"maze_{maze_idx_display}")
|
| 366 |
+
os.makedirs(maze_output_dir, exist_ok=True)
|
| 367 |
+
|
| 368 |
+
current_input_np_original = inputs_viz[maze_i].permute(1,2,0).detach().cpu().numpy()
|
| 369 |
+
# Apply scaling for visualization based on legacy_scaling: Legacy checkpoints need a [0, 1] range, but it might be better to default to [-1, 1] in the future
|
| 370 |
+
current_input_np_display = (current_input_np_original + 1) / 2 if not args.legacy_scaling else current_input_np_original
|
| 371 |
+
|
| 372 |
+
current_target_route = targets_viz[maze_i].detach().cpu().numpy()
|
| 373 |
+
print(f"Generating viz for maze {maze_idx_display}...")
|
| 374 |
+
|
| 375 |
+
try:
|
| 376 |
+
solution_maze_img = draw_path(current_input_np_display, current_target_route, gt=True)
|
| 377 |
+
cv2.imwrite(os.path.join(maze_output_dir, 'solution_ground_truth.png'), (solution_maze_img * 255).astype(np.uint8)[:,:,::-1])
|
| 378 |
+
except Exception: # Keep broad except for visualization robustness
|
| 379 |
+
print(f"Could not save ground truth solution for maze {maze_idx_display}")
|
| 380 |
+
pass
|
| 381 |
+
|
| 382 |
+
frames = []
|
| 383 |
+
n_steps_viz = predictions.shape[-1] # Use a different name
|
| 384 |
+
step_linspace = np.linspace(0, 1, n_steps_viz)
|
| 385 |
+
|
| 386 |
+
for stepi in range(n_steps_viz):
|
| 387 |
+
pred_route = predictions[maze_i, :, stepi].reshape(-1, 5).argmax(-1).detach().cpu().numpy()
|
| 388 |
+
frame = draw_path(current_input_np_display, pred_route)
|
| 389 |
+
|
| 390 |
+
if attention_tracking is not None and stepi < attention_tracking.shape[0] and maze_i < attention_tracking.shape[1]:
|
| 391 |
+
|
| 392 |
+
# Attention for current step (stepi) and current maze in batch (maze_i), average over heads
|
| 393 |
+
attn = attention_tracking[stepi, maze_i].mean(0)
|
| 394 |
+
attn_resized = cv2.resize(attn, (current_input_np_display.shape[1], current_input_np_display.shape[0]), interpolation=cv2.INTER_LINEAR)
|
| 395 |
+
if attn_resized.max() > attn_resized.min():
|
| 396 |
+
attn_norm = (attn_resized - attn_resized.min()) / (attn_resized.max() - attn_resized.min())
|
| 397 |
+
attn_norm[attn_norm < np.percentile(attn_norm, 80)] = 0.0
|
| 398 |
+
frame = np.clip((np.copy(frame)*(1-attn_norm[:,:,np.newaxis])*0.9 + (attn_norm[:,:,np.newaxis]*1.2 * np.reshape(np.array(cmap(step_linspace[stepi]))[:3], (1, 1, 3)))), 0, 1)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
frame_resized = cv2.resize(frame, (256, 256), interpolation=cv2.INTER_NEAREST)
|
| 402 |
+
frames.append((np.clip(frame_resized, 0, 1) * 255).astype(np.uint8))
|
| 403 |
+
|
| 404 |
+
if frames:
|
| 405 |
+
imageio.mimsave(os.path.join(maze_output_dir, 'attention_overlay.gif'), frames, fps=15, loop=0)
|
| 406 |
+
|
| 407 |
+
print("\n--- Visualization Action ('viz') Complete ---")
|
tasks/mazes/plotting.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import imageio
|
| 8 |
+
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
+
|
| 11 |
+
def find_center_of_mass(array_2d):
|
| 12 |
+
"""
|
| 13 |
+
Alternative implementation using np.average and meshgrid.
|
| 14 |
+
This version is generally faster and more concise.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
array_2d: A 2D numpy array of values between 0 and 1.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
A tuple (x, y) representing the coordinates of the center of mass.
|
| 21 |
+
"""
|
| 22 |
+
total_mass = np.sum(array_2d)
|
| 23 |
+
if total_mass == 0:
|
| 24 |
+
return (np.nan, np.nan)
|
| 25 |
+
|
| 26 |
+
y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]]
|
| 27 |
+
x_center = np.average(x_coords, weights=array_2d)
|
| 28 |
+
y_center = np.average(y_coords, weights=array_2d)
|
| 29 |
+
return (round(y_center, 4), round(x_center, 4))
|
| 30 |
+
|
| 31 |
+
def draw_path(x, route, valid_only=False, gt=False, cmap=None):
|
| 32 |
+
"""
|
| 33 |
+
Draws a path on a maze image based on a given route.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
maze: A numpy array representing the maze image.
|
| 37 |
+
route: A list of integers representing the route, where 0 is up, 1 is down, 2 is left, and 3 is right.
|
| 38 |
+
valid_only: A boolean indicating whether to only draw valid steps (i.e., steps that don't go into walls).
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
A numpy array representing the maze image with the path drawn in blue.
|
| 42 |
+
"""
|
| 43 |
+
x = np.copy(x)
|
| 44 |
+
start = np.argwhere((x == [1, 0, 0]).all(axis=2))
|
| 45 |
+
end = np.argwhere((x == [0, 1, 0]).all(axis=2))
|
| 46 |
+
if cmap is None:
|
| 47 |
+
cmap = plt.get_cmap('winter') if not valid_only else plt.get_cmap('summer')
|
| 48 |
+
|
| 49 |
+
# Initialize the current position
|
| 50 |
+
current_pos = start[0]
|
| 51 |
+
|
| 52 |
+
# Draw the path
|
| 53 |
+
colors = cmap(np.linspace(0, 1, len(route)))
|
| 54 |
+
si = 0
|
| 55 |
+
for step in route:
|
| 56 |
+
new_pos = current_pos
|
| 57 |
+
if step == 0: # Up
|
| 58 |
+
new_pos = (current_pos[0] - 1, current_pos[1])
|
| 59 |
+
elif step == 1: # Down
|
| 60 |
+
new_pos = (current_pos[0] + 1, current_pos[1])
|
| 61 |
+
elif step == 2: # Left
|
| 62 |
+
new_pos = (current_pos[0], current_pos[1] - 1)
|
| 63 |
+
elif step == 3: # Right
|
| 64 |
+
new_pos = (current_pos[0], current_pos[1] + 1)
|
| 65 |
+
elif step == 4: # Do nothing
|
| 66 |
+
pass
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError("Invalid step: {}".format(step))
|
| 69 |
+
|
| 70 |
+
# Check if the new position is valid
|
| 71 |
+
if valid_only:
|
| 72 |
+
try:
|
| 73 |
+
if np.all(x[new_pos] == [0,0,0]): # Check if it's a wall
|
| 74 |
+
continue # Skip this step if it's invalid
|
| 75 |
+
except IndexError:
|
| 76 |
+
continue # Skip this step if it's out of bounds
|
| 77 |
+
|
| 78 |
+
# Draw the step
|
| 79 |
+
if new_pos[0] >= 0 and new_pos[0] < x.shape[0] and new_pos[1] >= 0 and new_pos[1] < x.shape[1]:
|
| 80 |
+
if not ((x[new_pos] == [1,0,0]).all() or (x[new_pos] == [0,1,0]).all()):
|
| 81 |
+
colour = colors[si][:3]
|
| 82 |
+
si += 1
|
| 83 |
+
x[new_pos] = x[new_pos]*0.5 + colour*0.5
|
| 84 |
+
|
| 85 |
+
# Update the current position
|
| 86 |
+
current_pos = new_pos
|
| 87 |
+
# cv2.imwrite('maze2.png', x[:,:,::-1]*255)
|
| 88 |
+
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
def make_maze_gif(inputs, predictions, targets, attention_tracking, save_location):
|
| 92 |
+
"""
|
| 93 |
+
Expect inputs, predictions, targets as numpy arrays
|
| 94 |
+
"""
|
| 95 |
+
route_steps = []
|
| 96 |
+
route_colours = []
|
| 97 |
+
solution_maze = draw_path(np.moveaxis(inputs, 0, -1), targets)
|
| 98 |
+
|
| 99 |
+
# cv2.imwrite(f'{save_location}/ground_truth.png', solution_maze[:,:,::-1]*255)
|
| 100 |
+
mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 101 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 102 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 103 |
+
['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'],
|
| 104 |
+
['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'],
|
| 105 |
+
['head_8', 'head_9', 'head_10', 'head_11', 'head_12', 'head_13', 'head_14', 'head_15'],
|
| 106 |
+
]
|
| 107 |
+
img_aspect = 1
|
| 108 |
+
figscale = 1
|
| 109 |
+
aspect_ratio = (8 * figscale, 6 * figscale * img_aspect) # W, H
|
| 110 |
+
|
| 111 |
+
route_steps = [np.unravel_index(np.argmax((inputs == np.reshape(np.array([1, 0, 0]), (3, 1, 1))).all(0)), inputs.shape[1:])] # Starting point
|
| 112 |
+
frames = []
|
| 113 |
+
cmap = plt.get_cmap('gist_rainbow')
|
| 114 |
+
cmap_viridis = plt.get_cmap('viridis')
|
| 115 |
+
step_linspace = np.linspace(0, 1, predictions.shape[-1]) # For sampling colours
|
| 116 |
+
with tqdm(total=predictions.shape[-1], initial=0, leave=True, position=1, dynamic_ncols=True) as pbar:
|
| 117 |
+
pbar.set_description('Processing frames for maze plotting')
|
| 118 |
+
for stepi in np.arange(0, predictions.shape[-1], 1):
|
| 119 |
+
fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio)
|
| 120 |
+
for ax in axes.values():
|
| 121 |
+
ax.axis('off')
|
| 122 |
+
guess_maze = draw_path(np.moveaxis(inputs, 0, -1), predictions.argmax(1)[:,stepi], cmap=cmap)
|
| 123 |
+
attention_now = attention_tracking[stepi]
|
| 124 |
+
for hi in range(min((attention_tracking.shape[1], 16))):
|
| 125 |
+
ax = axes[f'head_{hi}']
|
| 126 |
+
attn = attention_tracking[stepi, hi]
|
| 127 |
+
attn = (attn - attn.min())/(np.ptp(attn))
|
| 128 |
+
ax.imshow(attn, cmap=cmap_viridis)
|
| 129 |
+
# Upsample attention just for visualisation
|
| 130 |
+
aggregated_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy()
|
| 131 |
+
|
| 132 |
+
# Get approximate center of mass
|
| 133 |
+
com_attn = np.copy(aggregated_attention)
|
| 134 |
+
com_attn[com_attn < np.percentile(com_attn, 96)] = 0.0
|
| 135 |
+
aggregated_attention[aggregated_attention < np.percentile(aggregated_attention, 80)] = 0.0
|
| 136 |
+
route_steps.append(find_center_of_mass(com_attn))
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
colour = list(cmap(step_linspace[stepi]))
|
| 140 |
+
route_colours.append(colour)
|
| 141 |
+
|
| 142 |
+
mapped_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy()
|
| 143 |
+
mapped_attention = (mapped_attention - mapped_attention.min())/np.ptp(mapped_attention)
|
| 144 |
+
# np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.5) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.3, 0, 1)
|
| 145 |
+
overlay_img = np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.6) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.1, 0, 1)#np.clip((np.copy(guess_maze)*(1-aggregated_attention[:,:,np.newaxis])*0.7 + (aggregated_attention[:,:,np.newaxis]*3 * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1)
|
| 146 |
+
axes['overlay'].imshow(overlay_img)
|
| 147 |
+
|
| 148 |
+
y_coords, x_coords = zip(*route_steps)
|
| 149 |
+
y_coords = inputs.shape[-1] - np.array(list(y_coords))-1
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
axes['route'].imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower')
|
| 153 |
+
# ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
|
| 154 |
+
arrow_scale = 2
|
| 155 |
+
for i in range(len(route_steps)-1):
|
| 156 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 157 |
+
dy = y_coords[i+1] - y_coords[i]
|
| 158 |
+
axes['route'].arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
|
| 159 |
+
|
| 160 |
+
fig.tight_layout(pad=0.1) # Adjust spacing
|
| 161 |
+
|
| 162 |
+
# Render the plot to a numpy array
|
| 163 |
+
canvas = fig.canvas
|
| 164 |
+
canvas.draw()
|
| 165 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 166 |
+
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB
|
| 167 |
+
|
| 168 |
+
frames.append(image_numpy) # Add to list for GIF
|
| 169 |
+
|
| 170 |
+
# fig.savefig(f'{save_location}/frame.png', dpi=200)
|
| 171 |
+
|
| 172 |
+
plt.close(fig)
|
| 173 |
+
|
| 174 |
+
# # frame = np.clip((np.copy(guess_maze)*0.5 + (aggregated_attention[:,:,np.newaxis] * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1)
|
| 175 |
+
# frame = torch.nn.functional.interpolate(torch.from_numpy(frame).permute(2,0,1).unsqueeze(0), 256)[0].permute(1,2,0).detach().cpu().numpy()
|
| 176 |
+
# frames.append((frame*255).astype(np.uint8))
|
| 177 |
+
pbar.update(1)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
y_coords, x_coords = zip(*route_steps)
|
| 181 |
+
y_coords = inputs.shape[-1] - np.array(list(y_coords))-1
|
| 182 |
+
|
| 183 |
+
fig = plt.figure(figsize=(5,5))
|
| 184 |
+
ax = fig.add_subplot(111)
|
| 185 |
+
|
| 186 |
+
ax.imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower')
|
| 187 |
+
# ax.imshow(np.flip(solution_maze, axis=0), origin='lower')
|
| 188 |
+
arrow_scale = 2
|
| 189 |
+
for i in range(len(route_steps)-1):
|
| 190 |
+
dx = x_coords[i+1] - x_coords[i]
|
| 191 |
+
dy = y_coords[i+1] - y_coords[i]
|
| 192 |
+
plt.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True)
|
| 193 |
+
|
| 194 |
+
ax.axis('off')
|
| 195 |
+
fig.tight_layout(pad=0)
|
| 196 |
+
fig.savefig(f'{save_location}/route_approximation.png', dpi=200)
|
| 197 |
+
imageio.mimsave(f'{save_location}/prediction.gif', frames, fps=15, loop=100)
|
| 198 |
+
plt.close(fig)
|
tasks/mazes/scripts/train_ctm.sh
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python -m tasks.mazes.train \
|
| 2 |
+
--model ctm \
|
| 3 |
+
--log_dir logs/mazes/ctm/d=2048--i=512--heads=16--sd=8--nlm=32--synch=64-32-h=32-first-last--iters=75x25--backbone=34-2 \
|
| 4 |
+
--neuron_select_type first-last \
|
| 5 |
+
--dataset mazes-large \
|
| 6 |
+
--synapse_depth 8 \
|
| 7 |
+
--heads 16 \
|
| 8 |
+
--iterations 75 \
|
| 9 |
+
--memory_length 25 \
|
| 10 |
+
--d_model 2048 \
|
| 11 |
+
--d_input 512 \
|
| 12 |
+
--backbone_type resnet34-2 \
|
| 13 |
+
--n_synch_out 64 \
|
| 14 |
+
--n_synch_action 32 \
|
| 15 |
+
--memory_hidden_dims 32 \
|
| 16 |
+
--deep_memory \
|
| 17 |
+
--weight_decay 0.000 \
|
| 18 |
+
--batch_size 64 \
|
| 19 |
+
--batch_size_test 128 \
|
| 20 |
+
--n_test_batches 20 \
|
| 21 |
+
--gradient_clipping -1 \
|
| 22 |
+
--use_scheduler \
|
| 23 |
+
--scheduler_type cosine \
|
| 24 |
+
--warmup_steps 10000 \
|
| 25 |
+
--training_iterations 1000001 \
|
| 26 |
+
--no-do_normalisation \
|
| 27 |
+
--track_every 1000 \
|
| 28 |
+
--lr 1e-4 \
|
| 29 |
+
--no-reload \
|
| 30 |
+
--dropout 0.1 \
|
| 31 |
+
--positional_embedding_type none \
|
| 32 |
+
--maze_route_length 100 \
|
| 33 |
+
--cirriculum_lookahead 5 \
|
| 34 |
+
--device 0 \
|
| 35 |
+
--no-expand_range
|
tasks/mazes/train.py
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import seaborn as sns
|
| 8 |
+
sns.set_style('darkgrid')
|
| 9 |
+
import torch
|
| 10 |
+
if torch.cuda.is_available():
|
| 11 |
+
# For faster
|
| 12 |
+
torch.set_float32_matmul_precision('high')
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
|
| 15 |
+
from data.custom_datasets import MazeImageFolder
|
| 16 |
+
from models.ctm import ContinuousThoughtMachine
|
| 17 |
+
from models.lstm import LSTMBaseline
|
| 18 |
+
from models.ff import FFBaseline
|
| 19 |
+
from tasks.mazes.plotting import make_maze_gif
|
| 20 |
+
from tasks.image_classification.plotting import plot_neural_dynamics
|
| 21 |
+
from utils.housekeeping import set_seed, zip_python_code
|
| 22 |
+
from utils.losses import maze_loss
|
| 23 |
+
from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
|
| 24 |
+
|
| 25 |
+
import torchvision
|
| 26 |
+
torchvision.disable_beta_transforms_warning()
|
| 27 |
+
|
| 28 |
+
import warnings
|
| 29 |
+
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
|
| 30 |
+
warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
|
| 31 |
+
warnings.filterwarnings(
|
| 32 |
+
"ignore",
|
| 33 |
+
"Corrupt EXIF data",
|
| 34 |
+
UserWarning,
|
| 35 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 36 |
+
)
|
| 37 |
+
warnings.filterwarnings(
|
| 38 |
+
"ignore",
|
| 39 |
+
"UserWarning: Metadata Warning",
|
| 40 |
+
UserWarning,
|
| 41 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 42 |
+
)
|
| 43 |
+
warnings.filterwarnings(
|
| 44 |
+
"ignore",
|
| 45 |
+
"UserWarning: Truncated File Read",
|
| 46 |
+
UserWarning,
|
| 47 |
+
r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def parse_args():
|
| 52 |
+
parser = argparse.ArgumentParser()
|
| 53 |
+
|
| 54 |
+
# Model Selection
|
| 55 |
+
parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
|
| 56 |
+
|
| 57 |
+
# Model Architecture
|
| 58 |
+
# Common across all or most
|
| 59 |
+
parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
|
| 60 |
+
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
|
| 61 |
+
parser.add_argument('--backbone_type', type=str, default='resnet34-2', help='Type of backbone featureiser.') # Default changed from original script
|
| 62 |
+
# CTM / LSTM specific
|
| 63 |
+
parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
|
| 64 |
+
parser.add_argument('--heads', type=int, default=8, help='Number of attention heads (CTM, LSTM).') # Default changed
|
| 65 |
+
parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
|
| 66 |
+
parser.add_argument('--positional_embedding_type', type=str, default='none',
|
| 67 |
+
help='Type of positional embedding (CTM, LSTM).', choices=['none',
|
| 68 |
+
'learnable-fourier',
|
| 69 |
+
'multi-learnable-fourier',
|
| 70 |
+
'custom-rotational'])
|
| 71 |
+
|
| 72 |
+
# CTM specific
|
| 73 |
+
parser.add_argument('--synapse_depth', type=int, default=8, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).') # Default changed
|
| 74 |
+
parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).') # Default changed
|
| 75 |
+
parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).') # Default changed
|
| 76 |
+
parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
|
| 77 |
+
parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
|
| 78 |
+
parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
|
| 79 |
+
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True,
|
| 80 |
+
help='Use deep memory (CTM only).')
|
| 81 |
+
parser.add_argument('--memory_hidden_dims', type=int, default=32, help='Hidden dimensions of the memory if using deep memory (CTM only).') # Default changed
|
| 82 |
+
parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
|
| 83 |
+
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
|
| 84 |
+
# LSTM specific
|
| 85 |
+
parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).') # Added LSTM arg
|
| 86 |
+
|
| 87 |
+
# Task Specific Args (Common to all models for this task)
|
| 88 |
+
parser.add_argument('--maze_route_length', type=int, default=100, help='Length to truncate targets.')
|
| 89 |
+
parser.add_argument('--cirriculum_lookahead', type=int, default=5, help='How far to look ahead for cirriculum.')
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Training
|
| 93 |
+
parser.add_argument('--expand_range', action=argparse.BooleanOptionalAction, default=True, help='Mazes between 0 and 1 = False. Between -1 and 1 = True. Legacy checkpoints use 0 and 1.')
|
| 94 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training.') # Default changed
|
| 95 |
+
parser.add_argument('--batch_size_test', type=int, default=64, help='Batch size for testing.') # Default changed
|
| 96 |
+
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate for the model.') # Default changed
|
| 97 |
+
parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
|
| 98 |
+
parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
|
| 99 |
+
parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
|
| 100 |
+
parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
|
| 101 |
+
parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
|
| 102 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
|
| 103 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
|
| 104 |
+
parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
|
| 105 |
+
parser.add_argument('--num_workers_train', type=int, default=0, help='Num workers training.') # Renamed from num_workers, kept default
|
| 106 |
+
parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
|
| 107 |
+
parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
|
| 108 |
+
|
| 109 |
+
# Logging and Saving
|
| 110 |
+
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 111 |
+
parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large'])
|
| 112 |
+
parser.add_argument('--data_root', type=str, default='data/mazes', help='Data root.')
|
| 113 |
+
|
| 114 |
+
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
| 115 |
+
parser.add_argument('--seed', type=int, default=412, help='Random seed.')
|
| 116 |
+
parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
|
| 117 |
+
parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
|
| 118 |
+
parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
|
| 119 |
+
parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading (for debugging)?') # Added back
|
| 120 |
+
|
| 121 |
+
# Tracking
|
| 122 |
+
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
|
| 123 |
+
parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval') # Default changed
|
| 124 |
+
|
| 125 |
+
# Device
|
| 126 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
|
| 127 |
+
parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
args = parser.parse_args()
|
| 131 |
+
return args
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
if __name__=='__main__':
|
| 135 |
+
|
| 136 |
+
# Hosuekeeping
|
| 137 |
+
args = parse_args()
|
| 138 |
+
|
| 139 |
+
set_seed(args.seed, False)
|
| 140 |
+
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 141 |
+
|
| 142 |
+
assert args.dataset in ['mazes-medium', 'mazes-large']
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
prediction_reshaper = [args.maze_route_length, 5] # Problem specific
|
| 147 |
+
args.out_dims = args.maze_route_length * 5 # Output dimension before reshaping
|
| 148 |
+
|
| 149 |
+
# For total reproducibility
|
| 150 |
+
zip_python_code(f'{args.log_dir}/repo_state.zip')
|
| 151 |
+
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 152 |
+
print(args, file=f)
|
| 153 |
+
|
| 154 |
+
# Configure device string
|
| 155 |
+
device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
|
| 156 |
+
print(f'Running model {args.model} on {device} for dataset {args.dataset}')
|
| 157 |
+
|
| 158 |
+
# Build model conditionally
|
| 159 |
+
model = None
|
| 160 |
+
if args.model == 'ctm':
|
| 161 |
+
model = ContinuousThoughtMachine(
|
| 162 |
+
iterations=args.iterations,
|
| 163 |
+
d_model=args.d_model,
|
| 164 |
+
d_input=args.d_input,
|
| 165 |
+
heads=args.heads,
|
| 166 |
+
n_synch_out=args.n_synch_out,
|
| 167 |
+
n_synch_action=args.n_synch_action,
|
| 168 |
+
synapse_depth=args.synapse_depth,
|
| 169 |
+
memory_length=args.memory_length,
|
| 170 |
+
deep_nlms=args.deep_memory,
|
| 171 |
+
memory_hidden_dims=args.memory_hidden_dims,
|
| 172 |
+
do_layernorm_nlm=args.do_normalisation,
|
| 173 |
+
backbone_type=args.backbone_type,
|
| 174 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 175 |
+
out_dims=args.out_dims,
|
| 176 |
+
prediction_reshaper=prediction_reshaper,
|
| 177 |
+
dropout=args.dropout,
|
| 178 |
+
dropout_nlm=args.dropout_nlm,
|
| 179 |
+
neuron_select_type=args.neuron_select_type,
|
| 180 |
+
n_random_pairing_self=args.n_random_pairing_self,
|
| 181 |
+
).to(device)
|
| 182 |
+
elif args.model == 'lstm':
|
| 183 |
+
model = LSTMBaseline(
|
| 184 |
+
num_layers=args.num_layers,
|
| 185 |
+
iterations=args.iterations,
|
| 186 |
+
d_model=args.d_model,
|
| 187 |
+
d_input=args.d_input,
|
| 188 |
+
heads=args.heads,
|
| 189 |
+
backbone_type=args.backbone_type,
|
| 190 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 191 |
+
out_dims=args.out_dims,
|
| 192 |
+
prediction_reshaper=prediction_reshaper,
|
| 193 |
+
dropout=args.dropout,
|
| 194 |
+
).to(device)
|
| 195 |
+
elif args.model == 'ff':
|
| 196 |
+
model = FFBaseline(
|
| 197 |
+
d_model=args.d_model,
|
| 198 |
+
backbone_type=args.backbone_type,
|
| 199 |
+
out_dims=args.out_dims,
|
| 200 |
+
dropout=args.dropout,
|
| 201 |
+
).to(device)
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError(f"Unknown model type: {args.model}")
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
# Determine pseudo input shape based on dataset
|
| 207 |
+
h_w = 39 if args.dataset in ['mazes-small', 'mazes-medium'] else 99 # Example dimensions
|
| 208 |
+
pseudo_inputs = torch.zeros((1, 3, h_w, h_w), device=device).float()
|
| 209 |
+
model(pseudo_inputs)
|
| 210 |
+
except Exception as e:
|
| 211 |
+
print(f"Warning: Pseudo forward pass failed: {e}")
|
| 212 |
+
|
| 213 |
+
print(f'Total params: {sum(p.numel() for p in model.parameters())}')
|
| 214 |
+
|
| 215 |
+
# Data
|
| 216 |
+
dataset_mean = [0,0,0] # For plotting later
|
| 217 |
+
dataset_std = [1,1,1]
|
| 218 |
+
|
| 219 |
+
which_maze = args.dataset.split('-')[-1]
|
| 220 |
+
data_root = f'{args.data_root}/{which_maze}'
|
| 221 |
+
|
| 222 |
+
train_data = MazeImageFolder(root=f'{data_root}/train/', which_set='train', maze_route_length=args.maze_route_length, expand_range=args.expand_range)
|
| 223 |
+
test_data = MazeImageFolder(root=f'{data_root}/test/', which_set='test', maze_route_length=args.maze_route_length, expand_range=args.expand_range)
|
| 224 |
+
|
| 225 |
+
num_workers_test = 1 # Defaulting to 1, can be changed
|
| 226 |
+
trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train, drop_last=True)
|
| 227 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False)
|
| 228 |
+
|
| 229 |
+
# For lazy modules so that we can get param count
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
model.train()
|
| 233 |
+
|
| 234 |
+
# Optimizer and scheduler
|
| 235 |
+
decay_params = []
|
| 236 |
+
no_decay_params = []
|
| 237 |
+
no_decay_names = []
|
| 238 |
+
for name, param in model.named_parameters():
|
| 239 |
+
if not param.requires_grad:
|
| 240 |
+
continue # Skip parameters that don't require gradients
|
| 241 |
+
if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
|
| 242 |
+
no_decay_params.append(param)
|
| 243 |
+
no_decay_names.append(name)
|
| 244 |
+
else:
|
| 245 |
+
decay_params.append(param)
|
| 246 |
+
if len(no_decay_names):
|
| 247 |
+
print(f'WARNING, excluding: {no_decay_names}')
|
| 248 |
+
|
| 249 |
+
# Optimizer and scheduler (Common setup)
|
| 250 |
+
if len(no_decay_names) and args.weight_decay!=0:
|
| 251 |
+
optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
|
| 252 |
+
{'params': no_decay_params, 'weight_decay':0}],
|
| 253 |
+
lr=args.lr,
|
| 254 |
+
eps=1e-8 if not args.use_amp else 1e-6)
|
| 255 |
+
else:
|
| 256 |
+
optimizer = torch.optim.AdamW(model.parameters(),
|
| 257 |
+
lr=args.lr,
|
| 258 |
+
eps=1e-8 if not args.use_amp else 1e-6,
|
| 259 |
+
weight_decay=args.weight_decay)
|
| 260 |
+
|
| 261 |
+
warmup_schedule = warmup(args.warmup_steps)
|
| 262 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
|
| 263 |
+
if args.use_scheduler:
|
| 264 |
+
if args.scheduler_type == 'multistep':
|
| 265 |
+
scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
|
| 266 |
+
elif args.scheduler_type == 'cosine':
|
| 267 |
+
scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
|
| 268 |
+
else:
|
| 269 |
+
raise NotImplementedError
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
# Metrics tracking
|
| 273 |
+
start_iter = 0
|
| 274 |
+
train_losses = []
|
| 275 |
+
test_losses = []
|
| 276 |
+
train_accuracies = [] # Per tick/step accuracy list
|
| 277 |
+
test_accuracies = []
|
| 278 |
+
train_accuracies_most_certain = [] # Accuracy, fine-grained
|
| 279 |
+
test_accuracies_most_certain = []
|
| 280 |
+
train_accuracies_most_certain_permaze = [] # Full maze accuracy
|
| 281 |
+
test_accuracies_most_certain_permaze = []
|
| 282 |
+
iters = []
|
| 283 |
+
|
| 284 |
+
scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
|
| 285 |
+
if args.reload:
|
| 286 |
+
checkpoint_path = f'{args.log_dir}/checkpoint.pt'
|
| 287 |
+
if os.path.isfile(checkpoint_path):
|
| 288 |
+
print(f'Reloading from: {checkpoint_path}')
|
| 289 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 290 |
+
if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
|
| 291 |
+
load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
|
| 292 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 293 |
+
|
| 294 |
+
if not args.reload_model_only:
|
| 295 |
+
print('Reloading optimizer etc.')
|
| 296 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 297 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 298 |
+
scaler.load_state_dict(checkpoint['scaler_state_dict']) # Load scaler state
|
| 299 |
+
start_iter = checkpoint['iteration']
|
| 300 |
+
|
| 301 |
+
if not args.ignore_metrics_when_reloading:
|
| 302 |
+
train_losses = checkpoint['train_losses']
|
| 303 |
+
test_losses = checkpoint['test_losses']
|
| 304 |
+
train_accuracies = checkpoint['train_accuracies']
|
| 305 |
+
test_accuracies = checkpoint['test_accuracies']
|
| 306 |
+
iters = checkpoint['iters']
|
| 307 |
+
train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
|
| 308 |
+
test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
|
| 309 |
+
train_accuracies_most_certain_permaze = checkpoint['train_accuracies_most_certain_permaze']
|
| 310 |
+
test_accuracies_most_certain_permaze = checkpoint['test_accuracies_most_certain_permaze']
|
| 311 |
+
else:
|
| 312 |
+
print("Ignoring metrics history upon reload.")
|
| 313 |
+
|
| 314 |
+
else:
|
| 315 |
+
print('Only reloading model!')
|
| 316 |
+
|
| 317 |
+
if 'torch_rng_state' in checkpoint:
|
| 318 |
+
# Reset seeds
|
| 319 |
+
torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte())
|
| 320 |
+
np.random.set_state(checkpoint['numpy_rng_state'])
|
| 321 |
+
random.setstate(checkpoint['random_rng_state'])
|
| 322 |
+
|
| 323 |
+
del checkpoint
|
| 324 |
+
import gc
|
| 325 |
+
gc.collect()
|
| 326 |
+
if torch.cuda.is_available():
|
| 327 |
+
torch.cuda.empty_cache()
|
| 328 |
+
|
| 329 |
+
if args.do_compile:
|
| 330 |
+
print('Compiling...')
|
| 331 |
+
if hasattr(model, 'backbone'):
|
| 332 |
+
model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True)
|
| 333 |
+
# Compile synapses only for CTM
|
| 334 |
+
if args.model == 'ctm':
|
| 335 |
+
model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True)
|
| 336 |
+
|
| 337 |
+
# Training
|
| 338 |
+
iterator = iter(trainloader)
|
| 339 |
+
with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar:
|
| 340 |
+
for bi in range(start_iter, args.training_iterations):
|
| 341 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 342 |
+
|
| 343 |
+
try:
|
| 344 |
+
inputs, targets = next(iterator)
|
| 345 |
+
except StopIteration:
|
| 346 |
+
iterator = iter(trainloader)
|
| 347 |
+
inputs, targets = next(iterator)
|
| 348 |
+
|
| 349 |
+
inputs = inputs.to(device)
|
| 350 |
+
targets = targets.to(device) # Shape (B, SeqLength)
|
| 351 |
+
|
| 352 |
+
# All for nice metric printing:
|
| 353 |
+
loss = None
|
| 354 |
+
accuracy_finegrained = None # Per-step accuracy at chosen tick
|
| 355 |
+
where_most_certain_val = -1.0 # Default value
|
| 356 |
+
where_most_certain_std = 0.0
|
| 357 |
+
where_most_certain_min = -1
|
| 358 |
+
where_most_certain_max = -1
|
| 359 |
+
upto_where_mean = -1.0
|
| 360 |
+
upto_where_std = 0.0
|
| 361 |
+
upto_where_min = -1
|
| 362 |
+
upto_where_max = -1
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# Model-specific forward, reshape, and loss calculation
|
| 366 |
+
with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
|
| 367 |
+
if args.do_compile: # CUDAGraph marking applied if compiling any model
|
| 368 |
+
torch.compiler.cudagraph_mark_step_begin()
|
| 369 |
+
|
| 370 |
+
if args.model == 'ctm':
|
| 371 |
+
# CTM output: (B, SeqLength*5, Ticks), Certainties: (B, Ticks)
|
| 372 |
+
predictions_raw, certainties, synchronisation = model(inputs)
|
| 373 |
+
# Reshape predictions: (B, SeqLength, 5, Ticks)
|
| 374 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))
|
| 375 |
+
loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=True)
|
| 376 |
+
# Accuracy uses predictions[B, S, C, T] indexed at where_most_certain[B] -> gives (B, S, C) -> argmax(2) -> (B,S)
|
| 377 |
+
accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()
|
| 378 |
+
|
| 379 |
+
elif args.model == 'lstm':
|
| 380 |
+
# LSTM output: (B, SeqLength*5, Ticks), Certainties: (B, Ticks)
|
| 381 |
+
predictions_raw, certainties, synchronisation = model(inputs)
|
| 382 |
+
# Reshape predictions: (B, SeqLength, 5, Ticks)
|
| 383 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1))
|
| 384 |
+
loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False)
|
| 385 |
+
# where_most_certain should be -1 (last tick) here. Accuracy calc follows same logic.
|
| 386 |
+
accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] == targets).float().mean().item()
|
| 387 |
+
|
| 388 |
+
elif args.model == 'ff':
|
| 389 |
+
# Assume FF output: (B, SeqLength*5)
|
| 390 |
+
predictions_raw = model(inputs)
|
| 391 |
+
# Reshape predictions: (B, SeqLength, 5)
|
| 392 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5)
|
| 393 |
+
# FF has no certainties, pass None. maze_loss must handle this.
|
| 394 |
+
# Unsqueeze predictions for compatibility with maze loss calcluation
|
| 395 |
+
loss, where_most_certain, upto_where = maze_loss(predictions.unsqueeze(-1), None, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False)
|
| 396 |
+
# where_most_certain should be -1 here. Accuracy uses 3D prediction tensor.
|
| 397 |
+
accuracy_finegrained = (predictions.argmax(2) == targets).float().mean().item()
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# Extract stats from loss outputs if they are tensors
|
| 401 |
+
if torch.is_tensor(where_most_certain):
|
| 402 |
+
where_most_certain_val = where_most_certain.float().mean().item()
|
| 403 |
+
where_most_certain_std = where_most_certain.float().std().item()
|
| 404 |
+
where_most_certain_min = where_most_certain.min().item()
|
| 405 |
+
where_most_certain_max = where_most_certain.max().item()
|
| 406 |
+
elif isinstance(where_most_certain, int): # Handle case where it might return -1 directly
|
| 407 |
+
where_most_certain_val = float(where_most_certain)
|
| 408 |
+
where_most_certain_min = where_most_certain
|
| 409 |
+
where_most_certain_max = where_most_certain
|
| 410 |
+
|
| 411 |
+
if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0: # Check if it's a list/array
|
| 412 |
+
upto_where_mean = np.mean(upto_where)
|
| 413 |
+
upto_where_std = np.std(upto_where)
|
| 414 |
+
upto_where_min = np.min(upto_where)
|
| 415 |
+
upto_where_max = np.max(upto_where)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
scaler.scale(loss).backward()
|
| 419 |
+
|
| 420 |
+
if args.gradient_clipping!=-1:
|
| 421 |
+
scaler.unscale_(optimizer)
|
| 422 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
|
| 423 |
+
|
| 424 |
+
scaler.step(optimizer)
|
| 425 |
+
scaler.update()
|
| 426 |
+
optimizer.zero_grad(set_to_none=True)
|
| 427 |
+
scheduler.step()
|
| 428 |
+
|
| 429 |
+
# Conditional Tqdm Description
|
| 430 |
+
pbar_desc = f'Loss={loss.item():0.3f}. Acc(step)={accuracy_finegrained:0.3f}. LR={current_lr:0.6f}.'
|
| 431 |
+
if args.model in ['ctm', 'lstm'] or torch.is_tensor(where_most_certain): # Show stats if available
|
| 432 |
+
pbar_desc += f' Where_certain={where_most_certain_val:0.2f}+-{where_most_certain_std:0.2f} ({where_most_certain_min:d}<->{where_most_certain_max:d}).'
|
| 433 |
+
if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
|
| 434 |
+
pbar_desc += f' Path pred stats: {upto_where_mean:0.2f}+-{upto_where_std:0.2f} ({upto_where_min:d} --> {upto_where_max:d})'
|
| 435 |
+
|
| 436 |
+
pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
# Metrics tracking and plotting
|
| 440 |
+
if bi%args.track_every==0 and (bi != 0 or args.reload_model_only):
|
| 441 |
+
model.eval() # Use eval mode for consistency during tracking
|
| 442 |
+
with torch.inference_mode(): # Use inference mode for tracking
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
# --- Quantitative Metrics ---
|
| 448 |
+
iters.append(bi)
|
| 449 |
+
# Re-initialize metric lists for this evaluation step
|
| 450 |
+
current_train_losses_eval = []
|
| 451 |
+
current_test_losses_eval = []
|
| 452 |
+
current_train_accuracies_eval = []
|
| 453 |
+
current_test_accuracies_eval = []
|
| 454 |
+
current_train_accuracies_most_certain_eval = []
|
| 455 |
+
current_test_accuracies_most_certain_eval = []
|
| 456 |
+
current_train_accuracies_most_certain_permaze_eval = []
|
| 457 |
+
current_test_accuracies_most_certain_permaze_eval = []
|
| 458 |
+
|
| 459 |
+
# TRAIN METRICS
|
| 460 |
+
pbar.set_description('Tracking: Computing TRAIN metrics')
|
| 461 |
+
loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test) # Use consistent num_workers
|
| 462 |
+
all_targets_list = []
|
| 463 |
+
all_predictions_list = [] # Per step/tick predictions argmax (N, S, T) or (N, S)
|
| 464 |
+
all_predictions_most_certain_list = [] # Predictions at chosen step/tick argmax (N, S)
|
| 465 |
+
all_losses = []
|
| 466 |
+
|
| 467 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 468 |
+
for inferi, (inputs, targets) in enumerate(loader):
|
| 469 |
+
inputs = inputs.to(device)
|
| 470 |
+
targets = targets.to(device)
|
| 471 |
+
all_targets_list.append(targets.detach().cpu().numpy()) # N x S
|
| 472 |
+
|
| 473 |
+
# Model-specific forward, reshape, loss for evaluation
|
| 474 |
+
if args.model == 'ctm':
|
| 475 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 476 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 477 |
+
loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
|
| 478 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,C,T -> argmax class -> B,S,T
|
| 479 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S
|
| 480 |
+
all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
|
| 481 |
+
|
| 482 |
+
elif args.model == 'lstm':
|
| 483 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 484 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 485 |
+
loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
|
| 486 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,C,T
|
| 487 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S (at last tick)
|
| 488 |
+
all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
|
| 489 |
+
|
| 490 |
+
elif args.model == 'ff':
|
| 491 |
+
predictions_raw = model(inputs) # B, S*C
|
| 492 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
|
| 493 |
+
loss, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
|
| 494 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S
|
| 495 |
+
all_predictions_most_certain_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S (same as above for FF)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
all_losses.append(loss.item())
|
| 499 |
+
|
| 500 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break
|
| 501 |
+
pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})')
|
| 502 |
+
pbar_inner.update(1)
|
| 503 |
+
|
| 504 |
+
all_targets = np.concatenate(all_targets_list) # N, S
|
| 505 |
+
all_predictions = np.concatenate(all_predictions_list) # N, S, T or N, S
|
| 506 |
+
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list) # N, S
|
| 507 |
+
|
| 508 |
+
train_losses.append(np.mean(all_losses))
|
| 509 |
+
# Calculate per step/tick accuracy averaged over batches
|
| 510 |
+
if args.model in ['ctm', 'lstm']:
|
| 511 |
+
# all_predictions shape (N, S, T), all_targets shape (N, S) -> compare targets to each tick prediction
|
| 512 |
+
train_accuracies.append(np.mean(all_predictions == all_targets[:,:,np.newaxis], axis=0)) # Mean over N -> (S, T)
|
| 513 |
+
else: # FF
|
| 514 |
+
# all_predictions shape (N, S), all_targets shape (N, S)
|
| 515 |
+
train_accuracies.append(np.mean(all_predictions == all_targets, axis=0)) # Mean over N -> (S,)
|
| 516 |
+
|
| 517 |
+
# Calculate accuracy at chosen step/tick ("most certain") averaged over all steps and batches
|
| 518 |
+
train_accuracies_most_certain.append((all_targets == all_predictions_most_certain).mean()) # Scalar
|
| 519 |
+
# Calculate full maze accuracy at chosen step/tick averaged over batches
|
| 520 |
+
train_accuracies_most_certain_permaze.append((all_targets == all_predictions_most_certain).reshape(all_targets.shape[0], -1).all(-1).mean()) # Scalar
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
# TEST METRICS
|
| 524 |
+
pbar.set_description('Tracking: Computing TEST metrics')
|
| 525 |
+
loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
|
| 526 |
+
all_targets_list = []
|
| 527 |
+
all_predictions_list = []
|
| 528 |
+
all_predictions_most_certain_list = []
|
| 529 |
+
all_losses = []
|
| 530 |
+
|
| 531 |
+
with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
|
| 532 |
+
for inferi, (inputs, targets) in enumerate(loader):
|
| 533 |
+
inputs = inputs.to(device)
|
| 534 |
+
targets = targets.to(device)
|
| 535 |
+
all_targets_list.append(targets.detach().cpu().numpy())
|
| 536 |
+
|
| 537 |
+
# Model-specific forward, reshape, loss for evaluation
|
| 538 |
+
if args.model == 'ctm':
|
| 539 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 540 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 541 |
+
loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
|
| 542 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,T
|
| 543 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S
|
| 544 |
+
all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
|
| 545 |
+
|
| 546 |
+
elif args.model == 'lstm':
|
| 547 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 548 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 549 |
+
loss, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
|
| 550 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S,T
|
| 551 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device), :, where_most_certain] # B,S (at last tick)
|
| 552 |
+
all_predictions_most_certain_list.append(pred_at_certain.detach().cpu().numpy())
|
| 553 |
+
|
| 554 |
+
elif args.model == 'ff':
|
| 555 |
+
predictions_raw = model(inputs) # B, S*C
|
| 556 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
|
| 557 |
+
loss, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
|
| 558 |
+
all_predictions_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S
|
| 559 |
+
all_predictions_most_certain_list.append(predictions.argmax(2).detach().cpu().numpy()) # B,S (same as above for FF)
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
all_losses.append(loss.item())
|
| 563 |
+
|
| 564 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 565 |
+
pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})')
|
| 566 |
+
pbar_inner.update(1)
|
| 567 |
+
|
| 568 |
+
all_targets = np.concatenate(all_targets_list)
|
| 569 |
+
all_predictions = np.concatenate(all_predictions_list)
|
| 570 |
+
all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
|
| 571 |
+
|
| 572 |
+
test_losses.append(np.mean(all_losses))
|
| 573 |
+
# Calculate per step/tick accuracy
|
| 574 |
+
if args.model in ['ctm', 'lstm']:
|
| 575 |
+
test_accuracies.append(np.mean(all_predictions == all_targets[:,:,np.newaxis], axis=0)) # -> (S, T)
|
| 576 |
+
else: # FF
|
| 577 |
+
test_accuracies.append(np.mean(all_predictions == all_targets, axis=0)) # -> (S,)
|
| 578 |
+
|
| 579 |
+
# Calculate "most certain" accuracy
|
| 580 |
+
test_accuracies_most_certain.append((all_targets == all_predictions_most_certain).mean()) # Scalar
|
| 581 |
+
# Calculate full maze accuracy
|
| 582 |
+
test_accuracies_most_certain_permaze.append((all_targets == all_predictions_most_certain).reshape(all_targets.shape[0], -1).all(-1).mean()) # Scalar
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
# --- Plotting ---
|
| 586 |
+
# Accuracy Plot (Handling different dimensions)
|
| 587 |
+
figacc = plt.figure(figsize=(10, 10))
|
| 588 |
+
axacc_train = figacc.add_subplot(211)
|
| 589 |
+
axacc_test = figacc.add_subplot(212)
|
| 590 |
+
cm = sns.color_palette("viridis", as_cmap=True)
|
| 591 |
+
|
| 592 |
+
# Plot per step/tick accuracy
|
| 593 |
+
# train_accuracies is List[(S, T)] or List[(S,)]
|
| 594 |
+
# We need to average over S dimension for plotting
|
| 595 |
+
train_acc_plot = [np.mean(acc_s) for acc_s in train_accuracies] # List[Scalar] or List[Scalar] after mean
|
| 596 |
+
test_acc_plot = [np.mean(acc_s) for acc_s in test_accuracies] # List[Scalar] or List[Scalar] after mean
|
| 597 |
+
|
| 598 |
+
axacc_train.plot(iters, train_acc_plot, 'g-', alpha=0.5, label='Avg Step Acc')
|
| 599 |
+
axacc_test.plot(iters, test_acc_plot, 'g-', alpha=0.5, label='Avg Step Acc')
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
# Plot most certain accuracy
|
| 603 |
+
axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most Certain (Avg Step)')
|
| 604 |
+
axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most Certain (Avg Step)')
|
| 605 |
+
# Plot full maze accuracy
|
| 606 |
+
axacc_train.plot(iters, train_accuracies_most_certain_permaze, 'r-', alpha=0.6, label='Full Maze')
|
| 607 |
+
axacc_test.plot(iters, test_accuracies_most_certain_permaze, 'r-', alpha=0.6, label='Full Maze')
|
| 608 |
+
|
| 609 |
+
axacc_train.set_title('Train Accuracy')
|
| 610 |
+
axacc_test.set_title('Test Accuracy')
|
| 611 |
+
axacc_train.legend(loc='lower right')
|
| 612 |
+
axacc_test.legend(loc='lower right')
|
| 613 |
+
axacc_train.set_xlim([0, args.training_iterations])
|
| 614 |
+
axacc_test.set_xlim([0, args.training_iterations])
|
| 615 |
+
axacc_train.set_ylim([0, 1]) # Set Ylim for accuracy
|
| 616 |
+
axacc_test.set_ylim([0, 1])
|
| 617 |
+
|
| 618 |
+
figacc.tight_layout()
|
| 619 |
+
figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
|
| 620 |
+
plt.close(figacc)
|
| 621 |
+
|
| 622 |
+
# Loss Plot
|
| 623 |
+
figloss = plt.figure(figsize=(10, 5))
|
| 624 |
+
axloss = figloss.add_subplot(111)
|
| 625 |
+
axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}')
|
| 626 |
+
axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}')
|
| 627 |
+
axloss.legend(loc='upper right')
|
| 628 |
+
axloss.set_xlim([0, args.training_iterations])
|
| 629 |
+
axloss.set_ylim(bottom=0)
|
| 630 |
+
|
| 631 |
+
figloss.tight_layout()
|
| 632 |
+
figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
|
| 633 |
+
plt.close(figloss)
|
| 634 |
+
|
| 635 |
+
# --- Visualization Section (Conditional) ---
|
| 636 |
+
if args.model in ['ctm', 'lstm']:
|
| 637 |
+
# try:
|
| 638 |
+
inputs_viz, targets_viz = next(iter(testloader))
|
| 639 |
+
inputs_viz = inputs_viz.to(device)
|
| 640 |
+
targets_viz = targets_viz.to(device)
|
| 641 |
+
# Find longest path in batch for potentially better visualization
|
| 642 |
+
longest_index = (targets_viz!=4).sum(-1).argmax() # Action 4 assumed padding/end
|
| 643 |
+
|
| 644 |
+
# Track internal states
|
| 645 |
+
predictions_viz_raw, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
|
| 646 |
+
|
| 647 |
+
# Reshape predictions (assuming raw is B, D, T)
|
| 648 |
+
predictions_viz = predictions_viz_raw.reshape(predictions_viz_raw.size(0), -1, 5, predictions_viz_raw.size(-1)) # B, S, C, T
|
| 649 |
+
|
| 650 |
+
att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
|
| 651 |
+
attention_tracking_viz = attention_tracking_viz.reshape(
|
| 652 |
+
attention_tracking_viz.shape[0],
|
| 653 |
+
attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
|
| 654 |
+
|
| 655 |
+
# Plot dynamics (common plotting function)
|
| 656 |
+
plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
|
| 657 |
+
|
| 658 |
+
# Create maze GIF (task-specific plotting)
|
| 659 |
+
make_maze_gif((inputs_viz[longest_index].detach().cpu().numpy()+1)/2,
|
| 660 |
+
predictions_viz[longest_index].detach().cpu().numpy(), # Pass reshaped B,S,C,T -> S,C,T
|
| 661 |
+
targets_viz[longest_index].detach().cpu().numpy(), # S
|
| 662 |
+
attention_tracking_viz[:, longest_index], # Pass T, (H), H, W
|
| 663 |
+
args.log_dir)
|
| 664 |
+
# except Exception as e:
|
| 665 |
+
# print(f"Visualization failed for model {args.model}: {e}")
|
| 666 |
+
# --- End Visualization ---
|
| 667 |
+
|
| 668 |
+
model.train() # Switch back to train mode
|
| 669 |
+
|
| 670 |
+
|
| 671 |
+
# Save model checkpoint
|
| 672 |
+
if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
|
| 673 |
+
pbar.set_description('Saving model checkpoint...')
|
| 674 |
+
checkpoint_data = {
|
| 675 |
+
'model_state_dict': model.state_dict(),
|
| 676 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 677 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 678 |
+
'scaler_state_dict': scaler.state_dict(), # Save scaler state
|
| 679 |
+
'iteration': bi,
|
| 680 |
+
# Save all tracked metrics
|
| 681 |
+
'train_losses': train_losses,
|
| 682 |
+
'test_losses': test_losses,
|
| 683 |
+
'train_accuracies': train_accuracies, # List of (S, T) or (S,) arrays
|
| 684 |
+
'test_accuracies': test_accuracies, # List of (S, T) or (S,) arrays
|
| 685 |
+
'train_accuracies_most_certain': train_accuracies_most_certain, # List of scalars
|
| 686 |
+
'test_accuracies_most_certain': test_accuracies_most_certain, # List of scalars
|
| 687 |
+
'train_accuracies_most_certain_permaze': train_accuracies_most_certain_permaze, # List of scalars
|
| 688 |
+
'test_accuracies_most_certain_permaze': test_accuracies_most_certain_permaze, # List of scalars
|
| 689 |
+
'iters': iters,
|
| 690 |
+
'args': args, # Save args used for this run
|
| 691 |
+
# RNG states
|
| 692 |
+
'torch_rng_state': torch.get_rng_state(),
|
| 693 |
+
'numpy_rng_state': np.random.get_state(),
|
| 694 |
+
'random_rng_state': random.getstate(),
|
| 695 |
+
}
|
| 696 |
+
torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
|
| 697 |
+
|
| 698 |
+
pbar.update(1)
|
tasks/mazes/train_distributed.py
ADDED
|
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import gc
|
| 5 |
+
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import numpy as np
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
sns.set_style('darkgrid')
|
| 10 |
+
import torch
|
| 11 |
+
if torch.cuda.is_available():
|
| 12 |
+
# For faster
|
| 13 |
+
torch.set_float32_matmul_precision('high')
|
| 14 |
+
import torch.distributed as dist
|
| 15 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 16 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 17 |
+
from utils.samplers import FastRandomDistributedSampler
|
| 18 |
+
from tqdm.auto import tqdm
|
| 19 |
+
|
| 20 |
+
# Data/Task Specific Imports
|
| 21 |
+
from data.custom_datasets import MazeImageFolder
|
| 22 |
+
|
| 23 |
+
# Model Imports
|
| 24 |
+
from models.ctm import ContinuousThoughtMachine
|
| 25 |
+
from models.lstm import LSTMBaseline
|
| 26 |
+
from models.ff import FFBaseline
|
| 27 |
+
|
| 28 |
+
# Plotting/Utils Imports
|
| 29 |
+
from tasks.mazes.plotting import make_maze_gif
|
| 30 |
+
from tasks.image_classification.plotting import plot_neural_dynamics
|
| 31 |
+
from utils.housekeeping import set_seed, zip_python_code
|
| 32 |
+
from utils.losses import maze_loss
|
| 33 |
+
from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
|
| 34 |
+
|
| 35 |
+
import torchvision
|
| 36 |
+
torchvision.disable_beta_transforms_warning()
|
| 37 |
+
|
| 38 |
+
import warnings
|
| 39 |
+
warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
|
| 40 |
+
warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
|
| 41 |
+
warnings.filterwarnings(
|
| 42 |
+
"ignore",
|
| 43 |
+
"Corrupt EXIF data",
|
| 44 |
+
UserWarning,
|
| 45 |
+
r"^PIL\.TiffImagePlugin$"
|
| 46 |
+
)
|
| 47 |
+
warnings.filterwarnings(
|
| 48 |
+
"ignore",
|
| 49 |
+
"UserWarning: Metadata Warning",
|
| 50 |
+
UserWarning,
|
| 51 |
+
r"^PIL\.TiffImagePlugin$"
|
| 52 |
+
)
|
| 53 |
+
warnings.filterwarnings(
|
| 54 |
+
"ignore",
|
| 55 |
+
"UserWarning: Truncated File Read",
|
| 56 |
+
UserWarning,
|
| 57 |
+
r"^PIL\.TiffImagePlugin$"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def parse_args():
|
| 62 |
+
parser = argparse.ArgumentParser()
|
| 63 |
+
|
| 64 |
+
# Model Selection
|
| 65 |
+
parser.add_argument('--model', type=str, required=True, choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
|
| 66 |
+
|
| 67 |
+
# Model Architecture
|
| 68 |
+
parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
|
| 69 |
+
parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
|
| 70 |
+
parser.add_argument('--backbone_type', type=str, default='resnet34-2', help='Type of backbone featureiser.')
|
| 71 |
+
# CTM / LSTM specific
|
| 72 |
+
parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
|
| 73 |
+
parser.add_argument('--heads', type=int, default=8, help='Number of attention heads (CTM, LSTM).')
|
| 74 |
+
parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
|
| 75 |
+
parser.add_argument('--positional_embedding_type', type=str, default='none',
|
| 76 |
+
help='Type of positional embedding (CTM, LSTM).', choices=['none',
|
| 77 |
+
'learnable-fourier',
|
| 78 |
+
'multi-learnable-fourier',
|
| 79 |
+
'custom-rotational'])
|
| 80 |
+
# CTM specific
|
| 81 |
+
parser.add_argument('--synapse_depth', type=int, default=8, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
|
| 82 |
+
parser.add_argument('--n_synch_out', type=int, default=32, help='Number of neurons to use for output synch (CTM only).')
|
| 83 |
+
parser.add_argument('--n_synch_action', type=int, default=32, help='Number of neurons to use for observation/action synch (CTM only).')
|
| 84 |
+
parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
|
| 85 |
+
parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
|
| 86 |
+
parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
|
| 87 |
+
parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
|
| 88 |
+
parser.add_argument('--memory_hidden_dims', type=int, default=32, help='Hidden dimensions of the memory if using deep memory (CTM only).')
|
| 89 |
+
parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
|
| 90 |
+
parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
|
| 91 |
+
# LSTM specific
|
| 92 |
+
parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
|
| 93 |
+
|
| 94 |
+
# Task Specific Args
|
| 95 |
+
parser.add_argument('--maze_route_length', type=int, default=100, help='Length to truncate targets.')
|
| 96 |
+
parser.add_argument('--cirriculum_lookahead', type=int, default=5, help='How far to look ahead for cirriculum.')
|
| 97 |
+
|
| 98 |
+
# Training
|
| 99 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for training (per GPU).')
|
| 100 |
+
parser.add_argument('--batch_size_test', type=int, default=64, help='Batch size for testing (per GPU).')
|
| 101 |
+
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate for the model.')
|
| 102 |
+
parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
|
| 103 |
+
parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
|
| 104 |
+
parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
|
| 105 |
+
parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
|
| 106 |
+
parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
|
| 107 |
+
parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
|
| 108 |
+
parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
|
| 109 |
+
parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
|
| 110 |
+
parser.add_argument('--num_workers_train', type=int, default=0, help='Num workers training.')
|
| 111 |
+
parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
|
| 112 |
+
parser.add_argument('--use_custom_sampler', action=argparse.BooleanOptionalAction, default=False, help='Use custom fast sampler to avoid reshuffling.')
|
| 113 |
+
parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components.')
|
| 114 |
+
|
| 115 |
+
# Logging and Saving
|
| 116 |
+
parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
|
| 117 |
+
parser.add_argument('--dataset', type=str, default='mazes-medium', help='Dataset to use.', choices=['mazes-medium', 'mazes-large'])
|
| 118 |
+
parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
|
| 119 |
+
parser.add_argument('--seed', type=int, default=412, help='Random seed.')
|
| 120 |
+
parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
|
| 121 |
+
parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?') # Default False based on user edit
|
| 122 |
+
parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=False, help='Should use strict reload for model weights.')
|
| 123 |
+
parser.add_argument('--ignore_metrics_when_reloading', action=argparse.BooleanOptionalAction, default=False, help='Ignore metrics when reloading (for debugging)?')
|
| 124 |
+
|
| 125 |
+
# Tracking
|
| 126 |
+
parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
|
| 127 |
+
parser.add_argument('--n_test_batches', type=int, default=2, help='How many minibatches to approx metrics. Set to -1 for full eval')
|
| 128 |
+
|
| 129 |
+
# Precision
|
| 130 |
+
parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
|
| 131 |
+
|
| 132 |
+
args = parser.parse_args()
|
| 133 |
+
return args
|
| 134 |
+
|
| 135 |
+
# --- DDP Setup Functions ---
|
| 136 |
+
def setup_ddp():
|
| 137 |
+
if 'RANK' not in os.environ:
|
| 138 |
+
os.environ['RANK'] = '0'
|
| 139 |
+
os.environ['WORLD_SIZE'] = '1'
|
| 140 |
+
os.environ['MASTER_ADDR'] = 'localhost'
|
| 141 |
+
os.environ['MASTER_PORT'] = '12356' # Different port from image classification
|
| 142 |
+
os.environ['LOCAL_RANK'] = '0'
|
| 143 |
+
print("Running in non-distributed mode (simulated DDP setup).")
|
| 144 |
+
if not torch.cuda.is_available() or int(os.environ['WORLD_SIZE']) == 1:
|
| 145 |
+
dist.init_process_group(backend='gloo')
|
| 146 |
+
print("Initialized process group with Gloo backend for single/CPU process.")
|
| 147 |
+
rank = int(os.environ['RANK'])
|
| 148 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 149 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 150 |
+
return rank, world_size, local_rank
|
| 151 |
+
|
| 152 |
+
dist.init_process_group(backend='nccl')
|
| 153 |
+
rank = int(os.environ['RANK'])
|
| 154 |
+
world_size = int(os.environ['WORLD_SIZE'])
|
| 155 |
+
local_rank = int(os.environ['LOCAL_RANK'])
|
| 156 |
+
if torch.cuda.is_available():
|
| 157 |
+
torch.cuda.set_device(local_rank)
|
| 158 |
+
print(f"Rank {rank} setup on GPU {local_rank}")
|
| 159 |
+
else:
|
| 160 |
+
print(f"Rank {rank} setup on CPU")
|
| 161 |
+
return rank, world_size, local_rank
|
| 162 |
+
|
| 163 |
+
def cleanup_ddp():
|
| 164 |
+
if dist.is_initialized():
|
| 165 |
+
dist.destroy_process_group()
|
| 166 |
+
print("DDP cleanup complete.")
|
| 167 |
+
|
| 168 |
+
def is_main_process(rank):
|
| 169 |
+
return rank == 0
|
| 170 |
+
# --- End DDP Setup ---
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if __name__=='__main__':
|
| 174 |
+
|
| 175 |
+
args = parse_args()
|
| 176 |
+
|
| 177 |
+
rank, world_size, local_rank = setup_ddp()
|
| 178 |
+
|
| 179 |
+
set_seed(args.seed + rank, False)
|
| 180 |
+
|
| 181 |
+
# Rank 0 handles directory creation and initial logging
|
| 182 |
+
if is_main_process(rank):
|
| 183 |
+
if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
|
| 184 |
+
zip_python_code(f'{args.log_dir}/repo_state.zip')
|
| 185 |
+
with open(f'{args.log_dir}/args.txt', 'w') as f:
|
| 186 |
+
print(args, file=f)
|
| 187 |
+
if world_size > 1: dist.barrier()
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
assert args.dataset in ['mazes-medium', 'mazes-large']
|
| 191 |
+
|
| 192 |
+
# Setup Device
|
| 193 |
+
if torch.cuda.is_available():
|
| 194 |
+
device = torch.device(f'cuda:{local_rank}')
|
| 195 |
+
else:
|
| 196 |
+
device = torch.device('cpu')
|
| 197 |
+
if world_size > 1: warnings.warn("Running DDP on CPU is not recommended.")
|
| 198 |
+
|
| 199 |
+
if is_main_process(rank):
|
| 200 |
+
print(f'Main process (Rank {rank}): Using device {device}. World size: {world_size}. Model: {args.model}')
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
prediction_reshaper = [args.maze_route_length, 5]
|
| 204 |
+
args.out_dims = args.maze_route_length * 5
|
| 205 |
+
|
| 206 |
+
# --- Model Definition (Conditional) ---
|
| 207 |
+
model_base = None # Base model before DDP wrapping
|
| 208 |
+
if args.model == 'ctm':
|
| 209 |
+
model_base = ContinuousThoughtMachine(
|
| 210 |
+
iterations=args.iterations,
|
| 211 |
+
d_model=args.d_model,
|
| 212 |
+
d_input=args.d_input,
|
| 213 |
+
heads=args.heads,
|
| 214 |
+
n_synch_out=args.n_synch_out,
|
| 215 |
+
n_synch_action=args.n_synch_action,
|
| 216 |
+
synapse_depth=args.synapse_depth,
|
| 217 |
+
memory_length=args.memory_length,
|
| 218 |
+
deep_nlms=args.deep_memory,
|
| 219 |
+
memory_hidden_dims=args.memory_hidden_dims,
|
| 220 |
+
do_layernorm_nlm=args.do_normalisation,
|
| 221 |
+
backbone_type=args.backbone_type,
|
| 222 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 223 |
+
out_dims=args.out_dims,
|
| 224 |
+
prediction_reshaper=prediction_reshaper,
|
| 225 |
+
dropout=args.dropout,
|
| 226 |
+
dropout_nlm=args.dropout_nlm,
|
| 227 |
+
neuron_select_type=args.neuron_select_type,
|
| 228 |
+
n_random_pairing_self=args.n_random_pairing_self,
|
| 229 |
+
).to(device)
|
| 230 |
+
elif args.model == 'lstm':
|
| 231 |
+
model_base = LSTMBaseline(
|
| 232 |
+
num_layers=args.num_layers,
|
| 233 |
+
iterations=args.iterations,
|
| 234 |
+
d_model=args.d_model,
|
| 235 |
+
d_input=args.d_input,
|
| 236 |
+
heads=args.heads,
|
| 237 |
+
backbone_type=args.backbone_type,
|
| 238 |
+
positional_embedding_type=args.positional_embedding_type,
|
| 239 |
+
out_dims=args.out_dims,
|
| 240 |
+
prediction_reshaper=prediction_reshaper,
|
| 241 |
+
dropout=args.dropout,
|
| 242 |
+
).to(device)
|
| 243 |
+
elif args.model == 'ff':
|
| 244 |
+
model_base = FFBaseline(
|
| 245 |
+
d_model=args.d_model,
|
| 246 |
+
backbone_type=args.backbone_type,
|
| 247 |
+
out_dims=args.out_dims,
|
| 248 |
+
dropout=args.dropout,
|
| 249 |
+
).to(device)
|
| 250 |
+
else:
|
| 251 |
+
raise ValueError(f"Unknown model type: {args.model}")
|
| 252 |
+
|
| 253 |
+
# Use pseudo-input *before* DDP wrapping
|
| 254 |
+
try:
|
| 255 |
+
# Determine pseudo input shape based on dataset
|
| 256 |
+
h_w = 39 if args.dataset in ['mazes-small', 'mazes-medium'] else 99 # Example dimensions
|
| 257 |
+
pseudo_inputs = torch.zeros((1, 3, h_w, h_w), device=device).float()
|
| 258 |
+
model_base(pseudo_inputs)
|
| 259 |
+
except Exception as e:
|
| 260 |
+
print(f"Warning: Pseudo forward pass failed: {e}")
|
| 261 |
+
|
| 262 |
+
if is_main_process(rank):
|
| 263 |
+
print(f'Total params: {sum(p.numel() for p in model_base.parameters() if p.requires_grad)}')
|
| 264 |
+
|
| 265 |
+
# Wrap model with DDP
|
| 266 |
+
if device.type == 'cuda' and world_size > 1:
|
| 267 |
+
model = DDP(model_base, device_ids=[local_rank], output_device=local_rank)
|
| 268 |
+
elif device.type == 'cpu' and world_size > 1:
|
| 269 |
+
model = DDP(model_base)
|
| 270 |
+
else:
|
| 271 |
+
model = model_base
|
| 272 |
+
# --- End Model Definition ---
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# Data Loading (After model setup to allow pseudo pass first)
|
| 276 |
+
dataset_mean = [0,0,0]
|
| 277 |
+
dataset_std = [1,1,1]
|
| 278 |
+
which_maze = args.dataset.split('-')[-1]
|
| 279 |
+
data_root = f'data/mazes/{which_maze}'
|
| 280 |
+
|
| 281 |
+
train_data = MazeImageFolder(root=f'{data_root}/train/', which_set='train', maze_route_length=args.maze_route_length)
|
| 282 |
+
test_data = MazeImageFolder(root=f'{data_root}/test/', which_set='test', maze_route_length=args.maze_route_length)
|
| 283 |
+
|
| 284 |
+
train_sampler = (FastRandomDistributedSampler(train_data, num_replicas=world_size, rank=rank, seed=args.seed, epoch_steps=int(10e10))
|
| 285 |
+
if args.use_custom_sampler else
|
| 286 |
+
DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=True, seed=args.seed))
|
| 287 |
+
test_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=rank, shuffle=False, seed=args.seed)
|
| 288 |
+
|
| 289 |
+
num_workers_test = 1
|
| 290 |
+
trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, sampler=train_sampler,
|
| 291 |
+
num_workers=args.num_workers_train, pin_memory=True, drop_last=True)
|
| 292 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, sampler=test_sampler,
|
| 293 |
+
num_workers=num_workers_test, pin_memory=True, drop_last=False)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# Optimizer and scheduler
|
| 297 |
+
decay_params = []
|
| 298 |
+
no_decay_params = []
|
| 299 |
+
no_decay_names = []
|
| 300 |
+
for name, param in model.named_parameters():
|
| 301 |
+
if not param.requires_grad:
|
| 302 |
+
continue # Skip parameters that don't require gradients
|
| 303 |
+
if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
|
| 304 |
+
no_decay_params.append(param)
|
| 305 |
+
no_decay_names.append(name)
|
| 306 |
+
else:
|
| 307 |
+
decay_params.append(param)
|
| 308 |
+
if len(no_decay_names) and is_main_process(rank):
|
| 309 |
+
print(f'WARNING, excluding: {no_decay_names}')
|
| 310 |
+
|
| 311 |
+
# Optimizer and scheduler (Common setup)
|
| 312 |
+
if len(no_decay_names) and args.weight_decay!=0:
|
| 313 |
+
optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
|
| 314 |
+
{'params': no_decay_params, 'weight_decay':0}],
|
| 315 |
+
lr=args.lr,
|
| 316 |
+
eps=1e-8 if not args.use_amp else 1e-6)
|
| 317 |
+
else:
|
| 318 |
+
optimizer = torch.optim.AdamW(model.parameters(),
|
| 319 |
+
lr=args.lr,
|
| 320 |
+
eps=1e-8 if not args.use_amp else 1e-6,
|
| 321 |
+
weight_decay=args.weight_decay)
|
| 322 |
+
|
| 323 |
+
warmup_schedule = warmup(args.warmup_steps)
|
| 324 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
|
| 325 |
+
if args.use_scheduler:
|
| 326 |
+
if args.scheduler_type == 'multistep':
|
| 327 |
+
scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
|
| 328 |
+
elif args.scheduler_type == 'cosine':
|
| 329 |
+
scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
|
| 330 |
+
else:
|
| 331 |
+
raise NotImplementedError
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# Metrics tracking (Rank 0 stores history)
|
| 335 |
+
start_iter = 0
|
| 336 |
+
iters = []
|
| 337 |
+
train_losses, test_losses = [], []
|
| 338 |
+
train_accuracies, test_accuracies = [], [] # Avg Step Acc (scalar list)
|
| 339 |
+
train_accuracies_most_certain, test_accuracies_most_certain = [], [] # Avg Step Acc @ Certain tick (scalar list)
|
| 340 |
+
train_accuracies_most_certain_permaze, test_accuracies_most_certain_permaze = [], [] # Full Maze Acc @ Certain tick (scalar list)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
scaler = torch.amp.GradScaler("cuda" if device.type == 'cuda' else "cpu", enabled=args.use_amp)
|
| 344 |
+
|
| 345 |
+
# Reloading Logic
|
| 346 |
+
if args.reload:
|
| 347 |
+
map_location = device
|
| 348 |
+
chkpt_path = f'{args.log_dir}/checkpoint.pt'
|
| 349 |
+
if os.path.isfile(chkpt_path):
|
| 350 |
+
print(f'Rank {rank}: Reloading from: {chkpt_path}')
|
| 351 |
+
if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
|
| 352 |
+
|
| 353 |
+
checkpoint = torch.load(chkpt_path, map_location=map_location, weights_only=False)
|
| 354 |
+
|
| 355 |
+
model_to_load = model.module if isinstance(model, DDP) else model
|
| 356 |
+
state_dict = checkpoint['model_state_dict']
|
| 357 |
+
has_module_prefix = all(k.startswith('module.') for k in state_dict)
|
| 358 |
+
is_wrapped = isinstance(model, DDP)
|
| 359 |
+
|
| 360 |
+
if has_module_prefix and not is_wrapped:
|
| 361 |
+
state_dict = {k.partition('module.')[2]: v for k,v in state_dict.items()}
|
| 362 |
+
elif not has_module_prefix and is_wrapped:
|
| 363 |
+
load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
|
| 364 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 365 |
+
state_dict = None # Prevent loading again
|
| 366 |
+
|
| 367 |
+
if state_dict is not None:
|
| 368 |
+
load_result = model_to_load.load_state_dict(state_dict, strict=args.strict_reload)
|
| 369 |
+
print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
if not args.reload_model_only:
|
| 374 |
+
print(f'Rank {rank}: Reloading optimizer, scheduler, scaler, iteration.')
|
| 375 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 376 |
+
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| 377 |
+
scaler.load_state_dict(checkpoint['scaler_state_dict'])
|
| 378 |
+
start_iter = checkpoint['iteration']
|
| 379 |
+
|
| 380 |
+
if is_main_process(rank) and not args.ignore_metrics_when_reloading:
|
| 381 |
+
print(f'Rank {rank}: Reloading metrics history.')
|
| 382 |
+
iters = checkpoint['iters']
|
| 383 |
+
train_losses = checkpoint['train_losses']
|
| 384 |
+
test_losses = checkpoint['test_losses']
|
| 385 |
+
train_accuracies = checkpoint['train_accuracies'] # Reloading simplified avg step acc list
|
| 386 |
+
test_accuracies = checkpoint['test_accuracies']
|
| 387 |
+
train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
|
| 388 |
+
test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
|
| 389 |
+
train_accuracies_most_certain_permaze = checkpoint['train_accuracies_most_certain_permaze']
|
| 390 |
+
test_accuracies_most_certain_permaze = checkpoint['test_accuracies_most_certain_permaze']
|
| 391 |
+
elif is_main_process(rank) and args.ignore_metrics_when_reloading:
|
| 392 |
+
print(f'Rank {rank}: Ignoring metrics history upon reload.')
|
| 393 |
+
else:
|
| 394 |
+
print(f'Rank {rank}: Only reloading model weights!')
|
| 395 |
+
|
| 396 |
+
if is_main_process(rank) and 'torch_rng_state' in checkpoint and not args.reload_model_only:
|
| 397 |
+
print(f'Rank {rank}: Loading RNG states.')
|
| 398 |
+
torch.set_rng_state(checkpoint['torch_rng_state'].cpu())
|
| 399 |
+
np.random.set_state(checkpoint['numpy_rng_state'])
|
| 400 |
+
random.setstate(checkpoint['random_rng_state'])
|
| 401 |
+
|
| 402 |
+
del checkpoint
|
| 403 |
+
gc.collect()
|
| 404 |
+
if torch.cuda.is_available():
|
| 405 |
+
torch.cuda.empty_cache()
|
| 406 |
+
print(f"Rank {rank}: Reload finished, starting from iteration {start_iter}")
|
| 407 |
+
else:
|
| 408 |
+
print(f"Rank {rank}: Checkpoint not found at {chkpt_path}, starting from scratch.")
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
if world_size > 1: dist.barrier()
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# Conditional Compilation
|
| 415 |
+
if args.do_compile:
|
| 416 |
+
if is_main_process(rank): print('Compiling model components...')
|
| 417 |
+
model_to_compile = model.module if isinstance(model, DDP) else model
|
| 418 |
+
if hasattr(model_to_compile, 'backbone'):
|
| 419 |
+
model_to_compile.backbone = torch.compile(model_to_compile.backbone, mode='reduce-overhead', fullgraph=True)
|
| 420 |
+
if args.model == 'ctm':
|
| 421 |
+
model_to_compile.synapses = torch.compile(model_to_compile.synapses, mode='reduce-overhead', fullgraph=True)
|
| 422 |
+
if world_size > 1: dist.barrier()
|
| 423 |
+
if is_main_process(rank): print('Compilation finished.')
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
# --- Training Loop ---
|
| 427 |
+
model.train()
|
| 428 |
+
pbar = tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True, disable=not is_main_process(rank))
|
| 429 |
+
|
| 430 |
+
iterator = iter(trainloader)
|
| 431 |
+
|
| 432 |
+
for bi in range(start_iter, args.training_iterations):
|
| 433 |
+
|
| 434 |
+
# --- Evaluation and Plotting (Rank 0 + Aggregation) ---
|
| 435 |
+
if bi % args.track_every == 0 and (bi != 0 or args.reload_model_only):
|
| 436 |
+
model.eval()
|
| 437 |
+
with torch.inference_mode():
|
| 438 |
+
|
| 439 |
+
# --- Distributed Evaluation ---
|
| 440 |
+
if is_main_process(rank): iters.append(bi) # Track iterations on rank 0
|
| 441 |
+
|
| 442 |
+
# Initialize accumulators on device
|
| 443 |
+
total_train_loss = torch.tensor(0.0, device=device)
|
| 444 |
+
total_train_correct_certain = torch.tensor(0.0, device=device) # Sum correct steps @ certain tick
|
| 445 |
+
total_train_mazes_solved = torch.tensor(0.0, device=device) # Sum solved mazes @ certain tick
|
| 446 |
+
total_train_steps = torch.tensor(0.0, device=device) # Total steps evaluated (B * S)
|
| 447 |
+
total_train_mazes = torch.tensor(0.0, device=device) # Total mazes evaluated (B)
|
| 448 |
+
|
| 449 |
+
# TRAIN METRICS
|
| 450 |
+
train_eval_sampler = DistributedSampler(train_data, num_replicas=world_size, rank=rank, shuffle=False)
|
| 451 |
+
train_eval_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, sampler=train_eval_sampler, num_workers=num_workers_test, pin_memory=True)
|
| 452 |
+
|
| 453 |
+
pbar_inner_desc = 'Eval Train (Rank 0)' if is_main_process(rank) else None
|
| 454 |
+
with tqdm(total=len(train_eval_loader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
|
| 455 |
+
for inferi, (inputs, targets) in enumerate(train_eval_loader):
|
| 456 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 457 |
+
targets = targets.to(device, non_blocking=True) # B, S
|
| 458 |
+
batch_size = inputs.size(0)
|
| 459 |
+
seq_len = targets.size(1)
|
| 460 |
+
|
| 461 |
+
loss_eval = None
|
| 462 |
+
pred_at_certain = None # Shape B, S
|
| 463 |
+
if args.model == 'ctm':
|
| 464 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 465 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 466 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
|
| 467 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
|
| 468 |
+
elif args.model == 'lstm':
|
| 469 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 470 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 471 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False) # where = -1
|
| 472 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
|
| 473 |
+
elif args.model == 'ff':
|
| 474 |
+
predictions_raw = model(inputs) # B, S*C
|
| 475 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5) # B,S,C
|
| 476 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False) # where = -1
|
| 477 |
+
pred_at_certain = predictions.argmax(2)
|
| 478 |
+
|
| 479 |
+
# Accumulate metrics
|
| 480 |
+
total_train_loss += loss_eval * batch_size # Sum losses
|
| 481 |
+
correct_steps = (pred_at_certain == targets) # B, S boolean
|
| 482 |
+
total_train_correct_certain += correct_steps.sum() # Sum correct steps across batch
|
| 483 |
+
total_train_mazes_solved += correct_steps.all(dim=-1).sum() # Sum mazes where all steps are correct
|
| 484 |
+
total_train_steps += batch_size * seq_len
|
| 485 |
+
total_train_mazes += batch_size
|
| 486 |
+
|
| 487 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 488 |
+
pbar_inner.update(1)
|
| 489 |
+
|
| 490 |
+
# Aggregate Train Metrics
|
| 491 |
+
if world_size > 1:
|
| 492 |
+
dist.all_reduce(total_train_loss, op=dist.ReduceOp.SUM)
|
| 493 |
+
dist.all_reduce(total_train_correct_certain, op=dist.ReduceOp.SUM)
|
| 494 |
+
dist.all_reduce(total_train_mazes_solved, op=dist.ReduceOp.SUM)
|
| 495 |
+
dist.all_reduce(total_train_steps, op=dist.ReduceOp.SUM)
|
| 496 |
+
dist.all_reduce(total_train_mazes, op=dist.ReduceOp.SUM)
|
| 497 |
+
|
| 498 |
+
# Calculate final Train metrics on Rank 0
|
| 499 |
+
if is_main_process(rank) and total_train_mazes > 0:
|
| 500 |
+
avg_train_loss = total_train_loss.item() / total_train_mazes.item() # Avg loss per maze/sample
|
| 501 |
+
avg_train_acc_step = total_train_correct_certain.item() / total_train_steps.item() # Avg correct step %
|
| 502 |
+
avg_train_acc_maze = total_train_mazes_solved.item() / total_train_mazes.item() # Avg full maze solved %
|
| 503 |
+
train_losses.append(avg_train_loss)
|
| 504 |
+
train_accuracies_most_certain.append(avg_train_acc_step)
|
| 505 |
+
train_accuracies_most_certain_permaze.append(avg_train_acc_maze)
|
| 506 |
+
# train_accuracies list remains unused/placeholder for this simplified metric structure
|
| 507 |
+
print(f"Iter {bi} Train Metrics (Agg): Loss={avg_train_loss:.4f}, StepAcc={avg_train_acc_step:.4f}, MazeAcc={avg_train_acc_maze:.4f}")
|
| 508 |
+
|
| 509 |
+
# TEST METRICS
|
| 510 |
+
total_test_loss = torch.tensor(0.0, device=device)
|
| 511 |
+
total_test_correct_certain = torch.tensor(0.0, device=device)
|
| 512 |
+
total_test_mazes_solved = torch.tensor(0.0, device=device)
|
| 513 |
+
total_test_steps = torch.tensor(0.0, device=device)
|
| 514 |
+
total_test_mazes = torch.tensor(0.0, device=device)
|
| 515 |
+
|
| 516 |
+
pbar_inner_desc = 'Eval Test (Rank 0)' if is_main_process(rank) else None
|
| 517 |
+
with tqdm(total=len(testloader), desc=pbar_inner_desc, leave=False, position=1, dynamic_ncols=True, disable=not is_main_process(rank)) as pbar_inner:
|
| 518 |
+
for inferi, (inputs, targets) in enumerate(testloader):
|
| 519 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 520 |
+
targets = targets.to(device, non_blocking=True)
|
| 521 |
+
batch_size = inputs.size(0)
|
| 522 |
+
seq_len = targets.size(1)
|
| 523 |
+
|
| 524 |
+
loss_eval = None
|
| 525 |
+
pred_at_certain = None
|
| 526 |
+
if args.model == 'ctm':
|
| 527 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 528 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1))
|
| 529 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=True)
|
| 530 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
|
| 531 |
+
elif args.model == 'lstm':
|
| 532 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 533 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5, predictions_raw.size(-1))
|
| 534 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions, certainties, targets, use_most_certain=False)
|
| 535 |
+
pred_at_certain = predictions.argmax(2)[torch.arange(batch_size, device=device), :, where_most_certain]
|
| 536 |
+
elif args.model == 'ff':
|
| 537 |
+
predictions_raw = model(inputs)
|
| 538 |
+
predictions = predictions_raw.reshape(batch_size, -1, 5)
|
| 539 |
+
loss_eval, where_most_certain, _ = maze_loss(predictions.unsqueeze(-1), None, targets, use_most_certain=False)
|
| 540 |
+
pred_at_certain = predictions.argmax(2)
|
| 541 |
+
|
| 542 |
+
total_test_loss += loss_eval * batch_size
|
| 543 |
+
correct_steps = (pred_at_certain == targets)
|
| 544 |
+
total_test_correct_certain += correct_steps.sum()
|
| 545 |
+
total_test_mazes_solved += correct_steps.all(dim=-1).sum()
|
| 546 |
+
total_test_steps += batch_size * seq_len
|
| 547 |
+
total_test_mazes += batch_size
|
| 548 |
+
|
| 549 |
+
if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
|
| 550 |
+
pbar_inner.update(1)
|
| 551 |
+
|
| 552 |
+
# Aggregate Test Metrics
|
| 553 |
+
if world_size > 1:
|
| 554 |
+
dist.all_reduce(total_test_loss, op=dist.ReduceOp.SUM)
|
| 555 |
+
dist.all_reduce(total_test_correct_certain, op=dist.ReduceOp.SUM)
|
| 556 |
+
dist.all_reduce(total_test_mazes_solved, op=dist.ReduceOp.SUM)
|
| 557 |
+
dist.all_reduce(total_test_steps, op=dist.ReduceOp.SUM)
|
| 558 |
+
dist.all_reduce(total_test_mazes, op=dist.ReduceOp.SUM)
|
| 559 |
+
|
| 560 |
+
# Calculate and Plot final Test metrics on Rank 0
|
| 561 |
+
if is_main_process(rank) and total_test_mazes > 0:
|
| 562 |
+
avg_test_loss = total_test_loss.item() / total_test_mazes.item()
|
| 563 |
+
avg_test_acc_step = total_test_correct_certain.item() / total_test_steps.item()
|
| 564 |
+
avg_test_acc_maze = total_test_mazes_solved.item() / total_test_mazes.item()
|
| 565 |
+
test_losses.append(avg_test_loss)
|
| 566 |
+
test_accuracies_most_certain.append(avg_test_acc_step)
|
| 567 |
+
test_accuracies_most_certain_permaze.append(avg_test_acc_maze)
|
| 568 |
+
print(f"Iter {bi} Test Metrics (Agg): Loss={avg_test_loss:.4f}, StepAcc={avg_test_acc_step:.4f}, MazeAcc={avg_test_acc_maze:.4f}\n")
|
| 569 |
+
|
| 570 |
+
# --- Plotting ---
|
| 571 |
+
figacc = plt.figure(figsize=(10, 10))
|
| 572 |
+
axacc_train = figacc.add_subplot(211)
|
| 573 |
+
axacc_test = figacc.add_subplot(212)
|
| 574 |
+
|
| 575 |
+
# Plot Avg Step Accuracy
|
| 576 |
+
axacc_train.plot(iters, train_accuracies_most_certain, 'k-', alpha=0.7, label=f'Avg Step Acc ({train_accuracies_most_certain[-1]:.3f})')
|
| 577 |
+
axacc_test.plot(iters, test_accuracies_most_certain, 'k-', alpha=0.7, label=f'Avg Step Acc ({test_accuracies_most_certain[-1]:.3f})')
|
| 578 |
+
# Plot Full Maze Accuracy
|
| 579 |
+
axacc_train.plot(iters, train_accuracies_most_certain_permaze, 'r-', alpha=0.6, label=f'Full Maze Acc ({train_accuracies_most_certain_permaze[-1]:.3f})')
|
| 580 |
+
axacc_test.plot(iters, test_accuracies_most_certain_permaze, 'r-', alpha=0.6, label=f'Full Maze Acc ({test_accuracies_most_certain_permaze[-1]:.3f})')
|
| 581 |
+
|
| 582 |
+
axacc_train.set_title('Train Accuracy (Aggregated)')
|
| 583 |
+
axacc_test.set_title('Test Accuracy (Aggregated)')
|
| 584 |
+
axacc_train.legend(loc='lower right')
|
| 585 |
+
axacc_test.legend(loc='lower right')
|
| 586 |
+
axacc_train.set_xlim([0, args.training_iterations])
|
| 587 |
+
axacc_test.set_xlim([0, args.training_iterations])
|
| 588 |
+
axacc_train.set_ylim([0, 1])
|
| 589 |
+
axacc_test.set_ylim([0, 1])
|
| 590 |
+
|
| 591 |
+
figacc.tight_layout()
|
| 592 |
+
figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
|
| 593 |
+
plt.close(figacc)
|
| 594 |
+
|
| 595 |
+
# Loss Plot
|
| 596 |
+
figloss = plt.figure(figsize=(10, 5))
|
| 597 |
+
axloss = figloss.add_subplot(111)
|
| 598 |
+
axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train (Agg): {train_losses[-1]:.4f}')
|
| 599 |
+
axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test (Agg): {test_losses[-1]:.4f}')
|
| 600 |
+
axloss.legend(loc='upper right')
|
| 601 |
+
axloss.set_xlabel("Iteration")
|
| 602 |
+
axloss.set_ylabel("Loss")
|
| 603 |
+
axloss.set_xlim([0, args.training_iterations])
|
| 604 |
+
axloss.set_ylim(bottom=0)
|
| 605 |
+
figloss.tight_layout()
|
| 606 |
+
figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
|
| 607 |
+
plt.close(figloss)
|
| 608 |
+
# --- End Plotting ---
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# --- Visualization (Rank 0, Conditional) ---
|
| 612 |
+
if is_main_process(rank) and args.model in ['ctm', 'lstm']:
|
| 613 |
+
# try:
|
| 614 |
+
model_module = model.module if isinstance(model, DDP) else model
|
| 615 |
+
# Use a consistent batch for viz if possible, or just next batch
|
| 616 |
+
inputs_viz, targets_viz = next(iter(testloader))
|
| 617 |
+
inputs_viz = inputs_viz.to(device)
|
| 618 |
+
targets_viz = targets_viz.to(device)
|
| 619 |
+
longest_index = (targets_viz!=4).sum(-1).argmax() # 4 assumed padding
|
| 620 |
+
|
| 621 |
+
pbar.set_description('Tracking (Rank 0): Viz Fwd Pass')
|
| 622 |
+
predictions_viz_raw, _, _, _, post_activations_viz, attention_tracking_viz = model_module(inputs_viz, track=True)
|
| 623 |
+
predictions_viz = predictions_viz_raw.reshape(predictions_viz_raw.size(0), -1, 5, predictions_viz_raw.size(-1))
|
| 624 |
+
|
| 625 |
+
att_shape = (model.module.kv_features.shape[2], model.module.kv_features.shape[3])
|
| 626 |
+
attention_tracking_viz = attention_tracking_viz.reshape(
|
| 627 |
+
attention_tracking_viz.shape[0],
|
| 628 |
+
attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
|
| 629 |
+
|
| 630 |
+
pbar.set_description('Tracking (Rank 0): Dynamics Plot')
|
| 631 |
+
plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
|
| 632 |
+
|
| 633 |
+
pbar.set_description('Tracking (Rank 0): Maze GIF')
|
| 634 |
+
if attention_tracking_viz is not None:
|
| 635 |
+
make_maze_gif((inputs_viz[longest_index].detach().cpu().numpy()+1)/2,
|
| 636 |
+
predictions_viz[longest_index].detach().cpu().numpy(),
|
| 637 |
+
targets_viz[longest_index].detach().cpu().numpy(),
|
| 638 |
+
attention_tracking_viz[:, longest_index],
|
| 639 |
+
args.log_dir)
|
| 640 |
+
# else:
|
| 641 |
+
# print("Skipping maze GIF due to attention shape issue.")
|
| 642 |
+
|
| 643 |
+
# except Exception as e_viz:
|
| 644 |
+
# print(f"Rank 0 visualization failed: {e_viz}")
|
| 645 |
+
# --- End Visualization ---
|
| 646 |
+
|
| 647 |
+
gc.collect()
|
| 648 |
+
if torch.cuda.is_available():
|
| 649 |
+
torch.cuda.empty_cache()
|
| 650 |
+
if world_size > 1: dist.barrier()
|
| 651 |
+
model.train()
|
| 652 |
+
# --- End Evaluation Block ---
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
if hasattr(train_sampler, 'set_epoch'): # Check if sampler has set_epoch
|
| 658 |
+
train_sampler.set_epoch(bi)
|
| 659 |
+
|
| 660 |
+
current_lr = optimizer.param_groups[-1]['lr']
|
| 661 |
+
|
| 662 |
+
try:
|
| 663 |
+
inputs, targets = next(iterator)
|
| 664 |
+
except StopIteration:
|
| 665 |
+
iterator = iter(trainloader)
|
| 666 |
+
inputs, targets = next(iterator)
|
| 667 |
+
|
| 668 |
+
inputs = inputs.to(device, non_blocking=True)
|
| 669 |
+
targets = targets.to(device, non_blocking=True)
|
| 670 |
+
|
| 671 |
+
# Defaults for logging
|
| 672 |
+
loss = torch.tensor(0.0, device=device) # Need loss defined for logging scope
|
| 673 |
+
accuracy_finegrained = 0.0
|
| 674 |
+
where_most_certain_val = -1.0
|
| 675 |
+
where_most_certain_std = 0.0
|
| 676 |
+
where_most_certain_min = -1
|
| 677 |
+
where_most_certain_max = -1
|
| 678 |
+
upto_where_mean = -1.0
|
| 679 |
+
upto_where_std = 0.0
|
| 680 |
+
upto_where_min = -1
|
| 681 |
+
upto_where_max = -1
|
| 682 |
+
|
| 683 |
+
with torch.autocast(device_type="cuda" if device.type == 'cuda' else "cpu", dtype=torch.float16, enabled=args.use_amp):
|
| 684 |
+
if args.do_compile: torch.compiler.cudagraph_mark_step_begin()
|
| 685 |
+
|
| 686 |
+
if args.model == 'ctm':
|
| 687 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 688 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 689 |
+
loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=True)
|
| 690 |
+
with torch.no_grad(): # Calculate local accuracy for logging
|
| 691 |
+
accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=device), :, where_most_certain] == targets).float().mean().item()
|
| 692 |
+
elif args.model == 'lstm':
|
| 693 |
+
predictions_raw, certainties, _ = model(inputs)
|
| 694 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5, predictions_raw.size(-1)) # B,S,C,T
|
| 695 |
+
loss, where_most_certain, upto_where = maze_loss(predictions, certainties, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False) # where = -1
|
| 696 |
+
with torch.no_grad():
|
| 697 |
+
accuracy_finegrained = (predictions.argmax(2)[torch.arange(predictions.size(0), device=device), :, where_most_certain] == targets).float().mean().item()
|
| 698 |
+
elif args.model == 'ff':
|
| 699 |
+
predictions_raw = model(inputs) # B, S*C
|
| 700 |
+
predictions = predictions_raw.reshape(predictions_raw.size(0), -1, 5) # B,S,C
|
| 701 |
+
loss, where_most_certain, upto_where = maze_loss(predictions.unsqueeze(-1), None, targets, cirriculum_lookahead=args.cirriculum_lookahead, use_most_certain=False) # where = -1
|
| 702 |
+
with torch.no_grad():
|
| 703 |
+
accuracy_finegrained = (predictions.argmax(2) == targets).float().mean().item()
|
| 704 |
+
|
| 705 |
+
# Extract stats from loss outputs
|
| 706 |
+
if torch.is_tensor(where_most_certain):
|
| 707 |
+
where_most_certain_val = where_most_certain.float().mean().item()
|
| 708 |
+
where_most_certain_std = where_most_certain.float().std().item()
|
| 709 |
+
where_most_certain_min = where_most_certain.min().item()
|
| 710 |
+
where_most_certain_max = where_most_certain.max().item()
|
| 711 |
+
elif isinstance(where_most_certain, int):
|
| 712 |
+
where_most_certain_val = float(where_most_certain); where_most_certain_min = where_most_certain; where_most_certain_max = where_most_certain
|
| 713 |
+
if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
|
| 714 |
+
upto_where_mean = np.mean(upto_where); upto_where_std = np.std(upto_where); upto_where_min = np.min(upto_where); upto_where_max = np.max(upto_where)
|
| 715 |
+
|
| 716 |
+
# Backprop / Step
|
| 717 |
+
scaler.scale(loss).backward()
|
| 718 |
+
if args.gradient_clipping!=-1:
|
| 719 |
+
scaler.unscale_(optimizer)
|
| 720 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
|
| 721 |
+
scaler.step(optimizer)
|
| 722 |
+
scaler.update()
|
| 723 |
+
optimizer.zero_grad(set_to_none=True)
|
| 724 |
+
scheduler.step()
|
| 725 |
+
|
| 726 |
+
# --- Aggregation and Logging (Rank 0) ---
|
| 727 |
+
loss_log = loss.detach()
|
| 728 |
+
if world_size > 1: dist.all_reduce(loss_log, op=dist.ReduceOp.AVG)
|
| 729 |
+
|
| 730 |
+
if is_main_process(rank):
|
| 731 |
+
pbar_desc = f'Loss(avg)={loss_log.item():.3f} Acc(loc)={accuracy_finegrained:.3f} LR={current_lr:.6f}'
|
| 732 |
+
if args.model in ['ctm', 'lstm'] or torch.is_tensor(where_most_certain):
|
| 733 |
+
pbar_desc += f' Cert={where_most_certain_val:.2f}'#+-{where_most_certain_std:.2f}' # Removed std for brevity
|
| 734 |
+
if isinstance(upto_where, (np.ndarray, list)) and len(upto_where) > 0:
|
| 735 |
+
pbar_desc += f' Path={upto_where_mean:.1f}'#+-{upto_where_std:.1f}'
|
| 736 |
+
pbar.set_description(f'{args.model.upper()} {pbar_desc}')
|
| 737 |
+
# --- End Aggregation and Logging ---
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
# --- Checkpointing (Rank 0) ---
|
| 744 |
+
if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter and is_main_process(rank):
|
| 745 |
+
pbar.set_description('Rank 0: Saving checkpoint...')
|
| 746 |
+
save_path = f'{args.log_dir}/checkpoint.pt'
|
| 747 |
+
model_state_to_save = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
|
| 748 |
+
|
| 749 |
+
checkpoint_data = {
|
| 750 |
+
'model_state_dict': model_state_to_save,
|
| 751 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 752 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 753 |
+
'scaler_state_dict': scaler.state_dict(),
|
| 754 |
+
'iteration': bi,
|
| 755 |
+
'train_losses': train_losses,
|
| 756 |
+
'test_losses': test_losses,
|
| 757 |
+
'train_accuracies': train_accuracies, # Saving simplified scalar list
|
| 758 |
+
'test_accuracies': test_accuracies, # Saving simplified scalar list
|
| 759 |
+
'train_accuracies_most_certain': train_accuracies_most_certain,
|
| 760 |
+
'test_accuracies_most_certain': test_accuracies_most_certain,
|
| 761 |
+
'train_accuracies_most_certain_permaze': train_accuracies_most_certain_permaze,
|
| 762 |
+
'test_accuracies_most_certain_permaze': test_accuracies_most_certain_permaze,
|
| 763 |
+
'iters': iters,
|
| 764 |
+
'args': args,
|
| 765 |
+
'torch_rng_state': torch.get_rng_state(),
|
| 766 |
+
'numpy_rng_state': np.random.get_state(),
|
| 767 |
+
'random_rng_state': random.getstate(),
|
| 768 |
+
}
|
| 769 |
+
torch.save(checkpoint_data, save_path)
|
| 770 |
+
# --- End Checkpointing ---
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
if world_size > 1: dist.barrier()
|
| 774 |
+
|
| 775 |
+
if is_main_process(rank):
|
| 776 |
+
pbar.update(1)
|
| 777 |
+
# --- End Training Loop ---
|
| 778 |
+
|
| 779 |
+
if is_main_process(rank):
|
| 780 |
+
pbar.close()
|
| 781 |
+
|
| 782 |
+
cleanup_ddp()
|
tasks/parity/README.md
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Parity
|
| 2 |
+
|
| 3 |
+
## Training
|
| 4 |
+
To run the parity training that we used for the paper, run bash scripts from the root level of the repository. For example, to train the 75-iteration, 25-memory-length CTM, run:
|
| 5 |
+
|
| 6 |
+
```
|
| 7 |
+
bash tasks/parity/scripts/train_ctm_75_25.sh
|
| 8 |
+
```
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
## Analysis
|
| 12 |
+
To run the analysis, first make sure the checkpoints are saved in the log directory (specified by the `log_dir` argument). The checkpoints can be obtained by either running the training code, or downloading them from [this link](https://drive.google.com/file/d/1itUS5_i9AyUo_7awllTx8X0PXYw9fnaG/view?usp=drive_link).
|
| 13 |
+
|
| 14 |
+
```
|
| 15 |
+
python -m tasks.parity.analysis.run --log_dir <PATH_TO_LOG_DIR>
|
| 16 |
+
```
|
tasks/parity/analysis/make_blog_gifs.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
import math
|
| 5 |
+
import imageio
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
from matplotlib.patches import FancyArrowPatch
|
| 9 |
+
from scipy.special import softmax
|
| 10 |
+
import matplotlib.cm as cm
|
| 11 |
+
from data.custom_datasets import ParityDataset
|
| 12 |
+
import umap
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from models.utils import reshape_predictions
|
| 17 |
+
from tasks.parity.utils import reshape_inputs
|
| 18 |
+
from tasks.parity.analysis.run import build_model_from_checkpoint_path
|
| 19 |
+
|
| 20 |
+
from tasks.image_classification.analysis.build_imagenet_viz_blog import save_frames_to_mp4
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def make_parity_gif(
|
| 24 |
+
predictions,
|
| 25 |
+
targets,
|
| 26 |
+
post_activations,
|
| 27 |
+
attention_weights,
|
| 28 |
+
inputs_to_model,
|
| 29 |
+
save_path,
|
| 30 |
+
umap_positions,
|
| 31 |
+
umap_point_scaler=1.0,
|
| 32 |
+
):
|
| 33 |
+
batch_index = 0
|
| 34 |
+
figscale = 0.32
|
| 35 |
+
n_steps, n_heads, seqLen = attention_weights.shape[:3]
|
| 36 |
+
grid_side = int(np.sqrt(seqLen))
|
| 37 |
+
frames = []
|
| 38 |
+
|
| 39 |
+
inputs_this_batch = inputs_to_model[:, batch_index]
|
| 40 |
+
preds_this_batch = predictions[batch_index]
|
| 41 |
+
targets_this_batch = targets[batch_index]
|
| 42 |
+
post_act_this_batch = post_activations[:, batch_index]
|
| 43 |
+
|
| 44 |
+
# build a flexible mosaic
|
| 45 |
+
mosaic = [
|
| 46 |
+
[f"att_0", f"in_0", "probs", "probs", "target", "target"],
|
| 47 |
+
[f"att_1", f"in_1", "probs", "probs", "target", "target"],
|
| 48 |
+
]
|
| 49 |
+
for h in range(2, n_heads):
|
| 50 |
+
mosaic.append(
|
| 51 |
+
[f"att_{h}", f"in_{h}", "umap", "umap",
|
| 52 |
+
"umap", "umap"]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
for t in range(n_steps):
|
| 56 |
+
rows = len(mosaic)
|
| 57 |
+
cell_size = figscale * 4
|
| 58 |
+
fig_h = rows * cell_size
|
| 59 |
+
|
| 60 |
+
fig, ax = plt.subplot_mosaic(
|
| 61 |
+
mosaic,
|
| 62 |
+
figsize=(6 * cell_size, fig_h),
|
| 63 |
+
constrained_layout=False,
|
| 64 |
+
gridspec_kw={'wspace': 0.05, 'hspace': 0.05}, # small gaps
|
| 65 |
+
)
|
| 66 |
+
# restore a little margin
|
| 67 |
+
fig.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02)
|
| 68 |
+
|
| 69 |
+
# probabilities heatmap
|
| 70 |
+
logits_t = preds_this_batch[:, :, t]
|
| 71 |
+
probs_t = softmax(logits_t, axis=1)[:, 0].reshape(grid_side, grid_side)
|
| 72 |
+
ax["probs"].imshow(probs_t, cmap="gray", vmin=0, vmax=1)
|
| 73 |
+
ax["probs"].axis("off")
|
| 74 |
+
|
| 75 |
+
# target overlay
|
| 76 |
+
ax["target"].imshow(
|
| 77 |
+
targets_this_batch.reshape(grid_side, grid_side),
|
| 78 |
+
cmap="gray_r", vmin=0, vmax=1
|
| 79 |
+
)
|
| 80 |
+
ax["target"].axis("off")
|
| 81 |
+
ax["target"].grid(which="minor", color="black", linestyle="-", linewidth=0.5)
|
| 82 |
+
|
| 83 |
+
z = post_act_this_batch[t]
|
| 84 |
+
low, high = np.percentile(z, 5), np.percentile(z, 95)
|
| 85 |
+
z_norm = np.clip((z - low) / (high - low), 0, 1)
|
| 86 |
+
point_sizes = (np.abs(z_norm - 0.5) * 100 + 5) * umap_point_scaler
|
| 87 |
+
cmap = plt.get_cmap("Spectral")
|
| 88 |
+
ax["umap"].scatter(
|
| 89 |
+
umap_positions[:, 0],
|
| 90 |
+
umap_positions[:, 1],
|
| 91 |
+
s=point_sizes,
|
| 92 |
+
c=cmap(z_norm),
|
| 93 |
+
alpha=0.8
|
| 94 |
+
)
|
| 95 |
+
ax["umap"].axis("off")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# normalize attention
|
| 99 |
+
att_t = attention_weights[t, :, :]
|
| 100 |
+
a_min, a_max = att_t.min(), att_t.max()
|
| 101 |
+
if not np.isclose(a_min, a_max):
|
| 102 |
+
att_t = (att_t - a_min) / (a_max - a_min + 1e-8)
|
| 103 |
+
else:
|
| 104 |
+
att_t = np.zeros_like(att_t)
|
| 105 |
+
|
| 106 |
+
# input image for arrows
|
| 107 |
+
img_t = inputs_this_batch[t].transpose(1, 2, 0)
|
| 108 |
+
|
| 109 |
+
if t == 0:
|
| 110 |
+
route_history = [[] for _ in range(n_heads)]
|
| 111 |
+
|
| 112 |
+
img_h, img_w = img_t.shape[:2]
|
| 113 |
+
cell_h = img_h // grid_side
|
| 114 |
+
cell_w = img_w // grid_side
|
| 115 |
+
|
| 116 |
+
for h in range(n_heads):
|
| 117 |
+
head_map = att_t[h].reshape(grid_side, grid_side)
|
| 118 |
+
ax[f"att_{h}"].imshow(head_map, cmap="viridis", vmin=0, vmax=1)
|
| 119 |
+
ax[f"att_{h}"].axis("off")
|
| 120 |
+
ax[f"in_{h}"].imshow(img_t, cmap="gray", vmin=0, vmax=1)
|
| 121 |
+
ax[f"in_{h}"].axis("off")
|
| 122 |
+
|
| 123 |
+
# track argmax center
|
| 124 |
+
flat_idx = np.argmax(head_map)
|
| 125 |
+
gy, gx = divmod(flat_idx, grid_side)
|
| 126 |
+
cx = int((gx + 0.5) * cell_w)
|
| 127 |
+
cy = int((gy + 0.5) * cell_h)
|
| 128 |
+
route_history[h].append((cx, cy))
|
| 129 |
+
|
| 130 |
+
cmap_steps = plt.colormaps.get_cmap("Spectral")
|
| 131 |
+
colors = [cmap_steps(i / (n_steps - 1)) for i in range(n_steps)]
|
| 132 |
+
for i in range(len(route_history[h]) - 1):
|
| 133 |
+
x0, y0 = route_history[h][i]
|
| 134 |
+
x1, y1 = route_history[h][i + 1]
|
| 135 |
+
color = colors[i]
|
| 136 |
+
is_last = (i == len(route_history[h]) - 2)
|
| 137 |
+
style = '->' if is_last else '-'
|
| 138 |
+
lw = 2.0 if is_last else 1.6
|
| 139 |
+
alpha = 1.0 if is_last else 0.9
|
| 140 |
+
scale = 10 if is_last else 1
|
| 141 |
+
|
| 142 |
+
# draw arrow
|
| 143 |
+
arr = FancyArrowPatch(
|
| 144 |
+
(x0, y0), (x1, y1),
|
| 145 |
+
arrowstyle=style,
|
| 146 |
+
linewidth=lw,
|
| 147 |
+
mutation_scale=scale,
|
| 148 |
+
alpha=alpha,
|
| 149 |
+
facecolor=color,
|
| 150 |
+
edgecolor=color,
|
| 151 |
+
shrinkA=0, shrinkB=0,
|
| 152 |
+
capstyle='round', joinstyle='round',
|
| 153 |
+
zorder=3 if is_last else 2,
|
| 154 |
+
clip_on=False,
|
| 155 |
+
)
|
| 156 |
+
ax[f"in_{h}"].add_patch(arr)
|
| 157 |
+
|
| 158 |
+
ax[f"in_{h}"].scatter(
|
| 159 |
+
x1, y1,
|
| 160 |
+
marker='x',
|
| 161 |
+
s=40,
|
| 162 |
+
color=color,
|
| 163 |
+
linewidths=lw,
|
| 164 |
+
zorder=4
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
canvas = fig.canvas
|
| 168 |
+
canvas.draw()
|
| 169 |
+
frame = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
|
| 170 |
+
w, h = canvas.get_width_height()
|
| 171 |
+
frames.append(frame.reshape(h, w, 4)[..., :3])
|
| 172 |
+
plt.close(fig)
|
| 173 |
+
|
| 174 |
+
# save gif
|
| 175 |
+
imageio.mimsave(f"{save_path}/activation.gif", frames, fps=15, loop=0)
|
| 176 |
+
|
| 177 |
+
# save mp4
|
| 178 |
+
save_frames_to_mp4(
|
| 179 |
+
[fm[:, :, ::-1] for fm in frames], # RGB→BGR
|
| 180 |
+
f"{save_path}/activation.mp4",
|
| 181 |
+
fps=15,
|
| 182 |
+
gop_size=1,
|
| 183 |
+
preset="slow"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def run_umap(model, testloader):
|
| 187 |
+
all_post_activations = []
|
| 188 |
+
point_counts = 150
|
| 189 |
+
sampled = 0
|
| 190 |
+
with tqdm(total=point_counts, desc="Collecting UMAP data") as pbar:
|
| 191 |
+
for inputs, _ in testloader:
|
| 192 |
+
for i in range(inputs.size(0)):
|
| 193 |
+
if sampled >= point_counts:
|
| 194 |
+
break
|
| 195 |
+
input_i = inputs[i].unsqueeze(0).to(device)
|
| 196 |
+
_, _, _, _, post_activations, _ = model(input_i, track=True)
|
| 197 |
+
all_post_activations.append(post_activations)
|
| 198 |
+
sampled += 1
|
| 199 |
+
pbar.update(1)
|
| 200 |
+
if sampled >= point_counts:
|
| 201 |
+
break
|
| 202 |
+
|
| 203 |
+
stacked = np.stack(all_post_activations, 1)
|
| 204 |
+
umap_features = stacked.reshape(-1, stacked.shape[-1])
|
| 205 |
+
reducer = umap.UMAP(
|
| 206 |
+
n_components=2,
|
| 207 |
+
n_neighbors=20,
|
| 208 |
+
min_dist=1,
|
| 209 |
+
spread=1,
|
| 210 |
+
metric='cosine',
|
| 211 |
+
local_connectivity=1
|
| 212 |
+
)
|
| 213 |
+
positions = reducer.fit_transform(umap_features.T)
|
| 214 |
+
return positions
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def run_model_and_make_gif(checkpoint_path, save_path, device):
|
| 218 |
+
|
| 219 |
+
parity_sequence_length = 64
|
| 220 |
+
iterations = 75
|
| 221 |
+
|
| 222 |
+
test_data = ParityDataset(sequence_length=parity_sequence_length, length=10000)
|
| 223 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=True, num_workers=0, drop_last=False)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
model, _ = build_model_from_checkpoint_path(checkpoint_path, "ctm", device=device)
|
| 227 |
+
|
| 228 |
+
input = torch.randint(0, 2, (64,), dtype=torch.float32, device=device) * 2 - 1
|
| 229 |
+
input = input.unsqueeze(0)
|
| 230 |
+
|
| 231 |
+
target = torch.cumsum((input == -1).to(torch.long), dim=1) % 2
|
| 232 |
+
target = target.unsqueeze(0)
|
| 233 |
+
|
| 234 |
+
positions = run_umap(model, testloader)
|
| 235 |
+
|
| 236 |
+
model.eval()
|
| 237 |
+
with torch.inference_mode():
|
| 238 |
+
predictions, _, _, _, post_activations, attention = model(input, track=True)
|
| 239 |
+
predictons = reshape_predictions(predictions, prediction_reshaper=[parity_sequence_length, 2])
|
| 240 |
+
input_images = reshape_inputs(input, iterations, grid_size=int(math.sqrt(parity_sequence_length)))
|
| 241 |
+
|
| 242 |
+
make_parity_gif(
|
| 243 |
+
predictions=predictons.detach().cpu().numpy(),
|
| 244 |
+
targets=target.detach().cpu().numpy(),
|
| 245 |
+
post_activations=post_activations,
|
| 246 |
+
attention_weights=attention.squeeze(1).squeeze(2),
|
| 247 |
+
inputs_to_model=input_images,
|
| 248 |
+
save_path=save_path,
|
| 249 |
+
umap_positions=positions,
|
| 250 |
+
umap_point_scaler=1.0,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
|
| 257 |
+
CHECKPOINT_PATH = "checkpoints/parity/run1/ctm_75_25/checkpoint_200000.pt"
|
| 258 |
+
SAVE_PATH = f"tasks/parity/analysis/outputs/blog_gifs/"
|
| 259 |
+
os.makedirs(SAVE_PATH, exist_ok=True)
|
| 260 |
+
|
| 261 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 262 |
+
|
| 263 |
+
run_model_and_make_gif(CHECKPOINT_PATH, SAVE_PATH, device)
|
tasks/parity/analysis/run.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import argparse
|
| 4 |
+
import multiprocessing
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import csv
|
| 9 |
+
from utils.housekeeping import set_seed
|
| 10 |
+
from data.custom_datasets import ParityDataset
|
| 11 |
+
from tasks.parity.utils import prepare_model, reshape_attention_weights, reshape_inputs, get_where_most_certain
|
| 12 |
+
from tasks.parity.plotting import plot_attention_trajectory, plot_input, plot_target, plot_probabilities, plot_prediction, plot_accuracy_training, create_attentions_heatmap_gif, create_accuracies_heatmap_gif, create_stacked_gif, plot_training_curve_all_runs, plot_accuracy_thinking_time, make_parity_gif, plot_lstm_last_and_certain_accuracy
|
| 13 |
+
from models.utils import compute_normalized_entropy, reshape_predictions, get_latest_checkpoint_file, get_checkpoint_files, load_checkpoint, get_model_args_from_checkpoint, get_all_log_dirs
|
| 14 |
+
from tasks.image_classification.plotting import plot_neural_dynamics
|
| 15 |
+
|
| 16 |
+
import seaborn as sns
|
| 17 |
+
sns.set_palette("hls")
|
| 18 |
+
sns.set_style('darkgrid')
|
| 19 |
+
|
| 20 |
+
def parse_args():
|
| 21 |
+
parser = argparse.ArgumentParser(description='Parity Analysis')
|
| 22 |
+
parser.add_argument('--log_dir', type=str, default='checkpoints/parity', help='Directory to save logs.')
|
| 23 |
+
parser.add_argument('--batch_size_test', type=int, default=128, help='batch size for testing')
|
| 24 |
+
parser.add_argument('--scale_training_curve', type=float, default=0.6, help='Scaling factor for plots.')
|
| 25 |
+
parser.add_argument('--scale_heatmap', type=float, default=0.4, help='Scaling factor for heatmap plots.')
|
| 26 |
+
parser.add_argument('--scale_training_index_accuracy', type=float, default=0.4, help='Scaling factor for training index accuracy plots.')
|
| 27 |
+
parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility.')
|
| 28 |
+
parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
|
| 29 |
+
parser.add_argument('--model_type', type=str, choices=['ctm', 'lstm'], default='ctm', help='Type of model to analyze (ctm or lstm).')
|
| 30 |
+
return parser.parse_args()
|
| 31 |
+
|
| 32 |
+
def calculate_corrects(predictions, targets):
|
| 33 |
+
predicted_labels = predictions.argmax(2)
|
| 34 |
+
accuracy = (predicted_labels == targets.unsqueeze(-1))
|
| 35 |
+
return accuracy.detach().cpu().numpy()
|
| 36 |
+
|
| 37 |
+
def get_corrects_per_element_at_most_certain_time(predictions, certainty, targets):
|
| 38 |
+
where_most_certain = get_where_most_certain(certainty)
|
| 39 |
+
corrects = (predictions.argmax(2)[torch.arange(predictions.size(0), device=predictions.device),:,where_most_certain] == targets).float()
|
| 40 |
+
return corrects.detach().cpu().numpy()
|
| 41 |
+
|
| 42 |
+
def calculate_entropy_average_over_batch(normalized_entropy_per_elements):
|
| 43 |
+
normalized_entropy_per_elements_avg_batch = normalized_entropy_per_elements.mean(axis=1)
|
| 44 |
+
return normalized_entropy_per_elements_avg_batch
|
| 45 |
+
|
| 46 |
+
def calculate_thinking_time_average_over_batch(normalized_entropy_per_elements):
|
| 47 |
+
first_occurrence = calculate_thinking_time(normalized_entropy_per_elements)
|
| 48 |
+
average_thinking_time = np.mean(first_occurrence, axis=0)
|
| 49 |
+
return average_thinking_time
|
| 50 |
+
|
| 51 |
+
def calculate_thinking_time(normalized_entropy_per_elements, finish_type="min", entropy_threshold=0.1):
|
| 52 |
+
if finish_type == "min":
|
| 53 |
+
min_entropy_time = np.argmin(normalized_entropy_per_elements, axis=0)
|
| 54 |
+
return min_entropy_time
|
| 55 |
+
elif finish_type == "threshold":
|
| 56 |
+
T, B, S = normalized_entropy_per_elements.shape
|
| 57 |
+
below_threshold = normalized_entropy_per_elements < entropy_threshold
|
| 58 |
+
first_occurrence = np.argmax(below_threshold, axis=0)
|
| 59 |
+
no_true = ~np.any(below_threshold, axis=0)
|
| 60 |
+
first_occurrence[no_true] = T
|
| 61 |
+
return first_occurrence
|
| 62 |
+
|
| 63 |
+
def test_handcrafted_examples(model, args, run_model_spefic_save_dir, device):
|
| 64 |
+
test_cases = []
|
| 65 |
+
all_even_input = torch.full((args.parity_sequence_length,), 1.0, dtype=torch.float32, device=device)
|
| 66 |
+
all_even_target = torch.zeros_like(all_even_input, dtype=torch.long)
|
| 67 |
+
test_cases.append((all_even_input, all_even_target))
|
| 68 |
+
|
| 69 |
+
all_odd_input = torch.full((args.parity_sequence_length,), -1.0, dtype=torch.float32, device=device)
|
| 70 |
+
all_odd_target = torch.cumsum((all_odd_input == -1).to(torch.long), dim=0) % 2
|
| 71 |
+
test_cases.append((all_odd_input, all_odd_target))
|
| 72 |
+
|
| 73 |
+
random_input = torch.randint(0, 2, (args.parity_sequence_length,), dtype=torch.float32, device=device) * 2 - 1
|
| 74 |
+
random_target = torch.cumsum((random_input == -1).to(torch.long), dim=0) % 2
|
| 75 |
+
test_cases.append((random_input, random_target))
|
| 76 |
+
|
| 77 |
+
for i, (inputs, targets) in enumerate(test_cases):
|
| 78 |
+
inputs = inputs.unsqueeze(0)
|
| 79 |
+
targets = targets.unsqueeze(0)
|
| 80 |
+
filename = f"eval_handcrafted_{i}"
|
| 81 |
+
extend_inference_time = False
|
| 82 |
+
handcraft_dir = f"{run_model_spefic_save_dir}/handcrafted_examples/{i}"
|
| 83 |
+
os.makedirs(handcraft_dir, exist_ok=True)
|
| 84 |
+
|
| 85 |
+
model.eval()
|
| 86 |
+
with torch.inference_mode():
|
| 87 |
+
if extend_inference_time:
|
| 88 |
+
model.iterations = model.iterations * 2
|
| 89 |
+
predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
|
| 90 |
+
predictions = reshape_predictions(predictions, prediction_reshaper=[args.parity_sequence_length, 2])
|
| 91 |
+
input_images = reshape_inputs(inputs, args.iterations, grid_size=int(math.sqrt(args.parity_sequence_length)))
|
| 92 |
+
|
| 93 |
+
plot_neural_dynamics(post_activations, 100, handcraft_dir, axis_snap=False)
|
| 94 |
+
|
| 95 |
+
process = multiprocessing.Process(
|
| 96 |
+
target=make_parity_gif,
|
| 97 |
+
args=(
|
| 98 |
+
predictions.detach().cpu().numpy(),
|
| 99 |
+
certainties.detach().cpu().numpy(),
|
| 100 |
+
targets.detach().cpu().numpy(),
|
| 101 |
+
pre_activations,
|
| 102 |
+
post_activations,
|
| 103 |
+
reshape_attention_weights(attention),
|
| 104 |
+
input_images,
|
| 105 |
+
f"{handcraft_dir}/eval_output_val_{0}_iter_{0}.gif",
|
| 106 |
+
))
|
| 107 |
+
process.start()
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
input_images = input_images.squeeze(1).squeeze(1)
|
| 111 |
+
attention = attention.squeeze(1)
|
| 112 |
+
|
| 113 |
+
for h in range(args.heads):
|
| 114 |
+
plot_attention_trajectory(attention[:, h, :, :], certainties, input_images, handcraft_dir, filename + f"_head_{h}", args)
|
| 115 |
+
|
| 116 |
+
plot_attention_trajectory(attention.mean(1), certainties, input_images, handcraft_dir, filename, args)
|
| 117 |
+
plot_input(input_images, handcraft_dir, filename)
|
| 118 |
+
plot_target(targets, handcraft_dir, filename, args)
|
| 119 |
+
plot_probabilities(predictions, certainties, handcraft_dir, filename, args)
|
| 120 |
+
plot_prediction(predictions, certainties,handcraft_dir, filename, args)
|
| 121 |
+
|
| 122 |
+
if extend_inference_time:
|
| 123 |
+
model.iterations = model.iterations // 2
|
| 124 |
+
model.train()
|
| 125 |
+
pass
|
| 126 |
+
|
| 127 |
+
def build_model_from_checkpoint_path(checkpoint_path, model_type, device="cpu"):
|
| 128 |
+
checkpoint = load_checkpoint(checkpoint_path, device)
|
| 129 |
+
model_args = get_model_args_from_checkpoint(checkpoint)
|
| 130 |
+
model = prepare_model([model_args.parity_sequence_length, 2], model_args, device)
|
| 131 |
+
model.load_state_dict(checkpoint["model_state_dict"], strict=False)
|
| 132 |
+
return model, model_args
|
| 133 |
+
|
| 134 |
+
def analyze_trained_model(run_model_spefic_save_dir, args, device):
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
|
| 137 |
+
latest_checkpoint_path = get_latest_checkpoint_file(args.log_dir)
|
| 138 |
+
model, model_args = build_model_from_checkpoint_path(latest_checkpoint_path, args.model_type, device=device)
|
| 139 |
+
model.eval()
|
| 140 |
+
model_args.log_dir = args.log_dir
|
| 141 |
+
test_data = ParityDataset(sequence_length=model_args.parity_sequence_length, length=10000)
|
| 142 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=0, drop_last=False)
|
| 143 |
+
|
| 144 |
+
corrects, corrects_at_most_certain_times, entropys, attentions = [], [], [], []
|
| 145 |
+
|
| 146 |
+
for inputs, targets in testloader:
|
| 147 |
+
inputs = inputs.to(device)
|
| 148 |
+
targets = targets.to(device)
|
| 149 |
+
predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
|
| 150 |
+
predictions = reshape_predictions(predictions, prediction_reshaper=[model_args.parity_sequence_length, 2])
|
| 151 |
+
corrects_batch = calculate_corrects(predictions, targets)
|
| 152 |
+
corrects_at_most_certain_time_batch = get_corrects_per_element_at_most_certain_time(predictions, certainties, targets)
|
| 153 |
+
corrects.append(corrects_batch)
|
| 154 |
+
corrects_at_most_certain_times.append(corrects_at_most_certain_time_batch)
|
| 155 |
+
attentions.append(attention)
|
| 156 |
+
|
| 157 |
+
test_handcrafted_examples(model, model_args, run_model_spefic_save_dir, device)
|
| 158 |
+
|
| 159 |
+
overall_mean_accuracy = np.mean(np.vstack(corrects_at_most_certain_times))
|
| 160 |
+
overall_std_accuracy = np.std(np.mean(np.vstack(corrects_at_most_certain_times), axis=1))
|
| 161 |
+
|
| 162 |
+
return overall_mean_accuracy, overall_std_accuracy, model_args.iterations
|
| 163 |
+
|
| 164 |
+
def analyze_training(run_model_spefic_save_dir, args, device):
|
| 165 |
+
checkpoint_files = get_checkpoint_files(args.log_dir)
|
| 166 |
+
all_accuracies = []
|
| 167 |
+
all_accuracies_at_most_certain_time = []
|
| 168 |
+
all_average_thinking_times = []
|
| 169 |
+
all_std_thinking_times = []
|
| 170 |
+
all_attentions = []
|
| 171 |
+
for checkpoint_path in checkpoint_files:
|
| 172 |
+
model, model_args = build_model_from_checkpoint_path(checkpoint_path, args.model_type, device=device)
|
| 173 |
+
model_args.log_dir = run_model_spefic_save_dir
|
| 174 |
+
test_data = ParityDataset(sequence_length=model_args.parity_sequence_length, length=1000)
|
| 175 |
+
testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=0, drop_last=False)
|
| 176 |
+
corrects = []
|
| 177 |
+
corrects_at_most_certain_times = []
|
| 178 |
+
thinking_times = []
|
| 179 |
+
attentions = []
|
| 180 |
+
|
| 181 |
+
for inputs, targets in testloader:
|
| 182 |
+
inputs = inputs.to(device)
|
| 183 |
+
targets = targets.to(device)
|
| 184 |
+
predictions, certainties, synchronisation, pre_activations, post_activations, attention = model(inputs, track=True)
|
| 185 |
+
predictions = reshape_predictions(predictions, prediction_reshaper=[model_args.parity_sequence_length, 2])
|
| 186 |
+
attention = reshape_attention_weights(attention)
|
| 187 |
+
|
| 188 |
+
corrects_batch = calculate_corrects(predictions, targets)
|
| 189 |
+
corrects_at_most_certain_time_batch = get_corrects_per_element_at_most_certain_time(predictions, certainties, targets)
|
| 190 |
+
entropy_per_element = compute_normalized_entropy(predictions.permute(0,3,1,2), reduction='none').detach().cpu().numpy()
|
| 191 |
+
thinking_times_batch = np.argmin(entropy_per_element, axis=1)
|
| 192 |
+
|
| 193 |
+
corrects.append(corrects_batch)
|
| 194 |
+
corrects_at_most_certain_times.append(corrects_at_most_certain_time_batch)
|
| 195 |
+
thinking_times.append(thinking_times_batch)
|
| 196 |
+
attentions.append(attention)
|
| 197 |
+
|
| 198 |
+
checkpoint_average_accuracies = np.mean(np.concatenate(corrects, axis=0), axis=0).transpose(1,0)
|
| 199 |
+
all_accuracies.append(checkpoint_average_accuracies)
|
| 200 |
+
|
| 201 |
+
stacked_corrects_at_most_certain_times = np.vstack(corrects_at_most_certain_times)
|
| 202 |
+
checkpoint_average_accuracy_at_most_certain_time = np.mean(stacked_corrects_at_most_certain_times, axis=0)
|
| 203 |
+
all_accuracies_at_most_certain_time.append(checkpoint_average_accuracy_at_most_certain_time)
|
| 204 |
+
|
| 205 |
+
checkpoint_thinking_times = np.concatenate(thinking_times, axis=0)
|
| 206 |
+
checkpoint_average_thinking_time = np.mean(checkpoint_thinking_times, axis=0)
|
| 207 |
+
checkpoint_std_thinking_time = np.std(checkpoint_thinking_times, axis=0)
|
| 208 |
+
all_average_thinking_times.append(checkpoint_average_thinking_time)
|
| 209 |
+
all_std_thinking_times.append(checkpoint_std_thinking_time)
|
| 210 |
+
|
| 211 |
+
checkpoint_average_attentions = np.mean(np.concatenate(attentions, axis=1), axis=1)
|
| 212 |
+
all_attentions.append(checkpoint_average_attentions)
|
| 213 |
+
|
| 214 |
+
plot_accuracy_training(all_accuracies_at_most_certain_time, args.scale_training_index_accuracy, run_model_spefic_save_dir, args=model_args)
|
| 215 |
+
create_attentions_heatmap_gif(all_attentions, args.scale_heatmap, run_model_spefic_save_dir, model_args)
|
| 216 |
+
create_accuracies_heatmap_gif(np.array(all_accuracies), all_average_thinking_times, all_std_thinking_times, args.scale_heatmap, run_model_spefic_save_dir, model_args)
|
| 217 |
+
create_stacked_gif(run_model_spefic_save_dir)
|
| 218 |
+
|
| 219 |
+
def get_accuracy_and_loss_from_checkpoint(checkpoint):
|
| 220 |
+
training_iteration = checkpoint.get('training_iteration', 0)
|
| 221 |
+
train_losses = checkpoint.get('train_losses', [])
|
| 222 |
+
test_losses = checkpoint.get('test_losses', [])
|
| 223 |
+
train_accuracies = checkpoint.get('train_accuracies_most_certain', [])
|
| 224 |
+
test_accuracies = checkpoint.get('test_accuracies_most_certain', [])
|
| 225 |
+
return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
|
| 229 |
+
args = parse_args()
|
| 230 |
+
|
| 231 |
+
device = f'cuda:{args.device[0]}' if args.device[0] != -1 else 'cpu'
|
| 232 |
+
|
| 233 |
+
set_seed(args.seed)
|
| 234 |
+
|
| 235 |
+
save_dir = "tasks/parity/analysis/outputs"
|
| 236 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 237 |
+
|
| 238 |
+
accuracy_csv_file_path = os.path.join(save_dir, "accuracy.csv")
|
| 239 |
+
if os.path.exists(accuracy_csv_file_path):
|
| 240 |
+
os.remove(accuracy_csv_file_path)
|
| 241 |
+
|
| 242 |
+
all_runs_log_dirs = get_all_log_dirs(args.log_dir)
|
| 243 |
+
|
| 244 |
+
plot_training_curve_all_runs(all_runs_log_dirs, save_dir, args.scale_training_curve, device, x_max=200_000)
|
| 245 |
+
plot_lstm_last_and_certain_accuracy(all_folders=all_runs_log_dirs, save_path=f"{save_dir}/lstm_final_vs_certain_accuracy.png", scale=args.scale_training_curve)
|
| 246 |
+
|
| 247 |
+
progress_bar = tqdm(all_runs_log_dirs, desc="Analyzing Runs", dynamic_ncols=True)
|
| 248 |
+
for folder in progress_bar:
|
| 249 |
+
|
| 250 |
+
run, model_name = folder.strip("/").split("/")[-2:]
|
| 251 |
+
|
| 252 |
+
run_model_spefic_save_dir = f"{save_dir}/{model_name}/{run}"
|
| 253 |
+
os.makedirs(run_model_spefic_save_dir, exist_ok=True)
|
| 254 |
+
|
| 255 |
+
args.log_dir = folder
|
| 256 |
+
progress_bar.set_description(f"Analyzing Trained Model at {folder}")
|
| 257 |
+
|
| 258 |
+
accuracy_mean, accuracy_std, num_iterations = analyze_trained_model(run_model_spefic_save_dir, args, device)
|
| 259 |
+
|
| 260 |
+
with open(accuracy_csv_file_path, mode='a', newline='') as file:
|
| 261 |
+
writer = csv.writer(file)
|
| 262 |
+
if file.tell() == 0:
|
| 263 |
+
writer.writerow(["Run", "Overall Mean Accuracy", "Overall Std Accuracy", "Num Iterations"])
|
| 264 |
+
writer.writerow([folder, accuracy_mean, accuracy_std, num_iterations])
|
| 265 |
+
|
| 266 |
+
progress_bar.set_description(f"Analyzing Training at {folder}")
|
| 267 |
+
analyze_training(run_model_spefic_save_dir, args, device)
|
| 268 |
+
|
| 269 |
+
plot_accuracy_thinking_time(accuracy_csv_file_path, scale=args.scale_training_curve, output_dir=save_dir)
|
tasks/parity/plotting.py
ADDED
|
@@ -0,0 +1,896 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from matplotlib.lines import Line2D
|
| 7 |
+
import matplotlib as mpl
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import matplotlib.patheffects as path_effects
|
| 10 |
+
from matplotlib.ticker import FuncFormatter
|
| 11 |
+
from scipy.special import softmax
|
| 12 |
+
import imageio.v2 as imageio
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import math
|
| 15 |
+
import re
|
| 16 |
+
sns.set_style('darkgrid')
|
| 17 |
+
mpl.use('Agg')
|
| 18 |
+
|
| 19 |
+
from tasks.parity.utils import get_where_most_certain, parse_folder_name
|
| 20 |
+
from models.utils import get_latest_checkpoint_file, load_checkpoint, get_model_args_from_checkpoint, get_accuracy_and_loss_from_checkpoint
|
| 21 |
+
from tasks.image_classification.plotting import save_frames_to_mp4
|
| 22 |
+
|
| 23 |
+
def make_parity_gif(predictions, certainties, targets, pre_activations, post_activations, attention_weights, inputs_to_model, filename):
|
| 24 |
+
|
| 25 |
+
# Config
|
| 26 |
+
batch_index = 0
|
| 27 |
+
n_neurons_to_visualise = 16
|
| 28 |
+
figscale = 0.28
|
| 29 |
+
n_steps = len(pre_activations)
|
| 30 |
+
frames = []
|
| 31 |
+
heatmap_cmap = sns.color_palette("viridis", as_cmap=True)
|
| 32 |
+
|
| 33 |
+
these_pre_acts = pre_activations[:, batch_index, :] # Shape: (T, H)
|
| 34 |
+
these_post_acts = post_activations[:, batch_index, :] # Shape: (T, H)
|
| 35 |
+
these_inputs = inputs_to_model[:, batch_index, :, :, :] # Shape: (T, C, H, W)
|
| 36 |
+
these_predictions = predictions[batch_index, :, :, :] # Shape: (d, C, T)
|
| 37 |
+
these_certainties = certainties[batch_index, :, :] # Shape: (C, T)
|
| 38 |
+
these_attention_weights = attention_weights[:, batch_index, :, :]
|
| 39 |
+
|
| 40 |
+
# Create mosaic layout
|
| 41 |
+
mosaic = [['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'target', 'target'] for _ in range(2)] + \
|
| 42 |
+
[['img_data', 'img_data', 'attention', 'attention', 'probs', 'probs', 'target', 'target'] for _ in range(2)] + \
|
| 43 |
+
[['certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty']] + \
|
| 44 |
+
[[f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}'] for ti in range(n_neurons_to_visualise)]
|
| 45 |
+
|
| 46 |
+
for stepi in range(n_steps):
|
| 47 |
+
fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale))
|
| 48 |
+
|
| 49 |
+
# Plot predictions
|
| 50 |
+
d = these_predictions.shape[0]
|
| 51 |
+
grid_side = int(np.sqrt(d))
|
| 52 |
+
logits = these_predictions[:, :, stepi]
|
| 53 |
+
|
| 54 |
+
probs = softmax(logits, axis=1)
|
| 55 |
+
probs_grid = probs[:, 0].reshape(grid_side, grid_side)
|
| 56 |
+
axes_gif["probs"].imshow(probs_grid, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
|
| 57 |
+
axes_gif["probs"].axis('off')
|
| 58 |
+
axes_gif["probs"].set_title('Probabilties')
|
| 59 |
+
|
| 60 |
+
# Create and show attention heatmap
|
| 61 |
+
this_input_gate = these_attention_weights[stepi]
|
| 62 |
+
gate_min, gate_max = np.nanmin(this_input_gate), np.nanmax(this_input_gate)
|
| 63 |
+
if not np.isclose(gate_min, gate_max):
|
| 64 |
+
normalized_gate = (this_input_gate - gate_min) / (gate_max - gate_min + 1e-8)
|
| 65 |
+
else:
|
| 66 |
+
normalized_gate = np.zeros_like(this_input_gate)
|
| 67 |
+
attention_weights_heatmap = heatmap_cmap(normalized_gate)[:,:,:3]
|
| 68 |
+
|
| 69 |
+
# Show heatmaps
|
| 70 |
+
axes_gif['attention'].imshow(attention_weights_heatmap, vmin=0, vmax=1)
|
| 71 |
+
axes_gif['attention'].axis('off')
|
| 72 |
+
axes_gif['attention'].set_title('Attention')
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Plot target
|
| 76 |
+
target_grid = targets[batch_index].reshape(grid_side, grid_side)
|
| 77 |
+
axes_gif["target"].imshow(target_grid, cmap='viridis_r', interpolation='nearest', vmin=0, vmax=1)
|
| 78 |
+
axes_gif["target"].axis('off')
|
| 79 |
+
axes_gif["target"].set_title('Target')
|
| 80 |
+
|
| 81 |
+
# Add certainty plot
|
| 82 |
+
axes_gif['certainty'].plot(np.arange(n_steps), these_certainties[1], 'k-', linewidth=2)
|
| 83 |
+
axes_gif['certainty'].set_xlim([0, n_steps-1])
|
| 84 |
+
axes_gif['certainty'].axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
|
| 85 |
+
axes_gif['certainty'].set_xticklabels([])
|
| 86 |
+
axes_gif['certainty'].set_yticklabels([])
|
| 87 |
+
axes_gif['certainty'].grid(False)
|
| 88 |
+
|
| 89 |
+
# Plot neuron traces
|
| 90 |
+
for neuroni in range(n_neurons_to_visualise):
|
| 91 |
+
ax = axes_gif[f'trace_{neuroni}']
|
| 92 |
+
|
| 93 |
+
pre_activation = these_pre_acts[:, neuroni]
|
| 94 |
+
post_activation = these_post_acts[:, neuroni]
|
| 95 |
+
|
| 96 |
+
ax_pre = ax.twinx()
|
| 97 |
+
|
| 98 |
+
pre_min, pre_max = np.min(pre_activation), np.max(pre_activation)
|
| 99 |
+
post_min, post_max = np.min(post_activation), np.max(post_activation)
|
| 100 |
+
|
| 101 |
+
ax_pre.plot(np.arange(n_steps), pre_activation,
|
| 102 |
+
color='grey',
|
| 103 |
+
linestyle='--',
|
| 104 |
+
linewidth=1,
|
| 105 |
+
alpha=0.4,
|
| 106 |
+
label='Pre-activation')
|
| 107 |
+
|
| 108 |
+
color = 'blue' if neuroni % 2 else 'red'
|
| 109 |
+
ax.plot(np.arange(n_steps), post_activation,
|
| 110 |
+
color=color,
|
| 111 |
+
linestyle='-',
|
| 112 |
+
linewidth=2,
|
| 113 |
+
alpha=1.0,
|
| 114 |
+
label='Post-activation')
|
| 115 |
+
|
| 116 |
+
ax.set_xlim([0, n_steps-1])
|
| 117 |
+
ax_pre.set_xlim([0, n_steps-1])
|
| 118 |
+
|
| 119 |
+
if pre_min != pre_max:
|
| 120 |
+
ax_pre.set_ylim([pre_min, pre_max])
|
| 121 |
+
if post_min != post_max:
|
| 122 |
+
ax.set_ylim([post_min, post_max])
|
| 123 |
+
|
| 124 |
+
ax.axvline(x=stepi, color='black', linewidth=1, alpha=0.5)
|
| 125 |
+
|
| 126 |
+
ax.set_xticklabels([])
|
| 127 |
+
ax.set_yticklabels([])
|
| 128 |
+
ax.grid(False)
|
| 129 |
+
|
| 130 |
+
ax_pre.set_xticklabels([])
|
| 131 |
+
ax_pre.set_yticklabels([])
|
| 132 |
+
ax_pre.grid(False)
|
| 133 |
+
|
| 134 |
+
# Show input image
|
| 135 |
+
this_image = these_inputs[stepi].transpose(1, 2, 0)
|
| 136 |
+
axes_gif['img_data'].imshow(this_image, cmap='viridis', vmin=0, vmax=1)
|
| 137 |
+
axes_gif['img_data'].grid(False)
|
| 138 |
+
axes_gif['img_data'].set_xticks([])
|
| 139 |
+
axes_gif['img_data'].set_yticks([])
|
| 140 |
+
axes_gif['img_data'].set_title('Input')
|
| 141 |
+
|
| 142 |
+
# Save frames
|
| 143 |
+
fig_gif.tight_layout(pad=0.1)
|
| 144 |
+
if stepi == 0:
|
| 145 |
+
fig_gif.savefig(filename.split('.gif')[0]+'_frame0.png', dpi=100)
|
| 146 |
+
if stepi == 1:
|
| 147 |
+
fig_gif.savefig(filename.split('.gif')[0]+'_frame1.png', dpi=100)
|
| 148 |
+
if stepi == n_steps-1:
|
| 149 |
+
fig_gif.savefig(filename.split('.gif')[0]+'_frame-1.png', dpi=100)
|
| 150 |
+
|
| 151 |
+
# Convert to frame
|
| 152 |
+
canvas = fig_gif.canvas
|
| 153 |
+
canvas.draw()
|
| 154 |
+
image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8')
|
| 155 |
+
image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3]
|
| 156 |
+
frames.append(image_numpy)
|
| 157 |
+
plt.close(fig_gif)
|
| 158 |
+
|
| 159 |
+
imageio.mimsave(filename, frames, fps=15, loop=100)
|
| 160 |
+
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def plot_attention_trajectory(attention, certainties, input_images, save_dir, filename, args):
|
| 165 |
+
where_most_certain = get_where_most_certain(certainties)
|
| 166 |
+
grid_size = int(math.sqrt(args.parity_sequence_length))
|
| 167 |
+
trajectory = [np.unravel_index(np.argmax(attention[t]), (grid_size, grid_size)) for t in range(args.iterations)]
|
| 168 |
+
x_coords, y_coords = zip(*trajectory)
|
| 169 |
+
|
| 170 |
+
plt.figure(figsize=(5, 5))
|
| 171 |
+
plt.imshow(input_images[0], cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
|
| 172 |
+
|
| 173 |
+
ax = plt.gca()
|
| 174 |
+
nrows, ncols = input_images[0].shape
|
| 175 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 176 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 177 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 178 |
+
ax.tick_params(which="minor", size=0)
|
| 179 |
+
ax.set_axisbelow(False)
|
| 180 |
+
plt.xticks([])
|
| 181 |
+
plt.yticks([])
|
| 182 |
+
|
| 183 |
+
cmap = plt.get_cmap("plasma")
|
| 184 |
+
norm_time = np.linspace(0, 1, len(trajectory))
|
| 185 |
+
|
| 186 |
+
for i in range(len(trajectory) - 1):
|
| 187 |
+
x1, y1 = x_coords[i], y_coords[i]
|
| 188 |
+
x2, y2 = x_coords[i + 1], y_coords[i + 1]
|
| 189 |
+
color = cmap(norm_time[i])
|
| 190 |
+
line, = plt.plot([y1, y2], [x1, x2], color=color, linewidth=6, alpha=0.5, zorder=4)
|
| 191 |
+
line.set_path_effects([
|
| 192 |
+
path_effects.Stroke(linewidth=8, foreground='white'),
|
| 193 |
+
path_effects.Normal()
|
| 194 |
+
])
|
| 195 |
+
|
| 196 |
+
for i, (x, y) in enumerate(trajectory):
|
| 197 |
+
plt.scatter(y, x, color=cmap(norm_time[i]), s=100, edgecolor='white', linewidth=1.5, zorder=5)
|
| 198 |
+
|
| 199 |
+
most_certain_point = trajectory[where_most_certain]
|
| 200 |
+
|
| 201 |
+
plt.plot(most_certain_point[1], most_certain_point[0],
|
| 202 |
+
marker='x', markersize=18, markeredgewidth=5,
|
| 203 |
+
color='white', linestyle='', zorder=6)
|
| 204 |
+
plt.plot(most_certain_point[1], most_certain_point[0],
|
| 205 |
+
marker='x', markersize=15, markeredgewidth=3,
|
| 206 |
+
color=cmap(norm_time[where_most_certain]), linestyle='', zorder=7)
|
| 207 |
+
|
| 208 |
+
plt.tight_layout()
|
| 209 |
+
plt.savefig(f"{save_dir}/{filename}_traj.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 210 |
+
plt.savefig(f"{save_dir}/{filename}_traj.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 211 |
+
plt.show()
|
| 212 |
+
plt.close()
|
| 213 |
+
|
| 214 |
+
def plot_input(input_images, save_dir, filename):
|
| 215 |
+
|
| 216 |
+
plt.figure(figsize=(5, 5))
|
| 217 |
+
plt.imshow(input_images[0], cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
|
| 218 |
+
|
| 219 |
+
ax = plt.gca()
|
| 220 |
+
nrows, ncols = input_images[0].shape
|
| 221 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 222 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 223 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 224 |
+
ax.tick_params(which="minor", size=0)
|
| 225 |
+
ax.set_axisbelow(False)
|
| 226 |
+
plt.xticks([])
|
| 227 |
+
plt.yticks([])
|
| 228 |
+
|
| 229 |
+
plt.tight_layout()
|
| 230 |
+
plt.savefig(f"{save_dir}/{filename}_input.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 231 |
+
plt.savefig(f"{save_dir}/{filename}_input.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 232 |
+
plt.show()
|
| 233 |
+
plt.close()
|
| 234 |
+
|
| 235 |
+
def plot_target(targets, save_dir, filename, args):
|
| 236 |
+
grid_size = int(math.sqrt(args.parity_sequence_length))
|
| 237 |
+
targets_grid = targets[0].reshape(grid_size, grid_size).detach().cpu().numpy()
|
| 238 |
+
plt.figure(figsize=(5, 5))
|
| 239 |
+
plt.imshow(targets_grid, cmap="gray_r", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
|
| 240 |
+
ax = plt.gca()
|
| 241 |
+
nrows, ncols = targets_grid.shape
|
| 242 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 243 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 244 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 245 |
+
ax.tick_params(which="minor", size=0)
|
| 246 |
+
ax.set_axisbelow(False)
|
| 247 |
+
plt.xticks([])
|
| 248 |
+
plt.yticks([])
|
| 249 |
+
plt.tight_layout()
|
| 250 |
+
plt.savefig(f"{save_dir}/{filename}_target.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 251 |
+
plt.savefig(f"{save_dir}/{filename}_target.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 252 |
+
plt.show()
|
| 253 |
+
plt.close()
|
| 254 |
+
|
| 255 |
+
def plot_probabilities(predictions, certainties, save_dir, filename, args):
|
| 256 |
+
grid_size = int(math.sqrt(args.parity_sequence_length))
|
| 257 |
+
where_most_certain = get_where_most_certain(certainties)
|
| 258 |
+
predictions_most_certain = predictions[0, :, :, where_most_certain].detach().cpu().numpy()
|
| 259 |
+
probs = softmax(predictions_most_certain, axis=1)
|
| 260 |
+
probs_grid = probs[:, 0].reshape(grid_size, grid_size)
|
| 261 |
+
plt.figure(figsize=(5, 5))
|
| 262 |
+
plt.imshow(probs_grid, cmap="gray", origin="upper", vmin=0.2, vmax=0.8, interpolation='nearest')
|
| 263 |
+
ax = plt.gca()
|
| 264 |
+
nrows, ncols = probs_grid.shape
|
| 265 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 266 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 267 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 268 |
+
ax.tick_params(which="minor", size=0)
|
| 269 |
+
ax.set_axisbelow(False)
|
| 270 |
+
plt.xticks([])
|
| 271 |
+
plt.yticks([])
|
| 272 |
+
plt.tight_layout()
|
| 273 |
+
plt.savefig(f"{save_dir}/{filename}_probs.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 274 |
+
plt.savefig(f"{save_dir}/{filename}_probs.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 275 |
+
plt.show()
|
| 276 |
+
plt.close()
|
| 277 |
+
|
| 278 |
+
def plot_prediction(predictions, certainties, save_dir, filename, args):
|
| 279 |
+
grid_size = int(math.sqrt(args.parity_sequence_length))
|
| 280 |
+
where_most_certain = get_where_most_certain(certainties)
|
| 281 |
+
predictions_most_certain = predictions[0, :, :, where_most_certain].detach().cpu().numpy()
|
| 282 |
+
class_grid = np.argmax(predictions_most_certain, axis=1).reshape(grid_size, grid_size)
|
| 283 |
+
|
| 284 |
+
plt.figure(figsize=(5, 5))
|
| 285 |
+
plt.imshow(class_grid, cmap="gray_r", origin="upper", vmin=0, vmax=1, interpolation='nearest')
|
| 286 |
+
|
| 287 |
+
ax = plt.gca()
|
| 288 |
+
nrows, ncols = class_grid.shape
|
| 289 |
+
ax.set_xticks(np.arange(-0.5, ncols, 1), minor=True)
|
| 290 |
+
ax.set_yticks(np.arange(-0.5, nrows, 1), minor=True)
|
| 291 |
+
ax.grid(which='minor', color='black', linestyle='-', linewidth=1)
|
| 292 |
+
ax.tick_params(which="minor", size=0)
|
| 293 |
+
ax.set_axisbelow(False)
|
| 294 |
+
plt.xticks([])
|
| 295 |
+
plt.yticks([])
|
| 296 |
+
|
| 297 |
+
plt.tight_layout()
|
| 298 |
+
plt.savefig(f"{save_dir}/{filename}_prediction.png", dpi=300, bbox_inches='tight', pad_inches=0)
|
| 299 |
+
plt.savefig(f"{save_dir}/{filename}_prediction.pdf", format='pdf', bbox_inches='tight', pad_inches=0)
|
| 300 |
+
plt.show()
|
| 301 |
+
plt.close()
|
| 302 |
+
|
| 303 |
+
def plot_accuracy_heatmap(overall_accuracies_avg, average_thinking_time, std_thinking_time, scale, save_path, args):
|
| 304 |
+
fig, ax = plt.subplots(figsize=(scale*10, scale*5))
|
| 305 |
+
im = ax.imshow(overall_accuracies_avg.T * 100, aspect='auto', cmap="viridis", origin='lower', extent=[0, args.iterations-1, 0, args.parity_sequence_length-1], vmin=50, vmax=100)
|
| 306 |
+
cbar = fig.colorbar(im, ax=ax, format="%.1f")
|
| 307 |
+
cbar.set_label("Accuracy (%)")
|
| 308 |
+
ax.errorbar(average_thinking_time, np.arange(args.parity_sequence_length), xerr=std_thinking_time, fmt='ko', markersize=2, capsize=2, elinewidth=1, label="Min. Entropy")
|
| 309 |
+
ax.set_xlabel("Time Step")
|
| 310 |
+
ax.set_ylabel("Sequence Index")
|
| 311 |
+
ax.set_xlim(0, args.iterations-1)
|
| 312 |
+
ax.set_ylim(0, args.parity_sequence_length-1)
|
| 313 |
+
ax.grid(False)
|
| 314 |
+
ax.legend(loc="upper left")
|
| 315 |
+
fig.tight_layout(pad=0.1)
|
| 316 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 317 |
+
fig.savefig(save_path.replace(".png", ".pdf"), format='pdf', bbox_inches="tight")
|
| 318 |
+
plt.close(fig)
|
| 319 |
+
|
| 320 |
+
def plot_attention_heatmap(overall_attentions_avg, scale, save_path, vmin=None, vmax=None):
|
| 321 |
+
overall_attentions_avg = overall_attentions_avg.reshape(overall_attentions_avg.shape[0], -1)
|
| 322 |
+
fig, ax = plt.subplots(figsize=(scale*10, scale*5))
|
| 323 |
+
im = ax.imshow(overall_attentions_avg.T, aspect='auto', cmap="viridis", origin='lower', extent=[0, overall_attentions_avg.shape[0]-1, 0, overall_attentions_avg.shape[1]-1], vmin=vmin, vmax=vmax)
|
| 324 |
+
cbar = fig.colorbar(im, ax=ax, format=FuncFormatter(lambda x, _: f"{x:05.2f}"))
|
| 325 |
+
cbar.set_label("Attention Weight")
|
| 326 |
+
ax.set_xlabel("Time Step")
|
| 327 |
+
ax.set_ylabel("Sequence Index")
|
| 328 |
+
ax.set_xlim(0, overall_attentions_avg.shape[0]-1)
|
| 329 |
+
ax.set_ylim(0, overall_attentions_avg.shape[1]-1)
|
| 330 |
+
ax.grid(False)
|
| 331 |
+
fig.tight_layout(pad=0.1)
|
| 332 |
+
fig.savefig(save_path, dpi=300, bbox_inches="tight")
|
| 333 |
+
fig.savefig(save_path.replace(".png", ".pdf"), format='pdf', bbox_inches="tight")
|
| 334 |
+
plt.close(fig)
|
| 335 |
+
|
| 336 |
+
def create_accuracies_heatmap_gif(all_accuracies, all_average_thinking_times, all_std_thinking_times, scale, save_dir, args):
|
| 337 |
+
heatmap_components_dir = os.path.join(save_dir, "accuracy_heatmaps")
|
| 338 |
+
os.makedirs(heatmap_components_dir, exist_ok=True)
|
| 339 |
+
|
| 340 |
+
image_paths = []
|
| 341 |
+
|
| 342 |
+
for i, (accuracies, avg_thinking_time, std_thinking_time) in enumerate(zip(all_accuracies, all_average_thinking_times, all_std_thinking_times)):
|
| 343 |
+
save_path = os.path.join(heatmap_components_dir, f"frame_{i:04d}.png")
|
| 344 |
+
plot_accuracy_heatmap(accuracies, avg_thinking_time, std_thinking_time, scale, save_path, args)
|
| 345 |
+
image_paths.append(save_path)
|
| 346 |
+
|
| 347 |
+
gif_path = os.path.join(save_dir, "accuracy_heatmap.gif")
|
| 348 |
+
with imageio.get_writer(gif_path, mode='I', duration=0.3) as writer:
|
| 349 |
+
for image_path in image_paths:
|
| 350 |
+
image = imageio.imread(image_path)
|
| 351 |
+
writer.append_data(image)
|
| 352 |
+
|
| 353 |
+
def create_attentions_heatmap_gif(all_attentions, scale, save_path, args):
|
| 354 |
+
heatmap_components_dir = os.path.join(args.log_dir, "attention_heatmaps")
|
| 355 |
+
os.makedirs(heatmap_components_dir, exist_ok=True)
|
| 356 |
+
|
| 357 |
+
global_min = min(attentions.min() for attentions in all_attentions)
|
| 358 |
+
global_max = max(attentions.max() for attentions in all_attentions)
|
| 359 |
+
|
| 360 |
+
image_paths = []
|
| 361 |
+
|
| 362 |
+
for i, attentions in enumerate(all_attentions):
|
| 363 |
+
save_path_component = os.path.join(heatmap_components_dir, f"frame_{i:04d}.png")
|
| 364 |
+
plot_attention_heatmap(attentions, scale, save_path_component, vmin=global_min, vmax=global_max)
|
| 365 |
+
image_paths.append(save_path_component)
|
| 366 |
+
|
| 367 |
+
gif_path = os.path.join(save_path, "attention_heatmap.gif")
|
| 368 |
+
with imageio.get_writer(gif_path, mode='I', duration=0.3) as writer:
|
| 369 |
+
for image_path in image_paths:
|
| 370 |
+
image = imageio.imread(image_path)
|
| 371 |
+
writer.append_data(image)
|
| 372 |
+
|
| 373 |
+
def create_stacked_gif(save_path, y_shift=200):
|
| 374 |
+
accuracy_gif_path = os.path.join(save_path, "accuracy_heatmap.gif")
|
| 375 |
+
attention_gif_path = os.path.join(save_path, "attention_heatmap.gif")
|
| 376 |
+
stacked_gif_path = os.path.join(save_path, "stacked_heatmap.gif")
|
| 377 |
+
|
| 378 |
+
accuracy_reader = imageio.get_reader(accuracy_gif_path)
|
| 379 |
+
attention_reader = imageio.get_reader(attention_gif_path)
|
| 380 |
+
|
| 381 |
+
accuracy_frames = [Image.fromarray(frame) for frame in accuracy_reader]
|
| 382 |
+
attention_frames = [Image.fromarray(frame) for frame in attention_reader]
|
| 383 |
+
|
| 384 |
+
assert len(accuracy_frames) == len(attention_frames), "Mismatch in frame counts between accuracy and attention GIFs"
|
| 385 |
+
|
| 386 |
+
stacked_frames = []
|
| 387 |
+
for acc_frame, att_frame in zip(accuracy_frames, attention_frames):
|
| 388 |
+
acc_width, acc_height = acc_frame.size
|
| 389 |
+
att_width, att_height = att_frame.size
|
| 390 |
+
|
| 391 |
+
# Create base canvas
|
| 392 |
+
stacked_height = acc_height + att_height - y_shift
|
| 393 |
+
stacked_width = max(acc_width, att_width)
|
| 394 |
+
|
| 395 |
+
stacked_frame = Image.new("RGB", (stacked_width, stacked_height), color=(255, 255, 255))
|
| 396 |
+
|
| 397 |
+
# Paste attention frame first, shifted up
|
| 398 |
+
stacked_frame.paste(att_frame, (0, 0)) # Paste at top
|
| 399 |
+
stacked_frame.paste(acc_frame, (0, att_height - y_shift)) # Shift accuracy up by overlap
|
| 400 |
+
|
| 401 |
+
stacked_frames.append(stacked_frame)
|
| 402 |
+
|
| 403 |
+
stacked_frames[0].save(
|
| 404 |
+
stacked_gif_path,
|
| 405 |
+
save_all=True,
|
| 406 |
+
append_images=stacked_frames[1:],
|
| 407 |
+
duration=300,
|
| 408 |
+
loop=0
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
save_frames_to_mp4(
|
| 412 |
+
[np.array(fm)[:, :, ::-1] for fm in stacked_frames],
|
| 413 |
+
f"{stacked_gif_path.replace('gif', 'mp4')}",
|
| 414 |
+
fps=15,
|
| 415 |
+
gop_size=1,
|
| 416 |
+
preset="slow"
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def plot_accuracy_training(all_accuracies, scale, run_model_spefic_save_dir, args):
|
| 421 |
+
scale=0.5
|
| 422 |
+
seq_indices = range(args.parity_sequence_length)
|
| 423 |
+
fig, ax = plt.subplots(figsize=(scale*10, scale*5))
|
| 424 |
+
cmap = plt.get_cmap("viridis")
|
| 425 |
+
|
| 426 |
+
for i, acc in enumerate(all_accuracies):
|
| 427 |
+
color = cmap(i / (len(all_accuracies) - 1))
|
| 428 |
+
ax.plot(seq_indices, acc*100, color=color, alpha=0.7, linewidth=1)
|
| 429 |
+
|
| 430 |
+
num_checkpoints = 5
|
| 431 |
+
checkpoint_percentages = np.linspace(0, 100, num_checkpoints)
|
| 432 |
+
|
| 433 |
+
sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=100))
|
| 434 |
+
sm.set_array([])
|
| 435 |
+
cbar = fig.colorbar(sm, ax=ax)
|
| 436 |
+
cbar.set_label("Training Progress (%)")
|
| 437 |
+
cbar.set_ticks(checkpoint_percentages)
|
| 438 |
+
cbar.set_ticklabels([f"{int(p)}%" for p in checkpoint_percentages])
|
| 439 |
+
|
| 440 |
+
ax.set_xlabel("Sequence Index")
|
| 441 |
+
ax.set_ylabel("Accuracy (%)")
|
| 442 |
+
ax.set_xticks([0, 16 ,32, 48, 63])
|
| 443 |
+
ax.grid(True, alpha=0.5)
|
| 444 |
+
ax.set_xlim(0, args.parity_sequence_length - 1)
|
| 445 |
+
|
| 446 |
+
fig.tight_layout(pad=0.1)
|
| 447 |
+
fig.savefig(f"{run_model_spefic_save_dir}/accuracy_vs_seq_element.png", dpi=300, bbox_inches="tight")
|
| 448 |
+
fig.savefig(f"{run_model_spefic_save_dir}/accuracy_vs_seq_element.pdf", format='pdf', bbox_inches="tight")
|
| 449 |
+
plt.close(fig)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def plot_loss_all_runs(training_data, evaluate_every, save_path="train_loss_comparison_parity.png", step=1, scale=1.0, x_max=None):
|
| 453 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 454 |
+
|
| 455 |
+
grouped = defaultdict(list)
|
| 456 |
+
label_map = {}
|
| 457 |
+
linestyle_map = {}
|
| 458 |
+
iters_map = {}
|
| 459 |
+
model_map = {}
|
| 460 |
+
|
| 461 |
+
for folder, data in training_data.items():
|
| 462 |
+
label, model_type, iters = parse_folder_name(folder)
|
| 463 |
+
if iters is None:
|
| 464 |
+
continue
|
| 465 |
+
|
| 466 |
+
key = f"{model_type}_{iters}"
|
| 467 |
+
grouped[key].append(data["train_losses"])
|
| 468 |
+
label_map[key] = f"{model_type}, {iters} Iters."
|
| 469 |
+
linestyle_map[key] = "--" if model_type == "LSTM" else "-"
|
| 470 |
+
iters_map[key] = iters
|
| 471 |
+
model_map[key] = model_type
|
| 472 |
+
|
| 473 |
+
unique_iters = sorted(set(iters_map.values()))
|
| 474 |
+
base_colors = sns.color_palette("hls", n_colors=len(unique_iters))
|
| 475 |
+
color_lookup = {iters: base_colors[i] for i, iters in enumerate(unique_iters)}
|
| 476 |
+
|
| 477 |
+
legend_entries = []
|
| 478 |
+
global_max_x = 0
|
| 479 |
+
for key in sorted(grouped.keys(), key=lambda k: (iters_map[k], model_map[k])):
|
| 480 |
+
runs = grouped[key]
|
| 481 |
+
if not runs:
|
| 482 |
+
continue
|
| 483 |
+
|
| 484 |
+
iters = iters_map[key]
|
| 485 |
+
color = color_lookup[iters]
|
| 486 |
+
linestyle = linestyle_map[key]
|
| 487 |
+
|
| 488 |
+
min_len = min(len(r) for r in runs)
|
| 489 |
+
trimmed = np.array([r[:min_len] for r in runs])[:, ::step]
|
| 490 |
+
|
| 491 |
+
mean = np.mean(trimmed, axis=0)
|
| 492 |
+
std = np.std(trimmed, axis=0)
|
| 493 |
+
x = np.arange(len(mean)) * step * evaluate_every
|
| 494 |
+
group_max_x = len(mean) * step * evaluate_every
|
| 495 |
+
global_max_x = max(global_max_x, group_max_x)
|
| 496 |
+
|
| 497 |
+
line, = ax.plot(x, mean, color=color, linestyle=linestyle, label=label_map[key])
|
| 498 |
+
ax.fill_between(x, mean - std, mean + std, alpha=0.1, color=color)
|
| 499 |
+
|
| 500 |
+
legend_entries.append((line, label_map[key]))
|
| 501 |
+
|
| 502 |
+
ax.set_xlabel("Training Iterations")
|
| 503 |
+
ax.set_ylabel("Loss")
|
| 504 |
+
ax.grid(True, alpha=0.5)
|
| 505 |
+
|
| 506 |
+
style_legend = [
|
| 507 |
+
Line2D([0], [0], color='black', linestyle='-', label='CTM'),
|
| 508 |
+
Line2D([0], [0], color='black', linestyle='--', label='LSTM')
|
| 509 |
+
]
|
| 510 |
+
color_legend = [
|
| 511 |
+
Line2D([0], [0], color=color_lookup[it], linestyle='-', label=f"{it} Iters.")
|
| 512 |
+
for it in unique_iters
|
| 513 |
+
]
|
| 514 |
+
|
| 515 |
+
if not x_max:
|
| 516 |
+
x_max = global_max_x
|
| 517 |
+
|
| 518 |
+
ax.set_xlim([0, x_max])
|
| 519 |
+
ax.set_ylim(bottom=0)
|
| 520 |
+
ax.set_xticks(np.arange(0, x_max + 1, 50000))
|
| 521 |
+
ax.legend(handles=color_legend + style_legend, loc="upper left")
|
| 522 |
+
fig.tight_layout(pad=0.1)
|
| 523 |
+
fig.savefig(save_path, dpi=300)
|
| 524 |
+
fig.savefig(save_path.replace("png", "pdf"), format='pdf')
|
| 525 |
+
plt.close(fig)
|
| 526 |
+
|
| 527 |
+
def plot_accuracy_all_runs(training_data, evaluate_every, save_path="test_accuracy_comparison_parity.png", step=1, scale=1.0, smooth=False, x_max=None):
|
| 528 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 529 |
+
|
| 530 |
+
grouped = defaultdict(list)
|
| 531 |
+
label_map = {}
|
| 532 |
+
linestyle_map = {}
|
| 533 |
+
iters_map = {}
|
| 534 |
+
model_map = {}
|
| 535 |
+
|
| 536 |
+
for folder, data in training_data.items():
|
| 537 |
+
label, model_type, iters = parse_folder_name(folder)
|
| 538 |
+
if iters is None:
|
| 539 |
+
continue
|
| 540 |
+
|
| 541 |
+
key = f"{model_type}_{iters}"
|
| 542 |
+
grouped[key].append(data["test_accuracies"])
|
| 543 |
+
label_map[key] = f"{model_type}, {iters} Iters."
|
| 544 |
+
linestyle_map[key] = "--" if model_type == "LSTM" else "-"
|
| 545 |
+
iters_map[key] = iters
|
| 546 |
+
model_map[key] = model_type
|
| 547 |
+
|
| 548 |
+
unique_iters = sorted(set(iters_map.values()))
|
| 549 |
+
base_colors = sns.color_palette("hls", n_colors=len(unique_iters))
|
| 550 |
+
color_lookup = {iters: base_colors[i] for i, iters in enumerate(unique_iters)}
|
| 551 |
+
|
| 552 |
+
legend_entries = []
|
| 553 |
+
global_max_x = 0
|
| 554 |
+
|
| 555 |
+
for key in sorted(grouped.keys(), key=lambda k: (iters_map[k], model_map[k])):
|
| 556 |
+
runs = grouped[key]
|
| 557 |
+
if not runs:
|
| 558 |
+
continue
|
| 559 |
+
|
| 560 |
+
iters = iters_map[key]
|
| 561 |
+
model = model_map[key]
|
| 562 |
+
color = color_lookup[iters]
|
| 563 |
+
linestyle = linestyle_map[key]
|
| 564 |
+
|
| 565 |
+
min_len = min(len(r) for r in runs)
|
| 566 |
+
trimmed = np.array([r[:min_len] for r in runs])[:, ::step]
|
| 567 |
+
|
| 568 |
+
mean = np.mean(trimmed, axis=0) * 100
|
| 569 |
+
std = np.std(trimmed, axis=0) * 100
|
| 570 |
+
|
| 571 |
+
if smooth:
|
| 572 |
+
window_size = max(1, int(0.05 * len(mean)))
|
| 573 |
+
if window_size % 2 == 0:
|
| 574 |
+
window_size += 1
|
| 575 |
+
kernel = np.ones(window_size) / window_size
|
| 576 |
+
|
| 577 |
+
smoothed_mean = np.convolve(mean, kernel, mode='same')
|
| 578 |
+
smoothed_std = np.convolve(std, kernel, mode='same')
|
| 579 |
+
|
| 580 |
+
valid_start = window_size // 2
|
| 581 |
+
valid_end = len(mean) - window_size // 2
|
| 582 |
+
valid_length = valid_end - valid_start
|
| 583 |
+
|
| 584 |
+
mean = smoothed_mean[valid_start:valid_end]
|
| 585 |
+
std = smoothed_std[valid_start:valid_end]
|
| 586 |
+
x = np.arange(valid_length) * step * evaluate_every
|
| 587 |
+
group_max_x = valid_length * step * evaluate_every
|
| 588 |
+
else:
|
| 589 |
+
x = np.arange(len(mean)) * step * evaluate_every
|
| 590 |
+
group_max_x = len(mean) * step * evaluate_every
|
| 591 |
+
|
| 592 |
+
global_max_x = max(global_max_x, group_max_x)
|
| 593 |
+
|
| 594 |
+
line, = ax.plot(x, mean, color=color, linestyle=linestyle, label=label_map[key])
|
| 595 |
+
ax.fill_between(x, mean - std, mean + std, alpha=0.1, color=color)
|
| 596 |
+
legend_entries.append((line, label_map[key]))
|
| 597 |
+
|
| 598 |
+
if smooth or x_max is None:
|
| 599 |
+
x_max = global_max_x
|
| 600 |
+
|
| 601 |
+
ax.set_xlim([0, x_max])
|
| 602 |
+
ax.set_ylim(top=100)
|
| 603 |
+
ax.set_xticks(np.arange(0, x_max + 1, 50000))
|
| 604 |
+
ax.set_xlabel("Training Iterations")
|
| 605 |
+
ax.set_ylabel("Accuracy (%)")
|
| 606 |
+
ax.grid(True, alpha=0.5)
|
| 607 |
+
|
| 608 |
+
style_legend = [
|
| 609 |
+
Line2D([0], [0], color='black', linestyle='-', label='CTM'),
|
| 610 |
+
Line2D([0], [0], color='black', linestyle='--', label='LSTM')
|
| 611 |
+
]
|
| 612 |
+
color_legend = [
|
| 613 |
+
Line2D([0], [0], color=color_lookup[it], linestyle='-', label=f"{it} Iters.")
|
| 614 |
+
for it in unique_iters
|
| 615 |
+
]
|
| 616 |
+
ax.legend(handles=color_legend + style_legend, loc="upper left")
|
| 617 |
+
|
| 618 |
+
fig.tight_layout(pad=0.1)
|
| 619 |
+
fig.savefig(save_path, dpi=300)
|
| 620 |
+
fig.savefig(save_path.replace("png", "pdf"), format='pdf')
|
| 621 |
+
plt.close(fig)
|
| 622 |
+
|
| 623 |
+
def extract_run_name(folder, run_index=None):
|
| 624 |
+
# Try to extract from parent folder
|
| 625 |
+
parent = os.path.basename(os.path.dirname(folder))
|
| 626 |
+
match = re.search(r'run(\d+)', parent, re.IGNORECASE)
|
| 627 |
+
if match:
|
| 628 |
+
return f"Run {int(match.group(1))}"
|
| 629 |
+
# Try current folder name
|
| 630 |
+
basename = os.path.basename(folder)
|
| 631 |
+
match = re.search(r'run(\d+)', basename, re.IGNORECASE)
|
| 632 |
+
if match:
|
| 633 |
+
return f"Run {int(match.group(1))}"
|
| 634 |
+
# Fallback: use run index
|
| 635 |
+
if run_index is not None:
|
| 636 |
+
return f"Run {run_index + 1}"
|
| 637 |
+
raise ValueError(f"Could not extract run number from: {folder}")
|
| 638 |
+
|
| 639 |
+
def plot_loss_individual_runs(training_data, evaluate_every, save_dir, scale=1.0, x_max=None):
|
| 640 |
+
|
| 641 |
+
grouped = defaultdict(list)
|
| 642 |
+
label_map = {}
|
| 643 |
+
iters_map = {}
|
| 644 |
+
model_map = {}
|
| 645 |
+
|
| 646 |
+
base_colors = sns.color_palette("hls", n_colors=3)
|
| 647 |
+
color_lookup = {f"Run {i+1}": base_colors[i] for i in range(3)}
|
| 648 |
+
|
| 649 |
+
for i, (folder, data) in enumerate(training_data.items()):
|
| 650 |
+
checkpoint = load_checkpoint(get_latest_checkpoint_file(folder), device="cpu")
|
| 651 |
+
model_args = get_model_args_from_checkpoint(checkpoint)
|
| 652 |
+
label, model_type, iters = parse_folder_name(folder)
|
| 653 |
+
if iters is None:
|
| 654 |
+
continue
|
| 655 |
+
|
| 656 |
+
if model_type.lower() == "ctm":
|
| 657 |
+
memory_length = getattr(model_args, "memory_length", None)
|
| 658 |
+
if memory_length is None:
|
| 659 |
+
raise ValueError(f"CTM model missing memory_length in checkpoint args from: {folder}")
|
| 660 |
+
key = f"{model_type}_{iters}_{memory_length}".lower()
|
| 661 |
+
else:
|
| 662 |
+
key = f"{model_type}_{iters}".lower()
|
| 663 |
+
|
| 664 |
+
run_name = extract_run_name(folder, run_index=i)
|
| 665 |
+
grouped[key].append((run_name, data["train_losses"]))
|
| 666 |
+
label_map[key] = f"{model_type}, {iters} Iters."
|
| 667 |
+
iters_map[key] = iters
|
| 668 |
+
model_map[key] = model_type
|
| 669 |
+
|
| 670 |
+
for key, runs in grouped.items():
|
| 671 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 672 |
+
for run_name, losses in runs:
|
| 673 |
+
x = np.arange(len(losses)) * evaluate_every
|
| 674 |
+
color = color_lookup.get(run_name, 'gray')
|
| 675 |
+
ax.plot(x, losses, label=run_name, color=color, alpha=0.7)
|
| 676 |
+
|
| 677 |
+
ax.set_xlabel("Training Iterations")
|
| 678 |
+
ax.set_ylabel("Loss")
|
| 679 |
+
ax.set_ylim(bottom=-0.01)
|
| 680 |
+
ax.grid(True, alpha=0.5)
|
| 681 |
+
if x_max:
|
| 682 |
+
ax.set_xlim([0, x_max])
|
| 683 |
+
ax.set_xticks(np.arange(0, x_max + 1, 50000))
|
| 684 |
+
ax.legend()
|
| 685 |
+
fig.tight_layout(pad=0.1)
|
| 686 |
+
|
| 687 |
+
subdir = os.path.join(save_dir, key)
|
| 688 |
+
os.makedirs(subdir, exist_ok=True)
|
| 689 |
+
fname = os.path.join(subdir, f"individual_runs_loss_{key}.png")
|
| 690 |
+
fig.savefig(fname, dpi=300)
|
| 691 |
+
fig.savefig(fname.replace("png", "pdf"), format="pdf")
|
| 692 |
+
plt.close(fig)
|
| 693 |
+
|
| 694 |
+
def plot_accuracy_individual_runs(training_data, evaluate_every, save_dir, scale=1.0, smooth=False, x_max=None):
|
| 695 |
+
|
| 696 |
+
grouped = defaultdict(list)
|
| 697 |
+
label_map = {}
|
| 698 |
+
iters_map = {}
|
| 699 |
+
model_map = {}
|
| 700 |
+
|
| 701 |
+
base_colors = sns.color_palette("hls", n_colors=3)
|
| 702 |
+
color_lookup = {f"Run {i+1}": base_colors[i] for i in range(3)}
|
| 703 |
+
|
| 704 |
+
for i, (folder, data) in enumerate(training_data.items()):
|
| 705 |
+
checkpoint = load_checkpoint(get_latest_checkpoint_file(folder), device="cpu")
|
| 706 |
+
model_args = get_model_args_from_checkpoint(checkpoint)
|
| 707 |
+
label, model_type, iters = parse_folder_name(folder)
|
| 708 |
+
if iters is None:
|
| 709 |
+
continue
|
| 710 |
+
|
| 711 |
+
if model_type.lower() == "ctm":
|
| 712 |
+
memory_length = getattr(model_args, "memory_length", None)
|
| 713 |
+
if memory_length is None:
|
| 714 |
+
raise ValueError(f"CTM model missing memory_length in checkpoint args from: {folder}")
|
| 715 |
+
key = f"{model_type}_{iters}_{memory_length}".lower()
|
| 716 |
+
else:
|
| 717 |
+
key = f"{model_type}_{iters}".lower()
|
| 718 |
+
|
| 719 |
+
run_name = extract_run_name(folder, run_index=i)
|
| 720 |
+
grouped[key].append((run_name, data["test_accuracies"]))
|
| 721 |
+
label_map[key] = f"{model_type}, {iters} Iters."
|
| 722 |
+
iters_map[key] = iters
|
| 723 |
+
model_map[key] = model_type
|
| 724 |
+
|
| 725 |
+
for key, runs in grouped.items():
|
| 726 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 727 |
+
for run_name, acc in runs:
|
| 728 |
+
acc = np.array(acc) * 100
|
| 729 |
+
if smooth:
|
| 730 |
+
window_size = max(1, int(0.05 * len(acc)))
|
| 731 |
+
if window_size % 2 == 0:
|
| 732 |
+
window_size += 1
|
| 733 |
+
kernel = np.ones(window_size) / window_size
|
| 734 |
+
acc = np.convolve(acc, kernel, mode="same")
|
| 735 |
+
|
| 736 |
+
x = np.arange(len(acc)) * evaluate_every
|
| 737 |
+
color = color_lookup.get(run_name, 'gray')
|
| 738 |
+
ax.plot(x, acc, label=run_name, color=color, alpha=0.7)
|
| 739 |
+
|
| 740 |
+
ax.set_xlabel("Training Iterations")
|
| 741 |
+
ax.set_ylabel("Accuracy (%)")
|
| 742 |
+
ax.set_ylim([50, 101])
|
| 743 |
+
ax.grid(True, alpha=0.5)
|
| 744 |
+
if x_max:
|
| 745 |
+
ax.set_xlim([0, x_max])
|
| 746 |
+
ax.set_xticks(np.arange(0, x_max + 1, 50000))
|
| 747 |
+
ax.legend()
|
| 748 |
+
fig.tight_layout(pad=0.1)
|
| 749 |
+
|
| 750 |
+
subdir = os.path.join(save_dir, key)
|
| 751 |
+
os.makedirs(subdir, exist_ok=True)
|
| 752 |
+
fname = os.path.join(subdir, f"individual_runs_accuracy_{key}.png")
|
| 753 |
+
fig.savefig(fname, dpi=300)
|
| 754 |
+
fig.savefig(fname.replace("png", "pdf"), format="pdf")
|
| 755 |
+
plt.close(fig)
|
| 756 |
+
|
| 757 |
+
def plot_training_curve_all_runs(all_folders, save_dir, scale, device, smooth=False, x_max=None, plot_individual_runs=True):
|
| 758 |
+
|
| 759 |
+
all_folders = [folder for folder in all_folders if "certain" not in folder]
|
| 760 |
+
|
| 761 |
+
training_data = {}
|
| 762 |
+
evaluation_intervals = []
|
| 763 |
+
for folder in all_folders:
|
| 764 |
+
latest_checkpoint_path = get_latest_checkpoint_file(folder)
|
| 765 |
+
if latest_checkpoint_path:
|
| 766 |
+
checkpoint = load_checkpoint(latest_checkpoint_path, device=device)
|
| 767 |
+
model_args = get_model_args_from_checkpoint(checkpoint)
|
| 768 |
+
evaluation_intervals.append(model_args.track_every)
|
| 769 |
+
|
| 770 |
+
_, train_losses, test_losses, train_accuracies, test_accuracies = get_accuracy_and_loss_from_checkpoint(checkpoint, device=device)
|
| 771 |
+
training_data[folder] = {
|
| 772 |
+
"train_losses": train_losses,
|
| 773 |
+
"test_losses": test_losses,
|
| 774 |
+
"train_accuracies": train_accuracies,
|
| 775 |
+
"test_accuracies": test_accuracies
|
| 776 |
+
}
|
| 777 |
+
else:
|
| 778 |
+
print(f"No checkpoint found for {folder}")
|
| 779 |
+
|
| 780 |
+
assert len(evaluation_intervals) > 0, "No valid checkpoints found."
|
| 781 |
+
assert all(interval == evaluation_intervals[0] for interval in evaluation_intervals), "Evaluation intervals are not consistent across runs."
|
| 782 |
+
|
| 783 |
+
evaluate_every = evaluation_intervals[0]
|
| 784 |
+
|
| 785 |
+
if plot_individual_runs:
|
| 786 |
+
plot_loss_individual_runs(training_data, evaluate_every, save_dir=save_dir, scale=scale, x_max=x_max)
|
| 787 |
+
plot_accuracy_individual_runs(training_data, evaluate_every, save_dir=save_dir, scale=scale, smooth=smooth, x_max=x_max)
|
| 788 |
+
|
| 789 |
+
plot_loss_all_runs(training_data, evaluate_every, save_path=f"{save_dir}/loss_comparison.png", scale=scale, x_max=x_max)
|
| 790 |
+
plot_accuracy_all_runs(training_data, evaluate_every, save_path=f"{save_dir}/accuracy_comparison.png", scale=scale, smooth=smooth, x_max=x_max)
|
| 791 |
+
|
| 792 |
+
return training_data
|
| 793 |
+
|
| 794 |
+
def plot_accuracy_thinking_time(csv_path, scale, output_dir="analysis/cifar"):
|
| 795 |
+
if not os.path.exists(csv_path):
|
| 796 |
+
raise FileNotFoundError(f"CSV file not found: {csv_path}")
|
| 797 |
+
|
| 798 |
+
df = pd.read_csv(csv_path)
|
| 799 |
+
df["RunName"] = df["Run"].apply(lambda x: os.path.basename(os.path.dirname(x)))
|
| 800 |
+
df["Model"] = df["Run"].apply(lambda x: "CTM" if "ctm" in x.lower() else "LSTM")
|
| 801 |
+
|
| 802 |
+
grouped = df.groupby(["Model", "Num Iterations"])
|
| 803 |
+
summary = grouped.agg(
|
| 804 |
+
mean_accuracy=("Overall Mean Accuracy", "mean"),
|
| 805 |
+
std_accuracy=("Overall Std Accuracy", lambda x: np.sqrt(np.mean(x**2)))
|
| 806 |
+
).reset_index()
|
| 807 |
+
|
| 808 |
+
summary["mean_accuracy"] *= 100
|
| 809 |
+
summary["std_accuracy"] *= 100
|
| 810 |
+
|
| 811 |
+
fig, ax = plt.subplots(figsize=(scale*5, scale*5))
|
| 812 |
+
|
| 813 |
+
for model in ("CTM", "LSTM"):
|
| 814 |
+
subset = summary[summary["Model"] == model].sort_values(by="Num Iterations")
|
| 815 |
+
linestyle = "-" if model == "CTM" else "--"
|
| 816 |
+
ax.errorbar(
|
| 817 |
+
subset["Num Iterations"],
|
| 818 |
+
subset["mean_accuracy"],
|
| 819 |
+
yerr=subset["std_accuracy"],
|
| 820 |
+
linestyle=linestyle,
|
| 821 |
+
color="black",
|
| 822 |
+
marker='.',
|
| 823 |
+
label=model,
|
| 824 |
+
capsize=3,
|
| 825 |
+
elinewidth=1,
|
| 826 |
+
errorevery=1
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
ax.set_xlabel("Internal Ticks")
|
| 830 |
+
ax.set_ylabel("Accuracy (%)")
|
| 831 |
+
custom_lines = [
|
| 832 |
+
Line2D([0], [0], color='black', linestyle='-', label='CTM'),
|
| 833 |
+
Line2D([0], [0], color='black', linestyle='--', label='LSTM')
|
| 834 |
+
]
|
| 835 |
+
ax.legend(handles=custom_lines, loc="lower right")
|
| 836 |
+
ax.grid(True, alpha=0.5)
|
| 837 |
+
|
| 838 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 839 |
+
output_path_png = os.path.join(output_dir, "accuracy_vs_thinking_time.png")
|
| 840 |
+
fig.tight_layout(pad=0.1)
|
| 841 |
+
fig.savefig(output_path_png, dpi=300)
|
| 842 |
+
fig.savefig(output_path_png.replace("png", "pdf"), format='pdf')
|
| 843 |
+
plt.close(fig)
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
def plot_lstm_last_and_certain_accuracy(all_folders, save_path="lstm_last_and_certain_accuracy.png", scale=1.0, step=1, x_max=None):
|
| 847 |
+
|
| 848 |
+
tags = ["lstm_10", "lstm_10_certain", "lstm_25", "lstm_25_certain"]
|
| 849 |
+
folders = [f for f in all_folders if any(tag in f.lower() for tag in tags)]
|
| 850 |
+
|
| 851 |
+
training_data, eval_intervals = {}, []
|
| 852 |
+
for f in folders:
|
| 853 |
+
cp = get_latest_checkpoint_file(f)
|
| 854 |
+
if not cp:
|
| 855 |
+
print(f"⚠️ No checkpoint in {f}")
|
| 856 |
+
continue
|
| 857 |
+
ckpt = load_checkpoint(cp, device="cpu")
|
| 858 |
+
args = get_model_args_from_checkpoint(ckpt)
|
| 859 |
+
eval_intervals.append(args.track_every)
|
| 860 |
+
_, _, _, _, acc = get_accuracy_and_loss_from_checkpoint(ckpt)
|
| 861 |
+
iters = "25" if "25" in f.lower() else "10"
|
| 862 |
+
label = "Certain" if "certain" in f.lower() else "Final"
|
| 863 |
+
training_data.setdefault((iters, label), []).append(acc)
|
| 864 |
+
|
| 865 |
+
assert training_data and all(i == eval_intervals[0] for i in eval_intervals), "Missing or inconsistent eval intervals."
|
| 866 |
+
evaluate_every = eval_intervals[0]
|
| 867 |
+
|
| 868 |
+
keys = sorted(training_data.keys())
|
| 869 |
+
colors = sns.color_palette("hls", n_colors=len(keys))
|
| 870 |
+
style_map = {key: ("--" if key[1] == "Certain" else "-") for key in keys}
|
| 871 |
+
color_map = {key: colors[i] for i, key in enumerate(keys)}
|
| 872 |
+
|
| 873 |
+
fig, ax = plt.subplots(figsize=(scale * 10, scale * 5))
|
| 874 |
+
max_x = 0
|
| 875 |
+
|
| 876 |
+
for key in keys:
|
| 877 |
+
runs = training_data[key]
|
| 878 |
+
min_len = min(len(r) for r in runs)
|
| 879 |
+
trimmed = np.stack([r[:min_len] for r in runs], axis=0)[:, ::step]
|
| 880 |
+
mean, std = np.mean(trimmed, 0) * 100, np.std(trimmed, 0) * 100
|
| 881 |
+
x = np.arange(len(mean)) * step * evaluate_every
|
| 882 |
+
ax.plot(x, mean, color=color_map[key], linestyle=style_map[key],
|
| 883 |
+
label=f"{key[0]} Iters, {key[1]}", linewidth=2, alpha=0.7)
|
| 884 |
+
ax.fill_between(x, mean - std, mean + std, color=color_map[key], alpha=0.1)
|
| 885 |
+
max_x = max(max_x, x[-1])
|
| 886 |
+
|
| 887 |
+
ax.set_xlim([0, x_max or max_x])
|
| 888 |
+
ax.set_xticks(np.arange(0, (x_max or max_x) + 1, 50000))
|
| 889 |
+
ax.set_xlabel("Training Iterations")
|
| 890 |
+
ax.set_ylabel("Accuracy (%)")
|
| 891 |
+
ax.grid(True, alpha=0.5)
|
| 892 |
+
ax.legend(loc="lower right")
|
| 893 |
+
fig.tight_layout(pad=0.1)
|
| 894 |
+
fig.savefig(save_path, dpi=300)
|
| 895 |
+
fig.savefig(save_path.replace("png", "pdf"), format="pdf")
|
| 896 |
+
plt.close(fig)
|
tasks/parity/scripts/train_ctm_100_50.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=100
|
| 4 |
+
MEMORY_LENGTH=50
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--memory_length $MEMORY_LENGTH \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 1024 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--n_synch_out 32 \
|
| 18 |
+
--n_synch_action 32 \
|
| 19 |
+
--synapse_depth 1 \
|
| 20 |
+
--heads 8 \
|
| 21 |
+
--memory_hidden_dims 16 \
|
| 22 |
+
--dropout 0.0 \
|
| 23 |
+
--deep_memory \
|
| 24 |
+
--no-do_normalisation \
|
| 25 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 26 |
+
--backbone_type="parity_backbone" \
|
| 27 |
+
--no-full_eval \
|
| 28 |
+
--weight_decay 0.0 \
|
| 29 |
+
--gradient_clipping 0.9 \
|
| 30 |
+
--use_scheduler \
|
| 31 |
+
--scheduler_type "cosine" \
|
| 32 |
+
--milestones 0 0 0 \
|
| 33 |
+
--gamma 0 \
|
| 34 |
+
--dataset "parity" \
|
| 35 |
+
--batch_size 64 \
|
| 36 |
+
--batch_size_test 256 \
|
| 37 |
+
--lr=0.0001 \
|
| 38 |
+
--training_iterations 200001 \
|
| 39 |
+
--warmup_steps 500 \
|
| 40 |
+
--track_every 1000 \
|
| 41 |
+
--save_every 10000 \
|
| 42 |
+
--no-reload \
|
| 43 |
+
--no-reload_model_only \
|
| 44 |
+
--device 0 \
|
| 45 |
+
--no-use_amp \
|
| 46 |
+
--neuron_select_type "random"
|
tasks/parity/scripts/train_ctm_10_5.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=10
|
| 4 |
+
MEMORY_LENGTH=5
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--memory_length $MEMORY_LENGTH \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 1024 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--n_synch_out 32 \
|
| 18 |
+
--n_synch_action 32 \
|
| 19 |
+
--synapse_depth 1 \
|
| 20 |
+
--heads 8 \
|
| 21 |
+
--memory_hidden_dims 16 \
|
| 22 |
+
--dropout 0.0 \
|
| 23 |
+
--deep_memory \
|
| 24 |
+
--no-do_normalisation \
|
| 25 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 26 |
+
--backbone_type="parity_backbone" \
|
| 27 |
+
--no-full_eval \
|
| 28 |
+
--weight_decay 0.0 \
|
| 29 |
+
--gradient_clipping 0.9 \
|
| 30 |
+
--use_scheduler \
|
| 31 |
+
--scheduler_type "cosine" \
|
| 32 |
+
--milestones 0 0 0 \
|
| 33 |
+
--gamma 0 \
|
| 34 |
+
--dataset "parity" \
|
| 35 |
+
--batch_size 64 \
|
| 36 |
+
--batch_size_test 256 \
|
| 37 |
+
--lr=0.0001 \
|
| 38 |
+
--training_iterations 200001 \
|
| 39 |
+
--warmup_steps 500 \
|
| 40 |
+
--track_every 1000 \
|
| 41 |
+
--save_every 10000 \
|
| 42 |
+
--no-reload \
|
| 43 |
+
--no-reload_model_only \
|
| 44 |
+
--device 0 \
|
| 45 |
+
--no-use_amp \
|
| 46 |
+
--neuron_select_type "random"
|
tasks/parity/scripts/train_ctm_1_1.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=1
|
| 4 |
+
MEMORY_LENGTH=1
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--memory_length $MEMORY_LENGTH \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 1024 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--n_synch_out 32 \
|
| 18 |
+
--n_synch_action 32 \
|
| 19 |
+
--synapse_depth 1 \
|
| 20 |
+
--heads 8 \
|
| 21 |
+
--memory_hidden_dims 16 \
|
| 22 |
+
--dropout 0.0 \
|
| 23 |
+
--deep_memory \
|
| 24 |
+
--no-do_normalisation \
|
| 25 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 26 |
+
--backbone_type="parity_backbone" \
|
| 27 |
+
--no-full_eval \
|
| 28 |
+
--weight_decay 0.0 \
|
| 29 |
+
--gradient_clipping 0.9 \
|
| 30 |
+
--use_scheduler \
|
| 31 |
+
--scheduler_type "cosine" \
|
| 32 |
+
--milestones 0 0 0 \
|
| 33 |
+
--gamma 0 \
|
| 34 |
+
--dataset "parity" \
|
| 35 |
+
--batch_size 64 \
|
| 36 |
+
--batch_size_test 256 \
|
| 37 |
+
--lr=0.0001 \
|
| 38 |
+
--training_iterations 200001 \
|
| 39 |
+
--warmup_steps 500 \
|
| 40 |
+
--track_every 1000 \
|
| 41 |
+
--save_every 10000 \
|
| 42 |
+
--no-reload \
|
| 43 |
+
--no-reload_model_only \
|
| 44 |
+
--device 0 \
|
| 45 |
+
--no-use_amp \
|
| 46 |
+
--neuron_select_type "random"
|
tasks/parity/scripts/train_ctm_25_10.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=25
|
| 4 |
+
MEMORY_LENGTH=10
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--memory_length $MEMORY_LENGTH \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 1024 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--n_synch_out 32 \
|
| 18 |
+
--n_synch_action 32 \
|
| 19 |
+
--synapse_depth 1 \
|
| 20 |
+
--heads 8 \
|
| 21 |
+
--memory_hidden_dims 16 \
|
| 22 |
+
--dropout 0.0 \
|
| 23 |
+
--deep_memory \
|
| 24 |
+
--no-do_normalisation \
|
| 25 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 26 |
+
--backbone_type="parity_backbone" \
|
| 27 |
+
--no-full_eval \
|
| 28 |
+
--weight_decay 0.0 \
|
| 29 |
+
--gradient_clipping 0.9 \
|
| 30 |
+
--use_scheduler \
|
| 31 |
+
--scheduler_type "cosine" \
|
| 32 |
+
--milestones 0 0 0 \
|
| 33 |
+
--gamma 0 \
|
| 34 |
+
--dataset "parity" \
|
| 35 |
+
--batch_size 64 \
|
| 36 |
+
--batch_size_test 256 \
|
| 37 |
+
--lr=0.0001 \
|
| 38 |
+
--training_iterations 200001 \
|
| 39 |
+
--warmup_steps 500 \
|
| 40 |
+
--track_every 1000 \
|
| 41 |
+
--save_every 10000 \
|
| 42 |
+
--no-reload \
|
| 43 |
+
--no-reload_model_only \
|
| 44 |
+
--device 0 \
|
| 45 |
+
--no-use_amp \
|
| 46 |
+
--neuron_select_type "random"
|
tasks/parity/scripts/train_ctm_50_25.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=50
|
| 4 |
+
MEMORY_LENGTH=25
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--memory_length $MEMORY_LENGTH \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 1024 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--n_synch_out 32 \
|
| 18 |
+
--n_synch_action 32 \
|
| 19 |
+
--synapse_depth 1 \
|
| 20 |
+
--heads 8 \
|
| 21 |
+
--memory_hidden_dims 16 \
|
| 22 |
+
--dropout 0.0 \
|
| 23 |
+
--deep_memory \
|
| 24 |
+
--no-do_normalisation \
|
| 25 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 26 |
+
--backbone_type="parity_backbone" \
|
| 27 |
+
--no-full_eval \
|
| 28 |
+
--weight_decay 0.0 \
|
| 29 |
+
--gradient_clipping 0.9 \
|
| 30 |
+
--use_scheduler \
|
| 31 |
+
--scheduler_type "cosine" \
|
| 32 |
+
--milestones 0 0 0 \
|
| 33 |
+
--gamma 0 \
|
| 34 |
+
--dataset "parity" \
|
| 35 |
+
--batch_size 64 \
|
| 36 |
+
--batch_size_test 256 \
|
| 37 |
+
--lr=0.0001 \
|
| 38 |
+
--training_iterations 200001 \
|
| 39 |
+
--warmup_steps 500 \
|
| 40 |
+
--track_every 1000 \
|
| 41 |
+
--save_every 10000 \
|
| 42 |
+
--no-reload \
|
| 43 |
+
--no-reload_model_only \
|
| 44 |
+
--device 0 \
|
| 45 |
+
--no-use_amp \
|
| 46 |
+
--neuron_select_type "random"
|
tasks/parity/scripts/train_ctm_75_25.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=75
|
| 4 |
+
MEMORY_LENGTH=25
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--memory_length $MEMORY_LENGTH \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 1024 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--n_synch_out 32 \
|
| 18 |
+
--n_synch_action 32 \
|
| 19 |
+
--synapse_depth 1 \
|
| 20 |
+
--heads 8 \
|
| 21 |
+
--memory_hidden_dims 16 \
|
| 22 |
+
--dropout 0.0 \
|
| 23 |
+
--deep_memory \
|
| 24 |
+
--no-do_normalisation \
|
| 25 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 26 |
+
--backbone_type="parity_backbone" \
|
| 27 |
+
--no-full_eval \
|
| 28 |
+
--weight_decay 0.0 \
|
| 29 |
+
--gradient_clipping 0.9 \
|
| 30 |
+
--use_scheduler \
|
| 31 |
+
--scheduler_type "cosine" \
|
| 32 |
+
--milestones 0 0 0 \
|
| 33 |
+
--gamma 0 \
|
| 34 |
+
--dataset "parity" \
|
| 35 |
+
--batch_size 64 \
|
| 36 |
+
--batch_size_test 256 \
|
| 37 |
+
--lr=0.0001 \
|
| 38 |
+
--training_iterations 200001 \
|
| 39 |
+
--warmup_steps 500 \
|
| 40 |
+
--track_every 1000 \
|
| 41 |
+
--save_every 10000 \
|
| 42 |
+
--no-reload \
|
| 43 |
+
--no-reload_model_only \
|
| 44 |
+
--device 0 \
|
| 45 |
+
--no-use_amp \
|
| 46 |
+
--neuron_select_type "random"
|
tasks/parity/scripts/train_lstm_1.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=1
|
| 4 |
+
MODEL_TYPE="lstm"
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--model_type $MODEL_TYPE \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 669 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--heads 8 \
|
| 18 |
+
--dropout 0.0 \
|
| 19 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 20 |
+
--backbone_type="parity_backbone" \
|
| 21 |
+
--no-full_eval \
|
| 22 |
+
--weight_decay 0.0 \
|
| 23 |
+
--gradient_clipping -1 \
|
| 24 |
+
--use_scheduler \
|
| 25 |
+
--scheduler_type "cosine" \
|
| 26 |
+
--milestones 0 0 0 \
|
| 27 |
+
--gamma 0 \
|
| 28 |
+
--dataset "parity" \
|
| 29 |
+
--batch_size 64 \
|
| 30 |
+
--batch_size_test 256 \
|
| 31 |
+
--lr=0.0001 \
|
| 32 |
+
--training_iterations 200001 \
|
| 33 |
+
--warmup_steps 500 \
|
| 34 |
+
--track_every 1000 \
|
| 35 |
+
--save_every 10000 \
|
| 36 |
+
--no-reload \
|
| 37 |
+
--no-reload_model_only \
|
| 38 |
+
--device 0 \
|
| 39 |
+
--no-use_amp \
|
tasks/parity/scripts/train_lstm_10.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=10
|
| 4 |
+
MODEL_TYPE="lstm"
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--model_type $MODEL_TYPE \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 686 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--heads 8 \
|
| 18 |
+
--dropout 0.0 \
|
| 19 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 20 |
+
--backbone_type="parity_backbone" \
|
| 21 |
+
--no-full_eval \
|
| 22 |
+
--weight_decay 0.0 \
|
| 23 |
+
--gradient_clipping -1 \
|
| 24 |
+
--use_scheduler \
|
| 25 |
+
--scheduler_type "cosine" \
|
| 26 |
+
--milestones 0 0 0 \
|
| 27 |
+
--gamma 0 \
|
| 28 |
+
--dataset "parity" \
|
| 29 |
+
--batch_size 64 \
|
| 30 |
+
--batch_size_test 256 \
|
| 31 |
+
--lr=0.0001 \
|
| 32 |
+
--training_iterations 200001 \
|
| 33 |
+
--warmup_steps 500 \
|
| 34 |
+
--track_every 1000 \
|
| 35 |
+
--save_every 10000 \
|
| 36 |
+
--no-reload \
|
| 37 |
+
--no-reload_model_only \
|
| 38 |
+
--device 0 \
|
| 39 |
+
--no-use_amp \
|
tasks/parity/scripts/train_lstm_100.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=100
|
| 4 |
+
MODEL_TYPE="lstm"
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--model_type $MODEL_TYPE \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 857 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--heads 8 \
|
| 18 |
+
--dropout 0.0 \
|
| 19 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 20 |
+
--backbone_type="parity_backbone" \
|
| 21 |
+
--no-full_eval \
|
| 22 |
+
--weight_decay 0.0 \
|
| 23 |
+
--gradient_clipping -1 \
|
| 24 |
+
--use_scheduler \
|
| 25 |
+
--scheduler_type "cosine" \
|
| 26 |
+
--milestones 0 0 0 \
|
| 27 |
+
--gamma 0 \
|
| 28 |
+
--dataset "parity" \
|
| 29 |
+
--batch_size 64 \
|
| 30 |
+
--batch_size_test 256 \
|
| 31 |
+
--lr=0.0001 \
|
| 32 |
+
--training_iterations 200001 \
|
| 33 |
+
--warmup_steps 500 \
|
| 34 |
+
--track_every 1000 \
|
| 35 |
+
--save_every 10000 \
|
| 36 |
+
--no-reload \
|
| 37 |
+
--no-reload_model_only \
|
| 38 |
+
--device 0 \
|
| 39 |
+
--no-use_amp \
|
tasks/parity/scripts/train_lstm_10_certain.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=3
|
| 3 |
+
ITERATIONS=10
|
| 4 |
+
MODEL_TYPE="lstm"
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}_certain"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--model_type $MODEL_TYPE \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 686 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--heads 8 \
|
| 18 |
+
--dropout 0.0 \
|
| 19 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 20 |
+
--backbone_type="parity_backbone" \
|
| 21 |
+
--no-full_eval \
|
| 22 |
+
--weight_decay 0.0 \
|
| 23 |
+
--gradient_clipping -1 \
|
| 24 |
+
--use_scheduler \
|
| 25 |
+
--scheduler_type "cosine" \
|
| 26 |
+
--milestones 0 0 0 \
|
| 27 |
+
--gamma 0 \
|
| 28 |
+
--dataset "parity" \
|
| 29 |
+
--batch_size 64 \
|
| 30 |
+
--batch_size_test 256 \
|
| 31 |
+
--lr=0.0001 \
|
| 32 |
+
--training_iterations 200001 \
|
| 33 |
+
--warmup_steps 500 \
|
| 34 |
+
--track_every 1000 \
|
| 35 |
+
--save_every 10000 \
|
| 36 |
+
--no-reload \
|
| 37 |
+
--no-reload_model_only \
|
| 38 |
+
--device 0 \
|
| 39 |
+
--no-use_amp \
|
| 40 |
+
--use_most_certain_with_lstm \
|
tasks/parity/scripts/train_lstm_25.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=1
|
| 3 |
+
ITERATIONS=25
|
| 4 |
+
MODEL_TYPE="lstm"
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--model_type $MODEL_TYPE \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 706 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--heads 8 \
|
| 18 |
+
--dropout 0.0 \
|
| 19 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 20 |
+
--backbone_type="parity_backbone" \
|
| 21 |
+
--no-full_eval \
|
| 22 |
+
--weight_decay 0.0 \
|
| 23 |
+
--gradient_clipping -1 \
|
| 24 |
+
--use_scheduler \
|
| 25 |
+
--scheduler_type "cosine" \
|
| 26 |
+
--milestones 0 0 0 \
|
| 27 |
+
--gamma 0 \
|
| 28 |
+
--dataset "parity" \
|
| 29 |
+
--batch_size 64 \
|
| 30 |
+
--batch_size_test 256 \
|
| 31 |
+
--lr=0.0001 \
|
| 32 |
+
--training_iterations 200001 \
|
| 33 |
+
--warmup_steps 500 \
|
| 34 |
+
--track_every 1000 \
|
| 35 |
+
--save_every 10000 \
|
| 36 |
+
--no-reload \
|
| 37 |
+
--no-reload_model_only \
|
| 38 |
+
--device 0 \
|
| 39 |
+
--no-use_amp \
|
tasks/parity/scripts/train_lstm_25_certain.sh
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
RUN=3
|
| 3 |
+
ITERATIONS=25
|
| 4 |
+
MODEL_TYPE="lstm"
|
| 5 |
+
LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}_certain"
|
| 6 |
+
SEED=$((RUN - 1))
|
| 7 |
+
|
| 8 |
+
python -m tasks.parity.train \
|
| 9 |
+
--log_dir $LOG_DIR \
|
| 10 |
+
--seed $SEED \
|
| 11 |
+
--iterations $ITERATIONS \
|
| 12 |
+
--model_type $MODEL_TYPE \
|
| 13 |
+
--parity_sequence_length 64 \
|
| 14 |
+
--n_test_batches 20 \
|
| 15 |
+
--d_model 706 \
|
| 16 |
+
--d_input 512 \
|
| 17 |
+
--heads 8 \
|
| 18 |
+
--dropout 0.0 \
|
| 19 |
+
--positional_embedding_type="custom-rotational-1d" \
|
| 20 |
+
--backbone_type="parity_backbone" \
|
| 21 |
+
--no-full_eval \
|
| 22 |
+
--weight_decay 0.0 \
|
| 23 |
+
--gradient_clipping -1 \
|
| 24 |
+
--use_scheduler \
|
| 25 |
+
--scheduler_type "cosine" \
|
| 26 |
+
--milestones 0 0 0 \
|
| 27 |
+
--gamma 0 \
|
| 28 |
+
--dataset "parity" \
|
| 29 |
+
--batch_size 64 \
|
| 30 |
+
--batch_size_test 256 \
|
| 31 |
+
--lr=0.0001 \
|
| 32 |
+
--training_iterations 200001 \
|
| 33 |
+
--warmup_steps 500 \
|
| 34 |
+
--track_every 1000 \
|
| 35 |
+
--save_every 10000 \
|
| 36 |
+
--no-reload \
|
| 37 |
+
--no-reload_model_only \
|
| 38 |
+
--device 0 \
|
| 39 |
+
--no-use_amp \
|
| 40 |
+
--use_most_certain_with_lstm \
|