Update README with HF metadata and Energy Halting info
Browse files
README.md
CHANGED
|
@@ -1,141 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# π°οΈ The Continuous Thought Machine
|
| 2 |
|
| 3 |
π [PAPER: Technical Report](https://arxiv.org/abs/2505.05522) | π [Blog](https://sakana.ai/ctm/) | πΉοΈ [Interactive Website](https://pub.sakana.ai/ctm) | βοΈ [Tutorial](examples/01_mnist.ipynb)
|
| 4 |
|
| 5 |
-
|
| 6 |
|
| 7 |
-
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
| 16 |
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
Please see our [Interactive Website](https://pub.sakana.ai/ctm) for a maze-solving demo, many demonstrative videos of the method, results, and other findings.
|
| 21 |
|
|
|
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
βΒ Β βΒ Β βββ train.py # Training code for sorting
|
| 41 |
-
βΒ Β βΒ Β βββ utils.py # Sort specific utils (e.g., CTC decode)
|
| 42 |
-
βΒ Β βββ parity
|
| 43 |
-
βΒ Β βΒ Β βββ train.py # Training code for parity task
|
| 44 |
-
βΒ Β βΒ Β βββ utils.py # Parity-specific helper functions
|
| 45 |
-
βΒ Β βΒ Β βββ plotting.py # Plotting utils specific to this task
|
| 46 |
-
βΒ Β βΒ Β βββ scripts/
|
| 47 |
-
βΒ Β βΒ Β βΒ Β βββ *.sh # Training scripts for different experimental setups
|
| 48 |
-
βΒ Β βΒ Β βββ analysis/
|
| 49 |
-
βΒ Β βΒ Β βββ run.py # Entry point for parity analysis
|
| 50 |
-
βΒ Β βββ qamnist
|
| 51 |
-
βΒ Β βΒ Β βββ train.py # Training code for QAMNIST task (quantized MNIST)
|
| 52 |
-
βΒ Β βΒ Β βββ utils.py # QAMNIST-specific helper functions
|
| 53 |
-
βΒ Β βΒ Β βββ plotting.py # Plotting utils specific to this task
|
| 54 |
-
βΒ Β βΒ Β βββ scripts/
|
| 55 |
-
βΒ Β βΒ Β βΒ Β βββ *.sh # Training scripts for different experimental setups
|
| 56 |
-
βΒ Β βΒ Β βββ analysis/
|
| 57 |
-
βΒ Β βΒ Β βββ run.py # Entry point for QAMNIST analysis
|
| 58 |
-
βΒ Β βββ rl
|
| 59 |
-
βΒ Β Β Β βββ train.py # Training code for RL environments
|
| 60 |
-
βΒ Β Β Β βββ utils.py # RL-specific helper functions
|
| 61 |
-
βΒ Β Β Β βββ plotting.py # Plotting utils specific to this task
|
| 62 |
-
βΒ Β Β Β βββ envs.py # Custom RL environment wrappers
|
| 63 |
-
βΒ Β Β Β βββ scripts/
|
| 64 |
-
βΒ Β Β Β βΒ Β βββ 4rooms/
|
| 65 |
-
βΒ Β Β Β βΒ Β βΒ Β βββ *.sh # Training scripts for MiniGrid-FourRooms-v0 environment
|
| 66 |
-
βΒ Β Β Β βΒ Β βββ acrobot/
|
| 67 |
-
βΒ Β Β Β βΒ Β βΒ Β βββ *.sh # Training scripts for Acrobot-v1 environment
|
| 68 |
-
βΒ Β Β Β βΒ Β βββ cartpole/
|
| 69 |
-
βΒ Β Β Β βΒ Β βββ *.sh # Training scripts for CartPole-v1 environment
|
| 70 |
-
βΒ Β Β Β βββ analysis/
|
| 71 |
-
βΒ Β Β Β βββ run.py # Entry point for RL analysis
|
| 72 |
-
βββ data # This is where data will be saved and downloaded to
|
| 73 |
-
βΒ Β βββ custom_datasets.py # Custom datasets (e.g., Mazes), sort
|
| 74 |
-
βββ models
|
| 75 |
-
βΒ Β βββ ctm.py # Main model code, used for: image classification, solving mazes, sort
|
| 76 |
-
βΒ Β βββ ctm_*.py # Other model code, standalone adjustments for other tasks
|
| 77 |
-
βΒ Β βββ ff.py # feed-forward (simple) baseline code (e.g., for image classification)
|
| 78 |
-
βΒ Β βββ lstm.py # LSTM baseline code (e.g., for image classification)
|
| 79 |
-
βΒ Β βββ lstm_*.py # Other baseline code, standalone adjustments for other tasks
|
| 80 |
-
βΒ Β βββ modules.py # Helper modules, including Neuron-level models and the Synapse UNET
|
| 81 |
-
βΒ Β βββ utils.py # Helper functions (e.g., synch decay)
|
| 82 |
-
βΒ Β βββ resnet.py # Wrapper for ResNet featuriser
|
| 83 |
-
βββ utils
|
| 84 |
-
βΒ Β βββ housekeeping.py # Helper functions for keeping things neat
|
| 85 |
-
βΒ Β βββ losses.py # Loss functions for various tasks (mostly with reshaping stuff)
|
| 86 |
-
βΒ Β βββ schedulers.py # Helper wrappers for learning rate schedulers
|
| 87 |
-
βββ checkpoints
|
| 88 |
-
Β Β βββ imagenet, mazes, ... # Checkpoint directories (see google drive link for files)
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
```
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
|
|
|
|
|
|
| 95 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
conda create --name=ctm python=3.12
|
| 97 |
conda activate ctm
|
| 98 |
pip install -r requirements.txt
|
|
|
|
| 99 |
```
|
| 100 |
|
| 101 |
-
If there are
|
| 102 |
-
|
|
|
|
| 103 |
pip uninstall torch
|
| 104 |
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
| 105 |
```
|
| 106 |
|
| 107 |
-
|
| 108 |
-
Each task has its own (set of) training code. See for instance [tasks/image_classification/train.py](tasks/image_classification/train.py). We have set it up like this to ensure ease-of-use as opposed to clinical efficiency. This code is for researchers and we hope to have it shared in a way that fosters collaboration and learning.
|
| 109 |
|
| 110 |
-
|
| 111 |
|
| 112 |
```
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
```
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
{
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
}
|
| 125 |
```
|
| 126 |
|
|
|
|
| 127 |
|
| 128 |
-
##
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
| 132 |
conda install -c conda-forge ffmpeg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
```
|
| 134 |
|
|
|
|
| 135 |
|
| 136 |
-
##
|
| 137 |
-
You can download the data and checkpoints from here:
|
| 138 |
-
- checkpoints: https://drive.google.com/drive/folders/1vSg8T7FqP-guMDk1LU7_jZaQtXFP9sZg
|
| 139 |
-
- maze data: https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk
|
| 140 |
|
| 141 |
-
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Continuous Thought Machine
|
| 3 |
+
emoji: π°οΈ
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
sdk_version: "20.10.21"
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 13 |
+
|
| 14 |
# π°οΈ The Continuous Thought Machine
|
| 15 |
|
| 16 |
π [PAPER: Technical Report](https://arxiv.org/abs/2505.05522) | π [Blog](https://sakana.ai/ctm/) | πΉοΈ [Interactive Website](https://pub.sakana.ai/ctm) | βοΈ [Tutorial](examples/01_mnist.ipynb)
|
| 17 |
|
| 18 |
+
## Overview
|
| 19 |
|
| 20 |
+
The **Continuous Thought Machine (CTM)** is a novel neural architecture designed to unfold and leverage neural activity as the underlying mechanism for observation and action. By introducing an internal temporal axis decoupled from input data, CTM enables neurons to process information over time with fine-grained temporal dynamics.
|
| 21 |
|
| 22 |
+
### Key Contributions
|
| 23 |
|
| 24 |
+
1. **Internal Temporal Axis**: Decoupled from input data, allowing neuron activity to unfold independently
|
| 25 |
+
2. **Neuron-Level Temporal Processing**: Each neuron uses unique weight parameters to process a history of incoming signals
|
| 26 |
+
3. **Neural Synchronisation**: Direct latent representation for modulating data and producing outputs, encoding information in the timing of neural activity
|
| 27 |
|
| 28 |
+
The CTM demonstrates strong performance across diverse tasks including ImageNet classification, 2D maze solving, sorting, parity computation, question-answering, and reinforcement learning.
|
| 29 |
|
| 30 |
+
---
|
| 31 |
|
| 32 |
+
## π¬ Energy-Based Halting Experiment
|
| 33 |
|
| 34 |
+
This repository includes an implementation of **Energy-Based Halting**, a mechanism that frames "thinking" as an optimization process where the model dynamically adjusts its internal thought process duration based on sample difficulty.
|
|
|
|
| 35 |
|
| 36 |
+
### Concept
|
| 37 |
|
| 38 |
+
Instead of using heuristic certainty thresholds, we train a learned energy scalar that:
|
| 39 |
+
|
| 40 |
+
- **Minimizes energy** for correct predictions (pushing the system to low-energy equilibrium)
|
| 41 |
+
- **Maximizes energy** for incorrect predictions (pushing away from stable states)
|
| 42 |
+
- **Enables adaptive halting** based on energy thresholds or convergence
|
| 43 |
+
|
| 44 |
+
### Implementation
|
| 45 |
+
|
| 46 |
+
**Modified Components:**
|
| 47 |
+
|
| 48 |
+
- `models/ctm.py`: Added energy projection head that maps synchronization states to scalar energy values
|
| 49 |
+
- `utils/losses.py`: Implemented `EnergyContrastiveLoss` for training the energy function
|
| 50 |
+
- `tasks/image_classification/train_energy.py`: Training script with energy halting
|
| 51 |
+
- `inference_energy.py`: Adaptive inference that halts when energy drops below threshold or stabilizes
|
| 52 |
+
- `configs/energy_experiment.yaml`: Configuration for energy experiments
|
| 53 |
+
|
| 54 |
+
**Training:**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
```bash
|
| 57 |
+
# Local training
|
| 58 |
+
pixi run accelerate launch tasks/image_classification/train_energy.py \
|
| 59 |
+
--energy_head_enabled \
|
| 60 |
+
--loss_type energy_contrastive \
|
| 61 |
+
--dataset cifar10
|
| 62 |
+
|
| 63 |
+
# Or with traditional python
|
| 64 |
+
pixi run python tasks/image_classification/train_energy.py \
|
| 65 |
+
--energy_head_enabled \
|
| 66 |
+
--loss_type energy_contrastive
|
| 67 |
```
|
| 68 |
|
| 69 |
+
**Deployment to Hugging Face:**
|
| 70 |
+
See [GUIDE_HF.md](GUIDE_HF.md) for instructions on deploying the training job to Hugging Face Spaces with GPU support.
|
| 71 |
+
|
| 72 |
+
---
|
| 73 |
+
|
| 74 |
+
## π Quick Start
|
| 75 |
+
|
| 76 |
+
### Setup with Pixi (Recommended)
|
| 77 |
+
|
| 78 |
+
We use [Pixi](https://pixi.sh) for dependency management, which handles both Python packages and system dependencies like `ffmpeg`.
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
# Install dependencies
|
| 82 |
+
pixi install
|
| 83 |
|
| 84 |
+
# Run training
|
| 85 |
+
pixi run python tasks/image_classification/train.py
|
| 86 |
```
|
| 87 |
+
|
| 88 |
+
### Alternative: Conda Setup
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
conda create --name=ctm python=3.12
|
| 92 |
conda activate ctm
|
| 93 |
pip install -r requirements.txt
|
| 94 |
+
conda install -c conda-forge ffmpeg
|
| 95 |
```
|
| 96 |
|
| 97 |
+
If there are PyTorch version issues:
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
pip uninstall torch
|
| 101 |
pip install torch --index-url https://download.pytorch.org/whl/cu121
|
| 102 |
```
|
| 103 |
|
| 104 |
+
---
|
|
|
|
| 105 |
|
| 106 |
+
## π Repository Structure
|
| 107 |
|
| 108 |
```
|
| 109 |
+
βββ tasks/
|
| 110 |
+
β βββ image_classification/
|
| 111 |
+
β β βββ train.py # Standard training
|
| 112 |
+
β β βββ train_energy.py # Energy halting training
|
| 113 |
+
β β βββ analysis/run_imagenet_analysis.py
|
| 114 |
+
β β βββ plotting.py
|
| 115 |
+
β βββ mazes/
|
| 116 |
+
β β βββ train.py
|
| 117 |
+
β β βββ analysis/
|
| 118 |
+
β βββ sort/
|
| 119 |
+
β βββ parity/
|
| 120 |
+
β βββ qamnist/
|
| 121 |
+
β βββ rl/
|
| 122 |
+
βββ models/
|
| 123 |
+
β βββ ctm.py # Main CTM model (with energy head support)
|
| 124 |
+
β βββ modules.py # Neuron-level models, Synapse UNET
|
| 125 |
+
β βββ ff.py # Feed-forward baseline
|
| 126 |
+
β βββ lstm.py # LSTM baseline
|
| 127 |
+
βββ utils/
|
| 128 |
+
β βββ losses.py # Loss functions (includes EnergyContrastiveLoss)
|
| 129 |
+
β βββ schedulers.py
|
| 130 |
+
β βββ housekeeping.py
|
| 131 |
+
βββ data/
|
| 132 |
+
β βββ custom_datasets.py
|
| 133 |
+
βββ configs/
|
| 134 |
+
β βββ energy_experiment.yaml # Energy halting hyperparameters
|
| 135 |
+
βββ inference_energy.py # Adaptive energy-based inference
|
| 136 |
+
βββ Dockerfile # For HF Spaces deployment
|
| 137 |
+
βββ GUIDE_HF.md # Hugging Face deployment guide
|
| 138 |
+
βββ checkpoints/ # Model checkpoints
|
| 139 |
```
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## π― Model Training
|
| 144 |
+
|
| 145 |
+
Each task has dedicated training code designed for ease-of-use and collaboration. Training scripts include reasonable defaults, with paper-replicating configurations in accompanying script folders.
|
| 146 |
+
|
| 147 |
+
### Image Classification Example
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
# Standard CTM training
|
| 151 |
+
python -m tasks.image_classification.train
|
| 152 |
+
|
| 153 |
+
# Energy halting training
|
| 154 |
+
python -m tasks.image_classification.train_energy \
|
| 155 |
+
--energy_head_enabled \
|
| 156 |
+
--loss_type energy_contrastive
|
| 157 |
```
|
| 158 |
+
|
| 159 |
+
### VSCode Debug Configuration
|
| 160 |
+
|
| 161 |
+
```json
|
| 162 |
{
|
| 163 |
+
"name": "Debug: train image classifier",
|
| 164 |
+
"type": "debugpy",
|
| 165 |
+
"request": "launch",
|
| 166 |
+
"module": "tasks.image_classification.train",
|
| 167 |
+
"console": "integratedTerminal",
|
| 168 |
+
"justMyCode": false
|
| 169 |
}
|
| 170 |
```
|
| 171 |
|
| 172 |
+
---
|
| 173 |
|
| 174 |
+
## π Analysis & Visualization
|
| 175 |
|
| 176 |
+
Analysis and plotting code to replicate paper figures is provided in `tasks/.../analysis/*`.
|
| 177 |
+
|
| 178 |
+
**Note:** `ffmpeg` is required for generating videos:
|
| 179 |
+
|
| 180 |
+
```bash
|
| 181 |
conda install -c conda-forge ffmpeg
|
| 182 |
+
# or with pixi (already included)
|
| 183 |
+
pixi install
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
---
|
| 187 |
+
|
| 188 |
+
## π¦ Checkpoints and Data
|
| 189 |
+
|
| 190 |
+
Download pre-trained checkpoints and datasets:
|
| 191 |
+
|
| 192 |
+
- **Checkpoints**: [Google Drive](https://drive.google.com/drive/folders/1vSg8T7FqP-guMDk1LU7_jZaQtXFP9sZg)
|
| 193 |
+
- **Maze Data**: [Google Drive](https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk)
|
| 194 |
+
|
| 195 |
+
Place checkpoints in the `checkpoints/` folder following the structure `checkpoints/{task}/...`
|
| 196 |
+
|
| 197 |
+
---
|
| 198 |
+
|
| 199 |
+
## π€ Hugging Face Integration
|
| 200 |
+
|
| 201 |
+
This repository includes full support for training on Hugging Face infrastructure:
|
| 202 |
+
|
| 203 |
+
- **Accelerate**: Multi-GPU and mixed precision training
|
| 204 |
+
- **Hub Integration**: Automatic checkpoint uploading
|
| 205 |
+
- **Spaces Deployment**: Run training jobs on GPU Spaces
|
| 206 |
+
|
| 207 |
+
See [GUIDE_HF.md](GUIDE_HF.md) for detailed instructions.
|
| 208 |
+
|
| 209 |
+
---
|
| 210 |
+
|
| 211 |
+
## π Interactive Resources
|
| 212 |
+
|
| 213 |
+
- **[Interactive Website](https://pub.sakana.ai/ctm)**: Maze-solving demo, videos, and visualizations
|
| 214 |
+
- **[Paper](https://arxiv.org/abs/2505.05522)**: Technical details and experiments
|
| 215 |
+
- **[Blog](https://sakana.ai/ctm/)**: High-level overview and insights
|
| 216 |
+
- **[Tutorial Notebook](examples/01_mnist.ipynb)**: Hands-on introduction
|
| 217 |
+
|
| 218 |
+
---
|
| 219 |
+
|
| 220 |
+
## π Citation
|
| 221 |
+
|
| 222 |
+
If you use this code or build upon CTM in your work, please cite:
|
| 223 |
+
|
| 224 |
+
```bibtex
|
| 225 |
+
@article{ctm2025,
|
| 226 |
+
title={The Continuous Thought Machine},
|
| 227 |
+
author={...},
|
| 228 |
+
journal={arXiv preprint arXiv:2505.05522},
|
| 229 |
+
year={2025}
|
| 230 |
+
}
|
| 231 |
```
|
| 232 |
|
| 233 |
+
---
|
| 234 |
|
| 235 |
+
## π License
|
|
|
|
|
|
|
|
|
|
| 236 |
|
| 237 |
+
This project is released under the MIT License. See LICENSE file for details.
|