Ciaran Regan commited on
Commit
f3f0e2d
·
1 Parent(s): 7ae2b6b

Add Hugging Face Support (#23)

Browse files

* Make CTM hf compatible

* Add notebook to demo HF usage

* Add changes to all notebooks for easier local runs.

.gitignore CHANGED
@@ -19,4 +19,7 @@ examples/*
19
  !examples/02_inference.ipynb
20
  !examples/03_mazes.ipynb
21
  !examples/04_parity.ipynb
 
 
22
  checkpoints
 
 
19
  !examples/02_inference.ipynb
20
  !examples/03_mazes.ipynb
21
  !examples/04_parity.ipynb
22
+ !examples/05_huggingface.ipynb
23
+ !examples/goldfish.jpg
24
  checkpoints
25
+ utils/hugging_face/
examples/01_mnist.ipynb CHANGED
@@ -749,7 +749,7 @@
749
  },
750
  {
751
  "cell_type": "code",
752
- "execution_count": 42,
753
  "id": "b3fbae96",
754
  "metadata": {},
755
  "outputs": [],
@@ -789,7 +789,7 @@
789
  " [['certainty'] * 8] + \\\n",
790
  " [[f'trace_{ti}'] * 8 for ti in range(n_neurons_to_visualise)]\n",
791
  "\n",
792
- " for stepi in range(n_steps):\n",
793
  " fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale))\n",
794
  " probs = softmax(these_predictions[:, stepi])\n",
795
  " colors = [('g' if i == this_target else 'b') for i in range(len(probs))]\n",
@@ -940,7 +940,7 @@
940
  ],
941
  "metadata": {
942
  "kernelspec": {
943
- "display_name": "base",
944
  "language": "python",
945
  "name": "python3"
946
  },
@@ -954,7 +954,7 @@
954
  "name": "python",
955
  "nbconvert_exporter": "python",
956
  "pygments_lexer": "ipython3",
957
- "version": "3.12.9"
958
  }
959
  },
960
  "nbformat": 4,
 
749
  },
750
  {
751
  "cell_type": "code",
752
+ "execution_count": null,
753
  "id": "b3fbae96",
754
  "metadata": {},
755
  "outputs": [],
 
789
  " [['certainty'] * 8] + \\\n",
790
  " [[f'trace_{ti}'] * 8 for ti in range(n_neurons_to_visualise)]\n",
791
  "\n",
792
+ " for stepi in tqdm(range(n_steps), desc=\"Processing steps\", unit=\"step\"):\n",
793
  " fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale))\n",
794
  " probs = softmax(these_predictions[:, stepi])\n",
795
  " colors = [('g' if i == this_target else 'b') for i in range(len(probs))]\n",
 
940
  ],
941
  "metadata": {
942
  "kernelspec": {
943
+ "display_name": "Python 3",
944
  "language": "python",
945
  "name": "python3"
946
  },
 
954
  "name": "python",
955
  "nbconvert_exporter": "python",
956
  "pygments_lexer": "ipython3",
957
+ "version": "3.12.10"
958
  }
959
  },
960
  "nbformat": 4,
examples/02_inference.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
examples/03_mazes.ipynb CHANGED
@@ -67,6 +67,31 @@
67
  "In addition to installing some dependencies, we also clone the CTM repo (assuming this tutorial is being ran in Colab), so that we can access the base CTM model."
68
  ]
69
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  {
71
  "cell_type": "code",
72
  "execution_count": null,
@@ -108,7 +133,7 @@
108
  },
109
  {
110
  "cell_type": "code",
111
- "execution_count": 1,
112
  "id": "24ffe416",
113
  "metadata": {},
114
  "outputs": [
@@ -122,9 +147,6 @@
122
  }
123
  ],
124
  "source": [
125
- "import sys\n",
126
- "sys.path.append(\"./continuous-thought-machines\")\n",
127
- "\n",
128
  "import os\n",
129
  "import torch\n",
130
  "import torch.nn as nn\n",
 
67
  "In addition to installing some dependencies, we also clone the CTM repo (assuming this tutorial is being ran in Colab), so that we can access the base CTM model."
68
  ]
69
  },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": null,
73
+ "id": "537dd917",
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "USE_COLAB = False"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "import sys\n",
87
+ "\n",
88
+ "if USE_COLAB:\n",
89
+ " !git clone https://github.com/SakanaAI/continuous-thought-machines.git\n",
90
+ " sys.path.append(\"./continuous-thought-machines\")\n",
91
+ "else:\n",
92
+ " sys.path.append(\"..\")"
93
+ ]
94
+ },
95
  {
96
  "cell_type": "code",
97
  "execution_count": null,
 
133
  },
134
  {
135
  "cell_type": "code",
136
+ "execution_count": null,
137
  "id": "24ffe416",
138
  "metadata": {},
139
  "outputs": [
 
147
  }
148
  ],
149
  "source": [
 
 
 
150
  "import os\n",
151
  "import torch\n",
152
  "import torch.nn as nn\n",
examples/04_parity.ipynb CHANGED
@@ -52,6 +52,32 @@
52
  "In addition to installing some dependencies, we also clone the CTM repo (assuming this tutorial is being run in Colab), so that we can access the base CTM model."
53
  ]
54
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  {
56
  "cell_type": "code",
57
  "execution_count": null,
@@ -60,8 +86,7 @@
60
  "outputs": [],
61
  "source": [
62
  "!pip install gdown\n",
63
- "!pip install mediapy\n",
64
- "!git clone https://github.com/SakanaAI/continuous-thought-machines.git\n"
65
  ]
66
  },
67
  {
@@ -74,14 +99,11 @@
74
  },
75
  {
76
  "cell_type": "code",
77
- "execution_count": 17,
78
  "id": "24ffe416",
79
  "metadata": {},
80
  "outputs": [],
81
  "source": [
82
- "import sys\n",
83
- "sys.path.append(\"./continuous-thought-machines\")\n",
84
- "\n",
85
  "import os\n",
86
  "import torch\n",
87
  "import torch.nn as nn\n",
 
52
  "In addition to installing some dependencies, we also clone the CTM repo (assuming this tutorial is being run in Colab), so that we can access the base CTM model."
53
  ]
54
  },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "id": "5c06d1e5",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "USE_COLAB = False"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "id": "30ab5f0d",
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "import sys\n",
73
+ "\n",
74
+ "if USE_COLAB:\n",
75
+ " !git clone https://github.com/SakanaAI/continuous-thought-machines.git\n",
76
+ " sys.path.append(\"./continuous-thought-machines\")\n",
77
+ "else:\n",
78
+ " sys.path.append(\"..\")"
79
+ ]
80
+ },
81
  {
82
  "cell_type": "code",
83
  "execution_count": null,
 
86
  "outputs": [],
87
  "source": [
88
  "!pip install gdown\n",
89
+ "!pip install mediapy"
 
90
  ]
91
  },
92
  {
 
99
  },
100
  {
101
  "cell_type": "code",
102
+ "execution_count": null,
103
  "id": "24ffe416",
104
  "metadata": {},
105
  "outputs": [],
106
  "source": [
 
 
 
107
  "import os\n",
108
  "import torch\n",
109
  "import torch.nn as nn\n",
examples/05_huggingface.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models/ctm.py CHANGED
@@ -2,6 +2,7 @@ import torch.nn as nn
2
  import torch
3
  import numpy as np
4
  import math
 
5
 
6
  from models.modules import ParityBackbone, SynapseUNET, Squeeze, SuperLinear, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
7
  from models.resnet import prepare_resnet_backbone
@@ -13,7 +14,7 @@ from models.constants import (
13
  VALID_POSITIONAL_EMBEDDING_TYPES
14
  )
15
 
16
- class ContinuousThoughtMachine(nn.Module):
17
  """
18
  Continuous Thought Machine (CTM).
19
 
@@ -149,6 +150,53 @@ class ContinuousThoughtMachine(nn.Module):
149
  # --- Output Procesing ---
150
  self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  # --- Core CTM Methods ---
153
 
154
  def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, r, synch_type):
@@ -553,3 +601,4 @@ class ContinuousThoughtMachine(nn.Module):
553
  if track:
554
  return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
555
  return predictions, certainties, synchronisation_out
 
 
2
  import torch
3
  import numpy as np
4
  import math
5
+ from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
6
 
7
  from models.modules import ParityBackbone, SynapseUNET, Squeeze, SuperLinear, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
8
  from models.resnet import prepare_resnet_backbone
 
14
  VALID_POSITIONAL_EMBEDDING_TYPES
15
  )
16
 
17
+ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
18
  """
19
  Continuous Thought Machine (CTM).
20
 
 
150
  # --- Output Procesing ---
151
  self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))
152
 
153
+ @classmethod
154
+ def _from_pretrained(
155
+ cls,
156
+ *,
157
+ model_id: str,
158
+ revision=None,
159
+ cache_dir=None,
160
+ force_download=False,
161
+ proxies=None,
162
+ resume_download=None,
163
+ local_files_only=False,
164
+ token=None,
165
+ map_location="cpu",
166
+ strict=False,
167
+ **model_kwargs,
168
+ ):
169
+ """Override to handle lazy weights initialization."""
170
+ model = cls(**model_kwargs).to(map_location)
171
+
172
+ # The CTM contains Lazy modules, so we must run a dummy forward pass to initialize them
173
+ if "imagenet" in model_id:
174
+ dummy_input = torch.randn(1, 3, 224, 224, device=map_location)
175
+ elif "maze-large" in model_id:
176
+ dummy_input = torch.randn(1, 3, 99, 99, device=map_location)
177
+ else:
178
+ raise NotImplementedError
179
+
180
+ with torch.no_grad():
181
+ _ = model(dummy_input)
182
+
183
+ model_file = hf_hub_download(
184
+ repo_id=model_id,
185
+ filename="model.safetensors",
186
+ revision=revision,
187
+ cache_dir=cache_dir,
188
+ force_download=force_download,
189
+ proxies=proxies,
190
+ resume_download=resume_download,
191
+ token=token,
192
+ local_files_only=local_files_only,
193
+ )
194
+ from safetensors.torch import load_model as load_model_as_safetensor
195
+ load_model_as_safetensor(model, model_file, strict=strict, device=map_location)
196
+
197
+ model.eval()
198
+ return model
199
+
200
  # --- Core CTM Methods ---
201
 
202
  def compute_synchronisation(self, activated_state, decay_alpha, decay_beta, r, synch_type):
 
601
  if track:
602
  return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
603
  return predictions, certainties, synchronisation_out
604
+
requirements.txt CHANGED
@@ -12,4 +12,6 @@ python-dotenv
12
  gymnasium
13
  minigrid
14
  datasets
15
- autoclip
 
 
 
12
  gymnasium
13
  minigrid
14
  datasets
15
+ autoclip
16
+ huggingface_hub
17
+ safetensors
tasks/parity/plotting.py CHANGED
@@ -13,6 +13,7 @@ import imageio.v2 as imageio
13
  from PIL import Image
14
  import math
15
  import re
 
16
  sns.set_style('darkgrid')
17
  mpl.use('Agg')
18
 
@@ -43,7 +44,7 @@ def make_parity_gif(predictions, certainties, targets, pre_activations, post_act
43
  [['certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty']] + \
44
  [[f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}'] for ti in range(n_neurons_to_visualise)]
45
 
46
- for stepi in range(n_steps):
47
  fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale))
48
 
49
  # Plot predictions
 
13
  from PIL import Image
14
  import math
15
  import re
16
+ from tqdm import tqdm
17
  sns.set_style('darkgrid')
18
  mpl.use('Agg')
19
 
 
44
  [['certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty']] + \
45
  [[f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}'] for ti in range(n_neurons_to_visualise)]
46
 
47
+ for stepi in tqdm(range(n_steps), desc="Processing steps", unit="step"):
48
  fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale))
49
 
50
  # Plot predictions