Uday commited on
Commit
06bb098
Β·
1 Parent(s): c8c8629

Update README with HF metadata and Energy Halting info

Browse files
Files changed (1) hide show
  1. README.md +194 -98
README.md CHANGED
@@ -1,141 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # πŸ•°οΈ The Continuous Thought Machine
2
 
3
  πŸ“š [PAPER: Technical Report](https://arxiv.org/abs/2505.05522) | πŸ“ [Blog](https://sakana.ai/ctm/) | πŸ•ΉοΈ [Interactive Website](https://pub.sakana.ai/ctm) | ✏️ [Tutorial](examples/01_mnist.ipynb)
4
 
5
- ![Activations](assets/activations.gif)
6
 
7
- We present the Continuous Thought Machine (CTM), a model designed to unfold and then leverage neural activity as the underlying mechanism for observation and action. Our contributions are:
8
 
9
- 1. An internal temporal axis, decoupled from any input data, that enables neuron activity to unfold.
10
 
11
- 2. Neuron-level temporal processing, where each neuron uses unique weight parameters to process a history of incoming signals, enabling fine-grained temporal dynamics.
 
 
12
 
13
- 3. Neural synchronisation, employed as a direct latent representation for modulating data and producing outputs, thus directly encoding information in the timing of neural activity.
14
 
15
- We demonstrate the CTM's strong performance and versatility across a range of challenging tasks, including ImageNet classification, solving 2D mazes, sorting, parity computation, question-answering, and RL tasks.
16
 
17
- We provide all necessary code to reproduce our results and invite others to build upon and use CTMs in their own work.
18
 
19
- ## [Interactive Website](https://pub.sakana.ai/ctm)
20
- Please see our [Interactive Website](https://pub.sakana.ai/ctm) for a maze-solving demo, many demonstrative videos of the method, results, and other findings.
21
 
 
22
 
23
- ## Repo structure
24
- ```
25
- β”œβ”€β”€ tasks
26
- β”‚Β Β  β”œβ”€β”€ image_classification
27
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ train.py # Training code for image classification (cifar, imagenet)
28
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ imagenet_classes.py # Helper for imagenet class names
29
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ plotting.py # Plotting utils specific to this task
30
- β”‚Β Β  β”‚Β Β  └── analysis
31
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€run_imagenet_analysis.py # ImageNet eval and visualisation code
32
- β”‚Β Β  β”‚Β Β  Β Β  └──outputs/ # Folder for outputs of analysis
33
- β”‚Β Β  β”œβ”€β”€ mazes
34
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ train.py # Training code for solving 2D mazes (by way of a route; see paper)
35
- β”‚Β Β  β”‚Β Β  └── plotting.py # Plotting utils specific to this task
36
- β”‚Β Β  β”‚Β Β  └── analysis
37
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€run.py # Maze analysis code
38
- β”‚Β Β  β”‚Β Β  Β Β  └──outputs/ # Folder for outputs of analysis
39
- β”‚Β Β  β”œβ”€β”€ sort
40
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ train.py # Training code for sorting
41
- β”‚Β Β  β”‚Β Β  └── utils.py # Sort specific utils (e.g., CTC decode)
42
- β”‚Β Β  β”œβ”€β”€ parity
43
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ train.py # Training code for parity task
44
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ utils.py # Parity-specific helper functions
45
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ plotting.py # Plotting utils specific to this task
46
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ scripts/
47
- β”‚Β Β  β”‚Β Β  β”‚Β Β  └── *.sh # Training scripts for different experimental setups
48
- β”‚Β Β  β”‚Β Β  └── analysis/
49
- β”‚Β Β  β”‚Β Β  └── run.py # Entry point for parity analysis
50
- β”‚Β Β  β”œβ”€β”€ qamnist
51
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ train.py # Training code for QAMNIST task (quantized MNIST)
52
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ utils.py # QAMNIST-specific helper functions
53
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ plotting.py # Plotting utils specific to this task
54
- β”‚Β Β  β”‚Β Β  β”œβ”€β”€ scripts/
55
- β”‚Β Β  β”‚Β Β  β”‚Β Β  └── *.sh # Training scripts for different experimental setups
56
- β”‚Β Β  β”‚Β Β  └── analysis/
57
- β”‚Β Β  β”‚Β Β  └── run.py # Entry point for QAMNIST analysis
58
- β”‚Β Β  └── rl
59
- β”‚Β Β  Β Β  β”œβ”€β”€ train.py # Training code for RL environments
60
- β”‚Β Β  Β Β  β”œβ”€β”€ utils.py # RL-specific helper functions
61
- β”‚Β Β  Β Β  β”œβ”€β”€ plotting.py # Plotting utils specific to this task
62
- β”‚Β Β  Β Β  β”œβ”€β”€ envs.py # Custom RL environment wrappers
63
- β”‚Β Β  Β Β  β”œβ”€β”€ scripts/
64
- β”‚Β Β  Β Β  β”‚Β Β  β”œβ”€β”€ 4rooms/
65
- β”‚Β Β  Β Β  β”‚Β Β  β”‚Β Β  └── *.sh # Training scripts for MiniGrid-FourRooms-v0 environment
66
- β”‚Β Β  Β Β  β”‚Β Β  β”œβ”€β”€ acrobot/
67
- β”‚Β Β  Β Β  β”‚Β Β  β”‚Β Β  └── *.sh # Training scripts for Acrobot-v1 environment
68
- β”‚Β Β  Β Β  β”‚Β Β  └── cartpole/
69
- β”‚Β Β  Β Β  β”‚Β Β  └── *.sh # Training scripts for CartPole-v1 environment
70
- β”‚Β Β  Β Β  └── analysis/
71
- β”‚Β Β  Β Β  └── run.py # Entry point for RL analysis
72
- β”œβ”€β”€ data # This is where data will be saved and downloaded to
73
- β”‚Β Β  └── custom_datasets.py # Custom datasets (e.g., Mazes), sort
74
- β”œβ”€β”€ models
75
- β”‚Β Β  β”œβ”€β”€ ctm.py # Main model code, used for: image classification, solving mazes, sort
76
- β”‚Β Β  β”œβ”€β”€ ctm_*.py # Other model code, standalone adjustments for other tasks
77
- β”‚Β Β  β”œβ”€β”€ ff.py # feed-forward (simple) baseline code (e.g., for image classification)
78
- β”‚Β Β  β”œβ”€β”€ lstm.py # LSTM baseline code (e.g., for image classification)
79
- β”‚Β Β  β”œβ”€β”€ lstm_*.py # Other baseline code, standalone adjustments for other tasks
80
- β”‚Β Β  β”œβ”€β”€ modules.py # Helper modules, including Neuron-level models and the Synapse UNET
81
- β”‚Β Β  β”œβ”€β”€ utils.py # Helper functions (e.g., synch decay)
82
- β”‚Β Β  └── resnet.py # Wrapper for ResNet featuriser
83
- β”œβ”€β”€ utils
84
- β”‚Β Β  β”œβ”€β”€ housekeeping.py # Helper functions for keeping things neat
85
- β”‚Β Β  β”œβ”€β”€ losses.py # Loss functions for various tasks (mostly with reshaping stuff)
86
- β”‚Β Β  └── schedulers.py # Helper wrappers for learning rate schedulers
87
- └── checkpoints
88
- Β Β  └── imagenet, mazes, ... # Checkpoint directories (see google drive link for files)
89
 
 
 
 
 
 
 
 
 
 
 
 
90
  ```
91
 
92
- ## Setup
93
- To set up the environment using conda:
 
 
 
 
 
 
 
 
 
 
 
 
94
 
 
 
95
  ```
 
 
 
 
96
  conda create --name=ctm python=3.12
97
  conda activate ctm
98
  pip install -r requirements.txt
 
99
  ```
100
 
101
- If there are issues with PyTorch versions, the following can be ran:
102
- ```
 
103
  pip uninstall torch
104
  pip install torch --index-url https://download.pytorch.org/whl/cu121
105
  ```
106
 
107
- ## Model training
108
- Each task has its own (set of) training code. See for instance [tasks/image_classification/train.py](tasks/image_classification/train.py). We have set it up like this to ensure ease-of-use as opposed to clinical efficiency. This code is for researchers and we hope to have it shared in a way that fosters collaboration and learning.
109
 
110
- While we have provided reasonable defaults in the argparsers of each training setup, scripts to replicate the setups in the paper will typically be found in the accompanying script folders. If you simply want to dive in, run the following as a module (setup like this to make it easy to run many high-level training scripts from the top directory):
111
 
112
  ```
113
- python -m tasks.image_classification.train
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  ```
115
- For debugging in VSCode, this configuration example might be helpful to you:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  ```
 
 
 
 
117
  {
118
- "name": "Debug: train image classifier",
119
- "type": "debugpy",
120
- "request": "launch",
121
- "module": "tasks.image_classification.train",
122
- "console": "integratedTerminal",
123
- "justMyCode": false
124
  }
125
  ```
126
 
 
127
 
128
- ## Running analyses
129
 
130
- We also provide analysis and plotting code to replicate many of the plots in our paper. See `tasks/.../analysis/*` for more details on that. We also provide some data (e.g., the mazes we generated for training) and checkpoints (see [here](#checkpoints-and-data)). Note that ffmpeg is required for generating mp4 files from the analysis scripts. It can be installed with:
131
- ```
 
 
 
132
  conda install -c conda-forge ffmpeg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  ```
134
 
 
135
 
136
- ## Checkpoints and data
137
- You can download the data and checkpoints from here:
138
- - checkpoints: https://drive.google.com/drive/folders/1vSg8T7FqP-guMDk1LU7_jZaQtXFP9sZg
139
- - maze data: https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk
140
 
141
- Checkpoints go in the `checkpoints` folder. For instance, when properly populated, the checkpoints folder will have the maze checkpoint in `checkpoints/mazes/...`
 
1
+ ---
2
+ title: Continuous Thought Machine
3
+ emoji: πŸ•°οΈ
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ sdk_version: "20.10.21"
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
13
+
14
  # πŸ•°οΈ The Continuous Thought Machine
15
 
16
  πŸ“š [PAPER: Technical Report](https://arxiv.org/abs/2505.05522) | πŸ“ [Blog](https://sakana.ai/ctm/) | πŸ•ΉοΈ [Interactive Website](https://pub.sakana.ai/ctm) | ✏️ [Tutorial](examples/01_mnist.ipynb)
17
 
18
+ ## Overview
19
 
20
+ The **Continuous Thought Machine (CTM)** is a novel neural architecture designed to unfold and leverage neural activity as the underlying mechanism for observation and action. By introducing an internal temporal axis decoupled from input data, CTM enables neurons to process information over time with fine-grained temporal dynamics.
21
 
22
+ ### Key Contributions
23
 
24
+ 1. **Internal Temporal Axis**: Decoupled from input data, allowing neuron activity to unfold independently
25
+ 2. **Neuron-Level Temporal Processing**: Each neuron uses unique weight parameters to process a history of incoming signals
26
+ 3. **Neural Synchronisation**: Direct latent representation for modulating data and producing outputs, encoding information in the timing of neural activity
27
 
28
+ The CTM demonstrates strong performance across diverse tasks including ImageNet classification, 2D maze solving, sorting, parity computation, question-answering, and reinforcement learning.
29
 
30
+ ---
31
 
32
+ ## πŸ”¬ Energy-Based Halting Experiment
33
 
34
+ This repository includes an implementation of **Energy-Based Halting**, a mechanism that frames "thinking" as an optimization process where the model dynamically adjusts its internal thought process duration based on sample difficulty.
 
35
 
36
+ ### Concept
37
 
38
+ Instead of using heuristic certainty thresholds, we train a learned energy scalar that:
39
+
40
+ - **Minimizes energy** for correct predictions (pushing the system to low-energy equilibrium)
41
+ - **Maximizes energy** for incorrect predictions (pushing away from stable states)
42
+ - **Enables adaptive halting** based on energy thresholds or convergence
43
+
44
+ ### Implementation
45
+
46
+ **Modified Components:**
47
+
48
+ - `models/ctm.py`: Added energy projection head that maps synchronization states to scalar energy values
49
+ - `utils/losses.py`: Implemented `EnergyContrastiveLoss` for training the energy function
50
+ - `tasks/image_classification/train_energy.py`: Training script with energy halting
51
+ - `inference_energy.py`: Adaptive inference that halts when energy drops below threshold or stabilizes
52
+ - `configs/energy_experiment.yaml`: Configuration for energy experiments
53
+
54
+ **Training:**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ ```bash
57
+ # Local training
58
+ pixi run accelerate launch tasks/image_classification/train_energy.py \
59
+ --energy_head_enabled \
60
+ --loss_type energy_contrastive \
61
+ --dataset cifar10
62
+
63
+ # Or with traditional python
64
+ pixi run python tasks/image_classification/train_energy.py \
65
+ --energy_head_enabled \
66
+ --loss_type energy_contrastive
67
  ```
68
 
69
+ **Deployment to Hugging Face:**
70
+ See [GUIDE_HF.md](GUIDE_HF.md) for instructions on deploying the training job to Hugging Face Spaces with GPU support.
71
+
72
+ ---
73
+
74
+ ## πŸš€ Quick Start
75
+
76
+ ### Setup with Pixi (Recommended)
77
+
78
+ We use [Pixi](https://pixi.sh) for dependency management, which handles both Python packages and system dependencies like `ffmpeg`.
79
+
80
+ ```bash
81
+ # Install dependencies
82
+ pixi install
83
 
84
+ # Run training
85
+ pixi run python tasks/image_classification/train.py
86
  ```
87
+
88
+ ### Alternative: Conda Setup
89
+
90
+ ```bash
91
  conda create --name=ctm python=3.12
92
  conda activate ctm
93
  pip install -r requirements.txt
94
+ conda install -c conda-forge ffmpeg
95
  ```
96
 
97
+ If there are PyTorch version issues:
98
+
99
+ ```bash
100
  pip uninstall torch
101
  pip install torch --index-url https://download.pytorch.org/whl/cu121
102
  ```
103
 
104
+ ---
 
105
 
106
+ ## πŸ“ Repository Structure
107
 
108
  ```
109
+ β”œβ”€β”€ tasks/
110
+ β”‚ β”œβ”€β”€ image_classification/
111
+ β”‚ β”‚ β”œβ”€β”€ train.py # Standard training
112
+ β”‚ β”‚ β”œβ”€β”€ train_energy.py # Energy halting training
113
+ β”‚ β”‚ β”œβ”€β”€ analysis/run_imagenet_analysis.py
114
+ β”‚ β”‚ └── plotting.py
115
+ β”‚ β”œβ”€β”€ mazes/
116
+ β”‚ β”‚ β”œβ”€β”€ train.py
117
+ β”‚ β”‚ └── analysis/
118
+ β”‚ β”œβ”€β”€ sort/
119
+ β”‚ β”œβ”€β”€ parity/
120
+ β”‚ β”œβ”€β”€ qamnist/
121
+ β”‚ └── rl/
122
+ β”œβ”€β”€ models/
123
+ β”‚ β”œβ”€β”€ ctm.py # Main CTM model (with energy head support)
124
+ β”‚ β”œβ”€β”€ modules.py # Neuron-level models, Synapse UNET
125
+ β”‚ β”œβ”€β”€ ff.py # Feed-forward baseline
126
+ β”‚ └── lstm.py # LSTM baseline
127
+ β”œβ”€β”€ utils/
128
+ β”‚ β”œβ”€β”€ losses.py # Loss functions (includes EnergyContrastiveLoss)
129
+ β”‚ β”œβ”€β”€ schedulers.py
130
+ β”‚ └── housekeeping.py
131
+ β”œβ”€β”€ data/
132
+ β”‚ └── custom_datasets.py
133
+ β”œβ”€β”€ configs/
134
+ β”‚ └── energy_experiment.yaml # Energy halting hyperparameters
135
+ β”œβ”€β”€ inference_energy.py # Adaptive energy-based inference
136
+ β”œβ”€β”€ Dockerfile # For HF Spaces deployment
137
+ β”œβ”€β”€ GUIDE_HF.md # Hugging Face deployment guide
138
+ └── checkpoints/ # Model checkpoints
139
  ```
140
+
141
+ ---
142
+
143
+ ## 🎯 Model Training
144
+
145
+ Each task has dedicated training code designed for ease-of-use and collaboration. Training scripts include reasonable defaults, with paper-replicating configurations in accompanying script folders.
146
+
147
+ ### Image Classification Example
148
+
149
+ ```bash
150
+ # Standard CTM training
151
+ python -m tasks.image_classification.train
152
+
153
+ # Energy halting training
154
+ python -m tasks.image_classification.train_energy \
155
+ --energy_head_enabled \
156
+ --loss_type energy_contrastive
157
  ```
158
+
159
+ ### VSCode Debug Configuration
160
+
161
+ ```json
162
  {
163
+ "name": "Debug: train image classifier",
164
+ "type": "debugpy",
165
+ "request": "launch",
166
+ "module": "tasks.image_classification.train",
167
+ "console": "integratedTerminal",
168
+ "justMyCode": false
169
  }
170
  ```
171
 
172
+ ---
173
 
174
+ ## πŸ” Analysis & Visualization
175
 
176
+ Analysis and plotting code to replicate paper figures is provided in `tasks/.../analysis/*`.
177
+
178
+ **Note:** `ffmpeg` is required for generating videos:
179
+
180
+ ```bash
181
  conda install -c conda-forge ffmpeg
182
+ # or with pixi (already included)
183
+ pixi install
184
+ ```
185
+
186
+ ---
187
+
188
+ ## πŸ“¦ Checkpoints and Data
189
+
190
+ Download pre-trained checkpoints and datasets:
191
+
192
+ - **Checkpoints**: [Google Drive](https://drive.google.com/drive/folders/1vSg8T7FqP-guMDk1LU7_jZaQtXFP9sZg)
193
+ - **Maze Data**: [Google Drive](https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk)
194
+
195
+ Place checkpoints in the `checkpoints/` folder following the structure `checkpoints/{task}/...`
196
+
197
+ ---
198
+
199
+ ## πŸ€— Hugging Face Integration
200
+
201
+ This repository includes full support for training on Hugging Face infrastructure:
202
+
203
+ - **Accelerate**: Multi-GPU and mixed precision training
204
+ - **Hub Integration**: Automatic checkpoint uploading
205
+ - **Spaces Deployment**: Run training jobs on GPU Spaces
206
+
207
+ See [GUIDE_HF.md](GUIDE_HF.md) for detailed instructions.
208
+
209
+ ---
210
+
211
+ ## πŸ“– Interactive Resources
212
+
213
+ - **[Interactive Website](https://pub.sakana.ai/ctm)**: Maze-solving demo, videos, and visualizations
214
+ - **[Paper](https://arxiv.org/abs/2505.05522)**: Technical details and experiments
215
+ - **[Blog](https://sakana.ai/ctm/)**: High-level overview and insights
216
+ - **[Tutorial Notebook](examples/01_mnist.ipynb)**: Hands-on introduction
217
+
218
+ ---
219
+
220
+ ## πŸ™ Citation
221
+
222
+ If you use this code or build upon CTM in your work, please cite:
223
+
224
+ ```bibtex
225
+ @article{ctm2025,
226
+ title={The Continuous Thought Machine},
227
+ author={...},
228
+ journal={arXiv preprint arXiv:2505.05522},
229
+ year={2025}
230
+ }
231
  ```
232
 
233
+ ---
234
 
235
+ ## πŸ“ License
 
 
 
236
 
237
+ This project is released under the MIT License. See LICENSE file for details.