Uday commited on
Commit
80dd9c4
·
1 Parent(s): 9f8cba2

Thought Depth via Energy Minimization: halting with a learned Energy scalar.

Browse files
.gitignore CHANGED
@@ -26,3 +26,4 @@ utils/hugging_face/
26
  # pixi environments
27
  .pixi/*
28
  !.pixi/config.toml
 
 
26
  # pixi environments
27
  .pixi/*
28
  !.pixi/config.toml
29
+ changes.md
configs/energy_experiment.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Energy Halting Experiment Config
2
+
3
+ # Model Architecture
4
+ model: ctm
5
+ d_model: 512
6
+ d_input: 128
7
+ heads: 4
8
+ iterations: 50
9
+ dropout: 0.0
10
+ backbone_type: resnet18-4
11
+ positional_embedding_type: none
12
+
13
+ # CTM Specifics
14
+ synapse_depth: 4
15
+ n_synch_out: 512
16
+ n_synch_action: 512
17
+ neuron_select_type: random-pairing
18
+ n_random_pairing_self: 0
19
+ memory_length: 25
20
+ deep_memory: true
21
+ memory_hidden_dims: 4
22
+ do_normalisation: false
23
+
24
+ # Energy Head
25
+ energy_head:
26
+ enabled: true
27
+ d_hidden: 64
28
+
29
+ # Training
30
+ batch_size: 32
31
+ batch_size_test: 32
32
+ lr: 1.0e-3
33
+ training_iterations: 100001
34
+ warmup_steps: 5000
35
+ use_scheduler: true
36
+ scheduler_type: cosine
37
+ weight_decay: 0.0
38
+ gradient_clipping: -1
39
+ do_compile: false
40
+ num_workers_train: 4
41
+
42
+ # Loss
43
+ loss:
44
+ type: energy_contrastive
45
+ margin: 5.0
46
+ energy_scale: 0.5
47
+
48
+ # Inference
49
+ inference:
50
+ energy_threshold: 0.5
51
+ delta_threshold: 0.01
52
+
53
+ # Housekeeping
54
+ dataset: cifar10
55
+ data_root: data/
56
+ save_every: 1000
57
+ track_every: 1000
58
+ seed: 412
59
+ log_dir: logs/energy_experiment
60
+ device: [-1] # Auto-detect
inference_energy.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import yaml
3
+ import argparse
4
+ from models.ctm import ContinuousThoughtMachine
5
+
6
+ class EnergyInference:
7
+ def __init__(self, model_path, config_path, device='cpu'):
8
+ # Load Config
9
+ with open(config_path, 'r') as f:
10
+ self.config = yaml.safe_load(f)
11
+
12
+ self.device = device
13
+
14
+ # Load Model
15
+ # Reconstruct model args from config
16
+ # Note: This assumes config structure matches __init__ args or we map them
17
+ # For simplicity, we'll assume a flat config or specific mapping
18
+
19
+ # Extract model params from config
20
+ model_config = self.config
21
+
22
+ self.model = ContinuousThoughtMachine(
23
+ iterations=model_config['iterations'],
24
+ d_model=model_config['d_model'],
25
+ d_input=model_config['d_input'],
26
+ heads=model_config['heads'],
27
+ n_synch_out=model_config['n_synch_out'],
28
+ n_synch_action=model_config['n_synch_action'],
29
+ synapse_depth=model_config['synapse_depth'],
30
+ memory_length=model_config['memory_length'],
31
+ deep_nlms=model_config['deep_memory'],
32
+ memory_hidden_dims=model_config['memory_hidden_dims'],
33
+ do_layernorm_nlm=model_config['do_normalisation'],
34
+ backbone_type=model_config['backbone_type'],
35
+ positional_embedding_type=model_config['positional_embedding_type'],
36
+ out_dims=model_config['out_dims'],
37
+ prediction_reshaper=model_config.get('prediction_reshaper', [-1]),
38
+ dropout=model_config.get('dropout', 0.0),
39
+ neuron_select_type=model_config.get('neuron_select_type', 'random-pairing'),
40
+ n_random_pairing_self=model_config.get('n_random_pairing_self', 0),
41
+ energy_head_enabled=model_config.get('energy_head', {}).get('enabled', False),
42
+ energy_hidden_dim=model_config.get('energy_head', {}).get('d_hidden', 64)
43
+ ).to(self.device)
44
+
45
+ checkpoint = torch.load(model_path, map_location=self.device)
46
+ self.model.load_state_dict(checkpoint['model_state_dict'])
47
+ self.model.eval()
48
+
49
+ def run_adaptive(self, inputs, energy_threshold=1.0, delta_threshold=0.01):
50
+ """
51
+ Runs the CTM and halts when Energy < threshold OR Energy stabilizes.
52
+ """
53
+ inputs = inputs.to(self.device)
54
+ batch_size = inputs.shape[0]
55
+
56
+ # We need to run the model step-by-step.
57
+ # However, the current CTM implementation runs the full loop in forward().
58
+ # To support adaptive halting without refactoring the whole model into a cell,
59
+ # we can run the full forward pass and then post-process the energy history
60
+ # to determine when it WOULD have halted.
61
+ # This is less efficient but easier to implement given the current codebase.
62
+
63
+ with torch.no_grad():
64
+ # Run full forward pass
65
+ predictions, certainties, energies = self.model(inputs)
66
+ # energies shape: [B, 1, T]
67
+ energies = energies.squeeze(1) # [B, T]
68
+
69
+ final_predictions = torch.zeros(batch_size, dtype=torch.long, device=self.device)
70
+ final_steps = torch.zeros(batch_size, dtype=torch.long, device=self.device)
71
+
72
+ for b in range(batch_size):
73
+ halted = False
74
+ for t in range(self.model.iterations):
75
+ energy = energies[b, t]
76
+
77
+ # 1. Check Absolute Energy Threshold
78
+ is_low_energy = energy < energy_threshold
79
+
80
+ # 2. Check Convergence
81
+ if t > 0:
82
+ prev_energy = energies[b, t-1]
83
+ energy_delta = torch.abs(energy - prev_energy)
84
+ is_converged = energy_delta < delta_threshold
85
+ else:
86
+ is_converged = False
87
+
88
+ if is_low_energy or is_converged:
89
+ final_predictions[b] = predictions[b, :, t].argmax()
90
+ final_steps[b] = t + 1
91
+ halted = True
92
+ break
93
+
94
+ if not halted:
95
+ final_predictions[b] = predictions[b, :, -1].argmax()
96
+ final_steps[b] = self.model.iterations
97
+
98
+ return final_predictions, final_steps
99
+
100
+ if __name__ == "__main__":
101
+ parser = argparse.ArgumentParser()
102
+ parser.add_argument('--model_path', type=str, required=True)
103
+ parser.add_argument('--config_path', type=str, required=True)
104
+ args = parser.parse_args()
105
+
106
+ # Example usage (requires data)
107
+ print("Inference script created. Use EnergyInference class to run adaptive inference.")
models/ctm.py CHANGED
@@ -98,6 +98,8 @@ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
98
  dropout_nlm=None,
99
  neuron_select_type='random-pairing',
100
  n_random_pairing_self=0,
 
 
101
  ):
102
  super(ContinuousThoughtMachine, self).__init__()
103
 
@@ -115,6 +117,8 @@ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
115
  self.neuron_select_type = neuron_select_type
116
  self.memory_length = memory_length
117
  dropout_nlm = dropout if dropout_nlm is None else dropout_nlm
 
 
118
 
119
  # --- Assertions ---
120
  self.verify_args()
@@ -149,6 +153,14 @@ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
149
 
150
  # --- Output Procesing ---
151
  self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))
 
 
 
 
 
 
 
 
152
 
153
  @classmethod
154
  def _from_pretrained(
@@ -460,13 +472,23 @@ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
460
  neuron_indices_left = neuron_indices_right = torch.arange(d_model-n_synch, d_model)
461
 
462
  elif self.neuron_select_type=='random':
463
- neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
464
- neuron_indices_right = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
465
 
466
  elif self.neuron_select_type=='random-pairing':
467
  assert n_synch > n_random_pairing_self, f"Need at least {n_random_pairing_self} pairs for {self.neuron_select_type}"
468
- neuron_indices_left = torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch))
469
- neuron_indices_right = torch.concatenate((neuron_indices_left[:n_random_pairing_self], torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch-n_random_pairing_self))))
 
 
 
 
 
 
 
 
 
 
470
 
471
  device = self.start_activated_state.device
472
  return neuron_indices_left.to(device), neuron_indices_right.to(device)
@@ -533,7 +555,9 @@ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
533
  post_activations_tracking = []
534
  synch_out_tracking = []
535
  synch_action_tracking = []
 
536
  attention_tracking = []
 
537
 
538
  # --- Featurise Input Data ---
539
  kv = self.compute_features(x)
@@ -544,7 +568,9 @@ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
544
 
545
  # --- Prepare Storage for Outputs per Iteration ---
546
  predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=torch.float32)
 
547
  certainties = torch.empty(B, 2, self.iterations, device=device, dtype=torch.float32)
 
548
 
549
  # --- Initialise Recurrent Synch Values ---
550
  decay_alpha_action, decay_beta_action = None, None
@@ -586,8 +612,13 @@ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
586
  current_prediction = self.output_projector(synchronisation_out)
587
  current_certainty = self.compute_certainty(current_prediction)
588
 
 
589
  predictions[..., stepi] = current_prediction
590
  certainties[..., stepi] = current_certainty
 
 
 
 
591
 
592
  # --- Tracking ---
593
  if track:
@@ -600,5 +631,11 @@ class ContinuousThoughtMachine(nn.Module, PyTorchModelHubMixin):
600
  # --- Return Values ---
601
  if track:
602
  return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
 
 
 
 
 
 
603
  return predictions, certainties, synchronisation_out
604
 
 
98
  dropout_nlm=None,
99
  neuron_select_type='random-pairing',
100
  n_random_pairing_self=0,
101
+ energy_head_enabled=False,
102
+ energy_hidden_dim=64,
103
  ):
104
  super(ContinuousThoughtMachine, self).__init__()
105
 
 
117
  self.neuron_select_type = neuron_select_type
118
  self.memory_length = memory_length
119
  dropout_nlm = dropout if dropout_nlm is None else dropout_nlm
120
+ self.energy_head_enabled = energy_head_enabled
121
+ self.energy_hidden_dim = energy_hidden_dim
122
 
123
  # --- Assertions ---
124
  self.verify_args()
 
153
 
154
  # --- Output Procesing ---
155
  self.output_projector = nn.Sequential(nn.LazyLinear(self.out_dims))
156
+
157
+ # --- Energy Projector ---
158
+ if self.energy_head_enabled:
159
+ self.energy_proj = nn.Sequential(
160
+ nn.LazyLinear(self.energy_hidden_dim),
161
+ nn.SiLU(),
162
+ nn.Linear(self.energy_hidden_dim, 1)
163
+ )
164
 
165
  @classmethod
166
  def _from_pretrained(
 
472
  neuron_indices_left = neuron_indices_right = torch.arange(d_model-n_synch, d_model)
473
 
474
  elif self.neuron_select_type=='random':
475
+ neuron_indices_left = torch.randperm(d_model)[:n_synch]
476
+ neuron_indices_right = torch.randperm(d_model)[:n_synch]
477
 
478
  elif self.neuron_select_type=='random-pairing':
479
  assert n_synch > n_random_pairing_self, f"Need at least {n_random_pairing_self} pairs for {self.neuron_select_type}"
480
+ neuron_indices_left = torch.randperm(d_model)[:n_synch]
481
+ # For right, we need to concatenate self-pairs and random pairs
482
+ # This logic mimics the original numpy logic but using torch
483
+ # Original: neuron_indices_right = torch.concatenate((neuron_indices_left[:n_random_pairing_self], torch.from_numpy(np.random.choice(np.arange(d_model), size=n_synch-n_random_pairing_self))))
484
+
485
+ # Note: The original logic allowed replacement in the random choice for the second part?
486
+ # np.random.choice(np.arange(d_model), size=...) defaults to replace=False if not specified? No, defaults to replace=True?
487
+ # Actually np.random.choice(a, size) defaults to replace=True if a is an int? No, wait.
488
+ # Let's assume we want random indices.
489
+
490
+ random_part = torch.randperm(d_model)[:n_synch-n_random_pairing_self]
491
+ neuron_indices_right = torch.cat((neuron_indices_left[:n_random_pairing_self], random_part))
492
 
493
  device = self.start_activated_state.device
494
  return neuron_indices_left.to(device), neuron_indices_right.to(device)
 
555
  post_activations_tracking = []
556
  synch_out_tracking = []
557
  synch_action_tracking = []
558
+ synch_action_tracking = []
559
  attention_tracking = []
560
+ energy_tracking = []
561
 
562
  # --- Featurise Input Data ---
563
  kv = self.compute_features(x)
 
568
 
569
  # --- Prepare Storage for Outputs per Iteration ---
570
  predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=torch.float32)
571
+ predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=torch.float32)
572
  certainties = torch.empty(B, 2, self.iterations, device=device, dtype=torch.float32)
573
+ energies = torch.empty(B, 1, self.iterations, device=device, dtype=torch.float32) if self.energy_head_enabled else None
574
 
575
  # --- Initialise Recurrent Synch Values ---
576
  decay_alpha_action, decay_beta_action = None, None
 
612
  current_prediction = self.output_projector(synchronisation_out)
613
  current_certainty = self.compute_certainty(current_prediction)
614
 
615
+ predictions[..., stepi] = current_prediction
616
  predictions[..., stepi] = current_prediction
617
  certainties[..., stepi] = current_certainty
618
+
619
+ if self.energy_head_enabled:
620
+ current_energy = self.energy_proj(synchronisation_out)
621
+ energies[..., stepi] = current_energy
622
 
623
  # --- Tracking ---
624
  if track:
 
631
  # --- Return Values ---
632
  if track:
633
  return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
634
+ if track:
635
+ return predictions, certainties, (np.array(synch_out_tracking), np.array(synch_action_tracking)), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
636
+
637
+ if self.energy_head_enabled:
638
+ return predictions, certainties, energies
639
+
640
  return predictions, certainties, synchronisation_out
641
 
pixi.lock CHANGED
@@ -203,7 +203,7 @@ environments:
203
  - conda: https://conda.anaconda.org/conda-forge/noarch/networkx-3.5-pyhe01879c_0.conda
204
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/nlohmann_json-3.12.0-h248ca61_1.conda
205
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numba-0.62.1-py312hd24c766_0.conda
206
- - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.3.5-py312h85ea64e_0.conda
207
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/opencv-4.12.0-qt6_py312h5b798a3_607.conda
208
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openexr-3.4.4-h3c4c831_0.conda
209
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openh264-2.6.0-hb5b2745_0.conda
@@ -2957,17 +2957,16 @@ packages:
2957
  - pkg:pypi/numba?source=hash-mapping
2958
  size: 5691441
2959
  timestamp: 1759165626923
2960
- - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.3.5-py312h85ea64e_0.conda
2961
- sha256: 095dc7f15d2f8d9970a6a4e9d4a1980989a4209cd34c2b756fbd40e71f6990cc
2962
- md5: ee4c185ae9c1edeb8e8cd26273c90a9a
2963
  depends:
2964
- - python
2965
- - __osx >=11.0
2966
- - python 3.12.* *_cpython
2967
- - libcxx >=19
2968
- - libcblas >=3.9.0,<4.0a0
2969
  - libblas >=3.9.0,<4.0a0
 
 
2970
  - liblapack >=3.9.0,<4.0a0
 
 
2971
  - python_abi 3.12.* *_cp312
2972
  constrains:
2973
  - numpy-base <0a0
@@ -2975,8 +2974,8 @@ packages:
2975
  license_family: BSD
2976
  purls:
2977
  - pkg:pypi/numpy?source=hash-mapping
2978
- size: 6704341
2979
- timestamp: 1763350985482
2980
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/opencv-4.12.0-qt6_py312h5b798a3_607.conda
2981
  sha256: 71b1ce5a0073c59d766a94ec80c6f248bba880cb4ea7763e203595a8fdab4fb5
2982
  md5: 6ab56fafd591c51e28c0d1ed3887f8a7
 
203
  - conda: https://conda.anaconda.org/conda-forge/noarch/networkx-3.5-pyhe01879c_0.conda
204
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/nlohmann_json-3.12.0-h248ca61_1.conda
205
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numba-0.62.1-py312hd24c766_0.conda
206
+ - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-1.26.4-py312h8442bc7_0.conda
207
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/opencv-4.12.0-qt6_py312h5b798a3_607.conda
208
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openexr-3.4.4-h3c4c831_0.conda
209
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openh264-2.6.0-hb5b2745_0.conda
 
2957
  - pkg:pypi/numba?source=hash-mapping
2958
  size: 5691441
2959
  timestamp: 1759165626923
2960
+ - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-1.26.4-py312h8442bc7_0.conda
2961
+ sha256: c8841d6d6f61fd70ca80682efbab6bdb8606dc77c68d8acabfbd7c222054f518
2962
+ md5: d83fc83d589e2625a3451c9a7e21047c
2963
  depends:
 
 
 
 
 
2964
  - libblas >=3.9.0,<4.0a0
2965
+ - libcblas >=3.9.0,<4.0a0
2966
+ - libcxx >=16
2967
  - liblapack >=3.9.0,<4.0a0
2968
+ - python >=3.12,<3.13.0a0
2969
+ - python >=3.12,<3.13.0a0 *_cpython
2970
  - python_abi 3.12.* *_cp312
2971
  constrains:
2972
  - numpy-base <0a0
 
2974
  license_family: BSD
2975
  purls:
2976
  - pkg:pypi/numpy?source=hash-mapping
2977
+ size: 6073136
2978
+ timestamp: 1707226249608
2979
  - conda: https://conda.anaconda.org/conda-forge/osx-arm64/opencv-4.12.0-qt6_py312h5b798a3_607.conda
2980
  sha256: 71b1ce5a0073c59d766a94ec80c6f248bba880cb4ea7763e203595a8fdab4fb5
2981
  md5: 6ab56fafd591c51e28c0d1ed3887f8a7
pixi.toml CHANGED
@@ -9,7 +9,7 @@ version = "0.1.0"
9
  python = "3.12.*"
10
  pytorch = "*"
11
  torchvision = "*"
12
- numpy = "*"
13
  matplotlib = "*"
14
  seaborn = "*"
15
  tqdm = "*"
 
9
  python = "3.12.*"
10
  pytorch = "*"
11
  torchvision = "*"
12
+ numpy = "<2.0"
13
  matplotlib = "*"
14
  seaborn = "*"
15
  tqdm = "*"
tasks/image_classification/train_energy.py ADDED
@@ -0,0 +1,725 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import seaborn as sns
8
+ sns.set_style('darkgrid')
9
+ import torch
10
+ if torch.cuda.is_available():
11
+ # For faster
12
+ torch.set_float32_matmul_precision('high')
13
+ import torch.nn as nn
14
+ from tqdm.auto import tqdm
15
+
16
+ import sys
17
+ from pathlib import Path
18
+
19
+ # Add project root to sys.path to allow imports from top-level packages
20
+ project_root = str(Path(__file__).resolve().parents[2])
21
+ if project_root not in sys.path:
22
+ sys.path.append(project_root)
23
+
24
+ from data.custom_datasets import ImageNet
25
+ from torchvision import datasets
26
+ from torchvision import transforms
27
+ from tasks.image_classification.imagenet_classes import IMAGENET2012_CLASSES
28
+ from models.ctm import ContinuousThoughtMachine
29
+ from models.lstm import LSTMBaseline
30
+ from models.ff import FFBaseline
31
+ from tasks.image_classification.plotting import plot_neural_dynamics, make_classification_gif
32
+ from utils.housekeeping import set_seed, zip_python_code
33
+ from utils.losses import image_classification_loss, EnergyContrastiveLoss # Used by CTM, LSTM
34
+ from utils.schedulers import WarmupCosineAnnealingLR, WarmupMultiStepLR, warmup
35
+
36
+ from autoclip.torch import QuantileClip
37
+
38
+ import gc
39
+ import torchvision
40
+ torchvision.disable_beta_transforms_warning()
41
+
42
+
43
+ import warnings
44
+ warnings.filterwarnings("ignore", message="using precomputed metric; inverse_transform will be unavailable")
45
+ warnings.filterwarnings('ignore', message='divide by zero encountered in power', category=RuntimeWarning)
46
+ warnings.filterwarnings(
47
+ "ignore",
48
+ "Corrupt EXIF data",
49
+ UserWarning,
50
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
51
+ )
52
+ warnings.filterwarnings(
53
+ "ignore",
54
+ "UserWarning: Metadata Warning",
55
+ UserWarning,
56
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
57
+ )
58
+ warnings.filterwarnings(
59
+ "ignore",
60
+ "UserWarning: Truncated File Read",
61
+ UserWarning,
62
+ r"^PIL\.TiffImagePlugin$" # Using a regular expression to match the module.
63
+ )
64
+
65
+
66
+ def parse_args():
67
+ parser = argparse.ArgumentParser()
68
+
69
+ # Model Selection
70
+ parser.add_argument('--model', type=str, default='ctm', choices=['ctm', 'lstm', 'ff'], help='Model type to train.')
71
+
72
+ # Model Architecture
73
+ # Common
74
+ parser.add_argument('--d_model', type=int, default=512, help='Dimension of the model.')
75
+ parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
76
+ parser.add_argument('--backbone_type', type=str, default='resnet18-4', help='Type of backbone featureiser.')
77
+ # CTM / LSTM specific
78
+ parser.add_argument('--d_input', type=int, default=128, help='Dimension of the input (CTM, LSTM).')
79
+ parser.add_argument('--heads', type=int, default=4, help='Number of attention heads (CTM, LSTM).')
80
+ parser.add_argument('--iterations', type=int, default=75, help='Number of internal ticks (CTM, LSTM).')
81
+ parser.add_argument('--positional_embedding_type', type=str, default='none', help='Type of positional embedding (CTM, LSTM).',
82
+ choices=['none',
83
+ 'learnable-fourier',
84
+ 'multi-learnable-fourier',
85
+ 'custom-rotational'])
86
+ # CTM specific
87
+ parser.add_argument('--synapse_depth', type=int, default=4, help='Depth of U-NET model for synapse. 1=linear, no unet (CTM only).')
88
+ parser.add_argument('--n_synch_out', type=int, default=512, help='Number of neurons to use for output synch (CTM only).')
89
+ parser.add_argument('--n_synch_action', type=int, default=512, help='Number of neurons to use for observation/action synch (CTM only).')
90
+ parser.add_argument('--neuron_select_type', type=str, default='random-pairing', help='Protocol for selecting neuron subset (CTM only).')
91
+ parser.add_argument('--n_random_pairing_self', type=int, default=0, help='Number of neurons paired self-to-self for synch (CTM only).')
92
+ parser.add_argument('--memory_length', type=int, default=25, help='Length of the pre-activation history for NLMS (CTM only).')
93
+ parser.add_argument('--deep_memory', action=argparse.BooleanOptionalAction, default=True, help='Use deep memory (CTM only).')
94
+ parser.add_argument('--memory_hidden_dims', type=int, default=4, help='Hidden dimensions of the memory if using deep memory (CTM only).')
95
+ parser.add_argument('--dropout_nlm', type=float, default=None, help='Dropout rate for NLMs specifically. Unset to match dropout on the rest of the model (CTM only).')
96
+ parser.add_argument('--do_normalisation', action=argparse.BooleanOptionalAction, default=False, help='Apply normalization in NLMs (CTM only).')
97
+
98
+ # Energy Head
99
+ parser.add_argument('--energy_head_enabled', action=argparse.BooleanOptionalAction, default=False, help='Enable energy head.')
100
+ parser.add_argument('--energy_hidden_dim', type=int, default=64, help='Hidden dim for energy head.')
101
+ parser.add_argument('--loss_type', type=str, default='standard', choices=['standard', 'energy_contrastive'], help='Loss type.')
102
+ parser.add_argument('--energy_margin', type=float, default=10.0, help='Margin for energy loss.')
103
+ parser.add_argument('--energy_scale', type=float, default=0.1, help='Scale for energy loss.')
104
+
105
+ # LSTM specific
106
+ parser.add_argument('--num_layers', type=int, default=2, help='Number of LSTM stacked layers (LSTM only).')
107
+
108
+ # Training
109
+ parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training.')
110
+ parser.add_argument('--batch_size_test', type=int, default=32, help='Batch size for testing.')
111
+ parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate for the model.')
112
+ parser.add_argument('--training_iterations', type=int, default=100001, help='Number of training iterations.')
113
+ parser.add_argument('--warmup_steps', type=int, default=5000, help='Number of warmup steps.')
114
+ parser.add_argument('--use_scheduler', action=argparse.BooleanOptionalAction, default=True, help='Use a learning rate scheduler.')
115
+ parser.add_argument('--scheduler_type', type=str, default='cosine', choices=['multistep', 'cosine'], help='Type of learning rate scheduler.')
116
+ parser.add_argument('--milestones', type=int, default=[8000, 15000, 20000], nargs='+', help='Learning rate scheduler milestones.')
117
+ parser.add_argument('--gamma', type=float, default=0.1, help='Learning rate scheduler gamma for multistep.')
118
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='Weight decay factor.')
119
+ parser.add_argument('--weight_decay_exclusion_list', type=str, nargs='+', default=[], help='List to exclude from weight decay. Typically good: bn, ln, bias, start')
120
+ parser.add_argument('--gradient_clipping', type=float, default=-1, help='Gradient quantile clipping value (-1 to disable).')
121
+ parser.add_argument('--do_compile', action=argparse.BooleanOptionalAction, default=False, help='Try to compile model components (backbone, synapses if CTM).')
122
+ parser.add_argument('--num_workers_train', type=int, default=1, help='Num workers training.')
123
+
124
+ # Housekeeping
125
+ parser.add_argument('--log_dir', type=str, default='logs/scratch', help='Directory for logging.')
126
+ parser.add_argument('--dataset', type=str, default='cifar10', help='Dataset to use.')
127
+ parser.add_argument('--data_root', type=str, default='data/', help='Where to save dataset.')
128
+ parser.add_argument('--save_every', type=int, default=1000, help='Save checkpoints every this many iterations.')
129
+ parser.add_argument('--seed', type=int, default=412, help='Random seed.')
130
+ parser.add_argument('--reload', action=argparse.BooleanOptionalAction, default=False, help='Reload from disk?')
131
+ parser.add_argument('--reload_model_only', action=argparse.BooleanOptionalAction, default=False, help='Reload only the model from disk?')
132
+ parser.add_argument('--strict_reload', action=argparse.BooleanOptionalAction, default=True, help='Should use strict reload for model weights.') # Added back
133
+ parser.add_argument('--track_every', type=int, default=1000, help='Track metrics every this many iterations.')
134
+ parser.add_argument('--n_test_batches', type=int, default=20, help='How many minibatches to approx metrics. Set to -1 for full eval')
135
+ parser.add_argument('--device', type=int, nargs='+', default=[-1], help='List of GPU(s) to use. Set to -1 to use CPU.')
136
+ parser.add_argument('--use_amp', action=argparse.BooleanOptionalAction, default=False, help='AMP autocast.')
137
+
138
+
139
+ args = parser.parse_args()
140
+ return args
141
+
142
+
143
+ def get_dataset(dataset, root):
144
+ if dataset=='imagenet':
145
+ dataset_mean = [0.485, 0.456, 0.406]
146
+ dataset_std = [0.229, 0.224, 0.225]
147
+
148
+ normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
149
+ train_transform = transforms.Compose([
150
+ transforms.RandomResizedCrop(224),
151
+ transforms.RandomHorizontalFlip(),
152
+ transforms.ToTensor(),
153
+ normalize])
154
+ test_transform = transforms.Compose([
155
+ transforms.Resize(256),
156
+ transforms.CenterCrop(224),
157
+ transforms.ToTensor(),
158
+ normalize])
159
+
160
+ class_labels = list(IMAGENET2012_CLASSES.values())
161
+
162
+ train_data = ImageNet(which_split='train', transform=train_transform)
163
+ test_data = ImageNet(which_split='validation', transform=test_transform)
164
+ elif dataset=='cifar10':
165
+ dataset_mean = [0.49139968, 0.48215827, 0.44653124]
166
+ dataset_std = [0.24703233, 0.24348505, 0.26158768]
167
+ normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
168
+ train_transform = transforms.Compose(
169
+ [transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
170
+ transforms.ToTensor(),
171
+ normalize,
172
+ ])
173
+
174
+ test_transform = transforms.Compose(
175
+ [transforms.ToTensor(),
176
+ normalize,
177
+ ])
178
+ train_data = datasets.CIFAR10(root, train=True, transform=train_transform, download=True)
179
+ test_data = datasets.CIFAR10(root, train=False, transform=test_transform, download=True)
180
+ class_labels = ['air', 'auto', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
181
+ elif dataset=='cifar100':
182
+ dataset_mean = [0.5070751592371341, 0.48654887331495067, 0.4409178433670344]
183
+ dataset_std = [0.2673342858792403, 0.2564384629170882, 0.27615047132568393]
184
+ normalize = transforms.Normalize(mean=dataset_mean, std=dataset_std)
185
+
186
+ train_transform = transforms.Compose(
187
+ [transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
188
+ transforms.ToTensor(),
189
+ normalize,
190
+ ])
191
+ test_transform = transforms.Compose(
192
+ [transforms.ToTensor(),
193
+ normalize,
194
+ ])
195
+ train_data = datasets.CIFAR100(root, train=True, transform=train_transform, download=True)
196
+ test_data = datasets.CIFAR100(root, train=False, transform=test_transform, download=True)
197
+ idx_order = np.argsort(np.array(list(train_data.class_to_idx.values())))
198
+ class_labels = list(np.array(list(train_data.class_to_idx.keys()))[idx_order])
199
+ else:
200
+ raise NotImplementedError
201
+
202
+ return train_data, test_data, class_labels, dataset_mean, dataset_std
203
+
204
+
205
+
206
+ if __name__=='__main__':
207
+
208
+ # Hosuekeeping
209
+ args = parse_args()
210
+
211
+ set_seed(args.seed, False)
212
+ if not os.path.exists(args.log_dir): os.makedirs(args.log_dir)
213
+
214
+ assert args.dataset in ['cifar10', 'cifar100', 'imagenet']
215
+
216
+ # Data
217
+ train_data, test_data, class_labels, dataset_mean, dataset_std = get_dataset(args.dataset, args.data_root)
218
+
219
+ num_workers_test = 1 # Defaulting to 1, change if needed
220
+ trainloader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers_train)
221
+ testloader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test, drop_last=False)
222
+
223
+ prediction_reshaper = [-1] # Problem specific
224
+ args.out_dims = len(class_labels)
225
+
226
+ # For total reproducibility
227
+ zip_python_code(f'{args.log_dir}/repo_state.zip')
228
+ with open(f'{args.log_dir}/args.txt', 'w') as f:
229
+ print(args, file=f)
230
+
231
+ # Configure device string (support MPS on macOS)
232
+ if args.device[0] != -1:
233
+ device = f'cuda:{args.device[0]}'
234
+ elif torch.backends.mps.is_available():
235
+ device = 'mps'
236
+ else:
237
+ device = 'cpu'
238
+ print(f'Running model {args.model} on {device}')
239
+
240
+ # Build model conditionally
241
+ model = None
242
+ if args.model == 'ctm':
243
+ model = ContinuousThoughtMachine(
244
+ iterations=args.iterations,
245
+ d_model=args.d_model,
246
+ d_input=args.d_input,
247
+ heads=args.heads,
248
+ n_synch_out=args.n_synch_out,
249
+ n_synch_action=args.n_synch_action,
250
+ synapse_depth=args.synapse_depth,
251
+ memory_length=args.memory_length,
252
+ deep_nlms=args.deep_memory,
253
+ memory_hidden_dims=args.memory_hidden_dims,
254
+ do_layernorm_nlm=args.do_normalisation,
255
+ backbone_type=args.backbone_type,
256
+ positional_embedding_type=args.positional_embedding_type,
257
+ out_dims=args.out_dims,
258
+ prediction_reshaper=prediction_reshaper,
259
+ dropout=args.dropout,
260
+ dropout_nlm=args.dropout_nlm,
261
+ neuron_select_type=args.neuron_select_type,
262
+ n_random_pairing_self=args.n_random_pairing_self,
263
+ energy_head_enabled=args.energy_head_enabled,
264
+ energy_hidden_dim=args.energy_hidden_dim,
265
+ ).to(device)
266
+ elif args.model == 'lstm':
267
+ model = LSTMBaseline(
268
+ num_layers=args.num_layers,
269
+ iterations=args.iterations,
270
+ d_model=args.d_model,
271
+ d_input=args.d_input,
272
+ heads=args.heads,
273
+ backbone_type=args.backbone_type,
274
+ positional_embedding_type=args.positional_embedding_type,
275
+ out_dims=args.out_dims,
276
+ prediction_reshaper=prediction_reshaper,
277
+ dropout=args.dropout,
278
+ ).to(device)
279
+ elif args.model == 'ff':
280
+ model = FFBaseline(
281
+ d_model=args.d_model,
282
+ backbone_type=args.backbone_type,
283
+ out_dims=args.out_dims,
284
+ dropout=args.dropout,
285
+ ).to(device)
286
+ else:
287
+ raise ValueError(f"Unknown model type: {args.model}")
288
+
289
+
290
+ # For lazy modules so that we can get param count
291
+ pseudo_inputs = train_data.__getitem__(0)[0].unsqueeze(0).to(device)
292
+ model(pseudo_inputs)
293
+
294
+ model.train()
295
+
296
+
297
+ print(f'Total params: {sum(p.numel() for p in model.parameters())}')
298
+ decay_params = []
299
+ no_decay_params = []
300
+ no_decay_names = []
301
+ for name, param in model.named_parameters():
302
+ if not param.requires_grad:
303
+ continue # Skip parameters that don't require gradients
304
+ if any(exclusion_str in name for exclusion_str in args.weight_decay_exclusion_list):
305
+ no_decay_params.append(param)
306
+ no_decay_names.append(name)
307
+ else:
308
+ decay_params.append(param)
309
+ if len(no_decay_names):
310
+ print(f'WARNING, excluding: {no_decay_names}')
311
+
312
+ # Optimizer and scheduler (Common setup)
313
+ if len(no_decay_names) and args.weight_decay!=0:
314
+ optimizer = torch.optim.AdamW([{'params': decay_params, 'weight_decay':args.weight_decay},
315
+ {'params': no_decay_params, 'weight_decay':0}],
316
+ lr=args.lr,
317
+ eps=1e-8 if not args.use_amp else 1e-6)
318
+ else:
319
+ optimizer = torch.optim.AdamW(model.parameters(),
320
+ lr=args.lr,
321
+ eps=1e-8 if not args.use_amp else 1e-6,
322
+ weight_decay=args.weight_decay)
323
+
324
+
325
+ warmup_schedule = warmup(args.warmup_steps)
326
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_schedule.step)
327
+ if args.use_scheduler:
328
+ if args.scheduler_type == 'multistep':
329
+ scheduler = WarmupMultiStepLR(optimizer, warmup_steps=args.warmup_steps, milestones=args.milestones, gamma=args.gamma)
330
+ elif args.scheduler_type == 'cosine':
331
+ scheduler = WarmupCosineAnnealingLR(optimizer, args.warmup_steps, args.training_iterations, warmup_start_lr=1e-20, eta_min=1e-7)
332
+ else:
333
+ raise NotImplementedError
334
+
335
+
336
+ # Metrics tracking
337
+ start_iter = 0
338
+ train_losses = []
339
+ test_losses = []
340
+ train_accuracies = []
341
+ test_accuracies = []
342
+ iters = []
343
+ # Conditional metrics for CTM/LSTM
344
+ train_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
345
+ test_accuracies_most_certain = [] if args.model in ['ctm', 'lstm'] else None
346
+
347
+ # scaler = torch.amp.GradScaler("cuda" if "cuda" in device else "cpu", enabled=args.use_amp)
348
+ # Fallback for older torch versions or specific builds
349
+ scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
350
+
351
+ # Reloading logic
352
+ if args.reload:
353
+ checkpoint_path = f'{args.log_dir}/checkpoint.pt'
354
+ if os.path.isfile(checkpoint_path):
355
+ print(f'Reloading from: {checkpoint_path}')
356
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
357
+ if not args.strict_reload: print('WARNING: not using strict reload for model weights!')
358
+ load_result = model.load_state_dict(checkpoint['model_state_dict'], strict=args.strict_reload)
359
+ print(f" Loaded state_dict. Missing: {load_result.missing_keys}, Unexpected: {load_result.unexpected_keys}")
360
+
361
+ if not args.reload_model_only:
362
+ print('Reloading optimizer etc.')
363
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
364
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
365
+ scaler.load_state_dict(checkpoint['scaler_state_dict'])
366
+ start_iter = checkpoint['iteration']
367
+ # Load common metrics
368
+ train_losses = checkpoint['train_losses']
369
+ test_losses = checkpoint['test_losses']
370
+ train_accuracies = checkpoint['train_accuracies']
371
+ test_accuracies = checkpoint['test_accuracies']
372
+ iters = checkpoint['iters']
373
+
374
+ # Load conditional metrics if they exist in checkpoint and are expected for current model
375
+ if args.model in ['ctm', 'lstm']:
376
+ train_accuracies_most_certain = checkpoint['train_accuracies_most_certain']
377
+ test_accuracies_most_certain = checkpoint['test_accuracies_most_certain']
378
+
379
+ else:
380
+ print('Only reloading model!')
381
+
382
+ if 'torch_rng_state' in checkpoint:
383
+ # Reset seeds
384
+ torch.set_rng_state(checkpoint['torch_rng_state'].cpu().byte())
385
+ np.random.set_state(checkpoint['numpy_rng_state'])
386
+ random.setstate(checkpoint['random_rng_state'])
387
+
388
+ del checkpoint
389
+ gc.collect()
390
+ if torch.cuda.is_available():
391
+ torch.cuda.empty_cache()
392
+
393
+ # Conditional Compilation
394
+ if args.do_compile:
395
+ print('Compiling...')
396
+ if hasattr(model, 'backbone'):
397
+ model.backbone = torch.compile(model.backbone, mode='reduce-overhead', fullgraph=True)
398
+
399
+ # Compile synapses only for CTM
400
+ if args.model == 'ctm':
401
+ model.synapses = torch.compile(model.synapses, mode='reduce-overhead', fullgraph=True)
402
+
403
+ # Training
404
+ iterator = iter(trainloader)
405
+
406
+
407
+ with tqdm(total=args.training_iterations, initial=start_iter, leave=False, position=0, dynamic_ncols=True) as pbar:
408
+ for bi in range(start_iter, args.training_iterations):
409
+ current_lr = optimizer.param_groups[-1]['lr']
410
+
411
+ try:
412
+ inputs, targets = next(iterator)
413
+ except StopIteration:
414
+ iterator = iter(trainloader)
415
+ inputs, targets = next(iterator)
416
+
417
+ inputs = inputs.to(device)
418
+ targets = targets.to(device)
419
+
420
+ loss = None
421
+ accuracy = None
422
+ # Model-specific forward and loss calculation
423
+ with torch.autocast(device_type="cuda" if "cuda" in device else "cpu", dtype=torch.float16, enabled=args.use_amp):
424
+ if args.do_compile: # CUDAGraph marking for clean compile
425
+ torch.compiler.cudagraph_mark_step_begin()
426
+
427
+ if args.model == 'ctm':
428
+ if args.energy_head_enabled:
429
+ predictions, certainties, energies = model(inputs)
430
+ if args.loss_type == 'energy_contrastive':
431
+ criterion = EnergyContrastiveLoss(margin=args.energy_margin, energy_scale=args.energy_scale)
432
+ loss, stats = criterion(predictions, energies, targets)
433
+ # Use standard accuracy metric for now
434
+ where_most_certain = certainties[:,1].argmax(-1)
435
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
436
+ pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Avg Energy={stats["avg_energy"]:0.3f}'
437
+ else:
438
+ # Fallback to standard loss even if energy head is enabled (but unused)
439
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
440
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
441
+ pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
442
+ else:
443
+ predictions, certainties, synchronisation = model(inputs)
444
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
445
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
446
+ pbar_desc = f'CTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
447
+
448
+ elif args.model == 'lstm':
449
+ predictions, certainties, synchronisation = model(inputs)
450
+ loss, where_most_certain = image_classification_loss(predictions, certainties, targets, use_most_certain=True)
451
+ # LSTM where_most_certain will just be -1 because use_most_certain is False owing to stability issues with LSTM training
452
+ accuracy = (predictions.argmax(1)[torch.arange(predictions.size(0), device=predictions.device),where_most_certain] == targets).float().mean().item()
453
+ pbar_desc = f'LSTM Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}. Where_certain={where_most_certain.float().mean().item():0.2f}+-{where_most_certain.float().std().item():0.2f} ({where_most_certain.min().item():d}<->{where_most_certain.max().item():d})'
454
+
455
+ elif args.model == 'ff':
456
+ predictions = model(inputs)
457
+ loss = nn.CrossEntropyLoss()(predictions, targets)
458
+ accuracy = (predictions.argmax(1) == targets).float().mean().item()
459
+ pbar_desc = f'FF Loss={loss.item():0.3f}. Acc={accuracy:0.3f}. LR={current_lr:0.6f}'
460
+
461
+ scaler.scale(loss).backward()
462
+
463
+ if args.gradient_clipping!=-1:
464
+ scaler.unscale_(optimizer)
465
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.gradient_clipping)
466
+
467
+ scaler.step(optimizer)
468
+ scaler.update()
469
+ optimizer.zero_grad(set_to_none=True)
470
+ scheduler.step()
471
+
472
+ pbar.set_description(f'Dataset={args.dataset}. Model={args.model}. {pbar_desc}')
473
+
474
+
475
+ # Metrics tracking and plotting (conditional logic needed)
476
+ if (bi % args.track_every == 0 or bi == args.warmup_steps) and (bi != 0 or args.reload_model_only):
477
+
478
+ iters.append(bi)
479
+ current_train_losses = []
480
+ current_test_losses = []
481
+ current_train_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
482
+ current_test_accuracies = [] # Holds list of accuracies per tick for CTM/LSTM, single value for FF
483
+ current_train_accuracies_most_certain = [] # Only for CTM/LSTM
484
+ current_test_accuracies_most_certain = [] # Only for CTM/LSTM
485
+
486
+
487
+ # Reset BN stats using train mode
488
+ pbar.set_description('Resetting BN')
489
+ model.train()
490
+ for module in model.modules():
491
+ if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)):
492
+ module.reset_running_stats()
493
+
494
+ pbar.set_description('Tracking: Computing TRAIN metrics')
495
+ with torch.no_grad(): # Should use inference_mode? CTM/LSTM scripts used no_grad
496
+ loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
497
+ all_targets_list = []
498
+ all_predictions_list = [] # List to store raw predictions (B, C, T) or (B, C)
499
+ all_predictions_most_certain_list = [] # Only for CTM/LSTM
500
+ all_losses = []
501
+
502
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
503
+ for inferi, (inputs, targets) in enumerate(loader):
504
+ inputs = inputs.to(device)
505
+ targets = targets.to(device)
506
+ all_targets_list.append(targets.detach().cpu().numpy())
507
+
508
+ # Model-specific forward and loss for evaluation
509
+ if args.model == 'ctm':
510
+ these_predictions, certainties, _ = model(inputs)
511
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
512
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
513
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
514
+
515
+ elif args.model == 'lstm':
516
+ these_predictions, certainties, _ = model(inputs)
517
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
518
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B, T)
519
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy()) # Shape (B,)
520
+
521
+ elif args.model == 'ff':
522
+ these_predictions = model(inputs)
523
+ loss = nn.CrossEntropyLoss()(these_predictions, targets)
524
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy()) # Shape (B,)
525
+
526
+ all_losses.append(loss.item())
527
+
528
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1 : break # Check condition >= N-1
529
+ pbar_inner.set_description(f'Computing metrics for train (Batch {inferi+1})')
530
+ pbar_inner.update(1)
531
+
532
+ all_targets = np.concatenate(all_targets_list)
533
+ all_predictions = np.concatenate(all_predictions_list) # Shape (N, T) or (N,)
534
+ train_losses.append(np.mean(all_losses))
535
+
536
+ if args.model in ['ctm', 'lstm']:
537
+ # Accuracies per tick for CTM/LSTM
538
+ current_train_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0) # Mean over batch dim -> Shape (T,)
539
+ train_accuracies.append(current_train_accuracies)
540
+ # Most certain accuracy
541
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
542
+ current_train_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
543
+ train_accuracies_most_certain.append(current_train_accuracies_most_certain)
544
+ else: # FF
545
+ current_train_accuracies = (all_targets == all_predictions).mean() # Shape scalar
546
+ train_accuracies.append(current_train_accuracies)
547
+
548
+ del these_predictions
549
+
550
+
551
+ # Switch to eval mode for test metrics (fixed BN stats)
552
+ model.eval()
553
+ pbar.set_description('Tracking: Computing TEST metrics')
554
+ with torch.inference_mode(): # Use inference_mode for test eval
555
+ loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size_test, shuffle=True, num_workers=num_workers_test)
556
+ all_targets_list = []
557
+ all_predictions_list = []
558
+ all_predictions_most_certain_list = [] # Only for CTM/LSTM
559
+ all_losses = []
560
+
561
+ with tqdm(total=len(loader), initial=0, leave=False, position=1, dynamic_ncols=True) as pbar_inner:
562
+ for inferi, (inputs, targets) in enumerate(loader):
563
+ inputs = inputs.to(device)
564
+ targets = targets.to(device)
565
+ all_targets_list.append(targets.detach().cpu().numpy())
566
+
567
+ # Model-specific forward and loss for evaluation
568
+ if args.model == 'ctm':
569
+ these_predictions, certainties, _ = model(inputs)
570
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
571
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
572
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
573
+
574
+ elif args.model == 'lstm':
575
+ these_predictions, certainties, _ = model(inputs)
576
+ loss, where_most_certain = image_classification_loss(these_predictions, certainties, targets, use_most_certain=True)
577
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
578
+ all_predictions_most_certain_list.append(these_predictions.argmax(1)[torch.arange(these_predictions.size(0), device=these_predictions.device), where_most_certain].detach().cpu().numpy())
579
+
580
+ elif args.model == 'ff':
581
+ these_predictions = model(inputs)
582
+ loss = nn.CrossEntropyLoss()(these_predictions, targets)
583
+ all_predictions_list.append(these_predictions.argmax(1).detach().cpu().numpy())
584
+
585
+ all_losses.append(loss.item())
586
+
587
+ if args.n_test_batches != -1 and inferi >= args.n_test_batches -1: break
588
+ pbar_inner.set_description(f'Computing metrics for test (Batch {inferi+1})')
589
+ pbar_inner.update(1)
590
+
591
+ all_targets = np.concatenate(all_targets_list)
592
+ all_predictions = np.concatenate(all_predictions_list)
593
+ test_losses.append(np.mean(all_losses))
594
+
595
+ if args.model in ['ctm', 'lstm']:
596
+ current_test_accuracies = np.mean(all_predictions == all_targets[...,np.newaxis], axis=0)
597
+ test_accuracies.append(current_test_accuracies)
598
+ all_predictions_most_certain = np.concatenate(all_predictions_most_certain_list)
599
+ current_test_accuracies_most_certain = (all_targets == all_predictions_most_certain).mean()
600
+ test_accuracies_most_certain.append(current_test_accuracies_most_certain)
601
+ else: # FF
602
+ current_test_accuracies = (all_targets == all_predictions).mean()
603
+ test_accuracies.append(current_test_accuracies)
604
+
605
+ # Plotting (conditional)
606
+ figacc = plt.figure(figsize=(10, 10))
607
+ axacc_train = figacc.add_subplot(211)
608
+ axacc_test = figacc.add_subplot(212)
609
+ cm = sns.color_palette("viridis", as_cmap=True)
610
+
611
+ if args.model in ['ctm', 'lstm']:
612
+ # Plot per-tick accuracy for CTM/LSTM
613
+ train_acc_arr = np.array(train_accuracies) # Shape (N_iters, T)
614
+ test_acc_arr = np.array(test_accuracies) # Shape (N_iters, T)
615
+ num_ticks = train_acc_arr.shape[1]
616
+ for ti in range(num_ticks):
617
+ axacc_train.plot(iters, train_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
618
+ axacc_test.plot(iters, test_acc_arr[:, ti], color=cm(ti / num_ticks), alpha=0.3)
619
+ # Plot most certain accuracy
620
+ axacc_train.plot(iters, train_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
621
+ axacc_test.plot(iters, test_accuracies_most_certain, 'k--', alpha=0.7, label='Most certain')
622
+ else: # FF
623
+ axacc_train.plot(iters, train_accuracies, 'k-', alpha=0.7, label='Accuracy') # Simple line
624
+ axacc_test.plot(iters, test_accuracies, 'k-', alpha=0.7, label='Accuracy')
625
+
626
+ axacc_train.set_title('Train Accuracy')
627
+ axacc_test.set_title('Test Accuracy')
628
+ axacc_train.legend(loc='lower right')
629
+ axacc_test.legend(loc='lower right')
630
+ axacc_train.set_xlim([0, args.training_iterations])
631
+ axacc_test.set_xlim([0, args.training_iterations])
632
+ if args.dataset=='cifar10':
633
+ axacc_train.set_ylim([0.75, 1])
634
+ axacc_test.set_ylim([0.75, 1])
635
+
636
+
637
+
638
+ figacc.tight_layout()
639
+ figacc.savefig(f'{args.log_dir}/accuracies.png', dpi=150)
640
+ plt.close(figacc)
641
+
642
+ figloss = plt.figure(figsize=(10, 5))
643
+ axloss = figloss.add_subplot(111)
644
+ axloss.plot(iters, train_losses, 'b-', linewidth=1, alpha=0.8, label=f'Train: {train_losses[-1]:.4f}')
645
+ axloss.plot(iters, test_losses, 'r-', linewidth=1, alpha=0.8, label=f'Test: {test_losses[-1]:.4f}')
646
+ axloss.legend(loc='upper right')
647
+ axloss.set_xlim([0, args.training_iterations])
648
+ axloss.set_ylim(bottom=0)
649
+
650
+ figloss.tight_layout()
651
+ figloss.savefig(f'{args.log_dir}/losses.png', dpi=150)
652
+ plt.close(figloss)
653
+
654
+ # Conditional Visualization (Only for CTM/LSTM)
655
+ if args.model in ['ctm', 'lstm']:
656
+ try: # For safety
657
+ inputs_viz, targets_viz = next(iter(testloader)) # Get a fresh batch
658
+ inputs_viz = inputs_viz.to(device)
659
+ targets_viz = targets_viz.to(device)
660
+
661
+ pbar.set_description('Tracking: Processing test data for viz')
662
+ predictions_viz, certainties_viz, _, pre_activations_viz, post_activations_viz, attention_tracking_viz = model(inputs_viz, track=True)
663
+
664
+ att_shape = (model.kv_features.shape[2], model.kv_features.shape[3])
665
+ attention_tracking_viz = attention_tracking_viz.reshape(
666
+ attention_tracking_viz.shape[0],
667
+ attention_tracking_viz.shape[1], -1, att_shape[0], att_shape[1])
668
+
669
+ pbar.set_description('Tracking: Neural dynamics plot')
670
+ plot_neural_dynamics(post_activations_viz, 100, args.log_dir, axis_snap=True)
671
+
672
+ imgi = 0 # Visualize the first image in the batch
673
+ img_to_gif = np.moveaxis(np.clip(inputs_viz[imgi].detach().cpu().numpy()*np.array(dataset_std).reshape(len(dataset_std), 1, 1) + np.array(dataset_mean).reshape(len(dataset_mean), 1, 1), 0, 1), 0, -1)
674
+
675
+ pbar.set_description('Tracking: Producing attention gif')
676
+ make_classification_gif(img_to_gif,
677
+ targets_viz[imgi].item(),
678
+ predictions_viz[imgi].detach().cpu().numpy(),
679
+ certainties_viz[imgi].detach().cpu().numpy(),
680
+ post_activations_viz[:,imgi],
681
+ attention_tracking_viz[:,imgi],
682
+ class_labels,
683
+ f'{args.log_dir}/{imgi}_attention.gif',
684
+ )
685
+ del predictions_viz, certainties_viz, pre_activations_viz, post_activations_viz, attention_tracking_viz
686
+ except Exception as e:
687
+ print(f"Visualization failed for model {args.model}: {e}")
688
+
689
+
690
+
691
+ gc.collect()
692
+ if torch.cuda.is_available():
693
+ torch.cuda.empty_cache()
694
+ model.train() # Switch back to train mode
695
+
696
+
697
+ # Save model checkpoint (conditional metrics)
698
+ if (bi % args.save_every == 0 or bi == args.training_iterations - 1) and bi != start_iter:
699
+ pbar.set_description('Saving model checkpoint...')
700
+ checkpoint_data = {
701
+ 'model_state_dict': model.state_dict(),
702
+ 'optimizer_state_dict': optimizer.state_dict(),
703
+ 'scheduler_state_dict': scheduler.state_dict(),
704
+ 'scaler_state_dict': scaler.state_dict(),
705
+ 'iteration': bi,
706
+ # Always save these
707
+ 'train_losses': train_losses,
708
+ 'test_losses': test_losses,
709
+ 'train_accuracies': train_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
710
+ 'test_accuracies': test_accuracies, # This is list of scalars for FF, list of arrays for CTM/LSTM
711
+ 'iters': iters,
712
+ 'args': args, # Save args used for this run
713
+ # RNG states
714
+ 'torch_rng_state': torch.get_rng_state(),
715
+ 'numpy_rng_state': np.random.get_state(),
716
+ 'random_rng_state': random.getstate(),
717
+ }
718
+ # Conditionally add metrics specific to CTM/LSTM
719
+ if args.model in ['ctm', 'lstm']:
720
+ checkpoint_data['train_accuracies_most_certain'] = train_accuracies_most_certain
721
+ checkpoint_data['test_accuracies_most_certain'] = test_accuracies_most_certain
722
+
723
+ torch.save(checkpoint_data, f'{args.log_dir}/checkpoint.pt')
724
+
725
+ pbar.update(1)
utils/losses.py CHANGED
@@ -169,6 +169,57 @@ def parity_loss(predictions, certainties, targets, use_most_certain=True):
169
  return loss, loss_index_2
170
 
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def qamnist_loss(predictions, certainties, targets, use_most_certain=True):
173
  """
174
  Computes the qamnist loss over the last num_answer_steps steps.
 
169
  return loss, loss_index_2
170
 
171
 
172
+ class EnergyContrastiveLoss(nn.Module):
173
+ def __init__(self, margin=10.0, energy_scale=0.1):
174
+ super().__init__()
175
+ self.margin = margin
176
+ self.energy_scale = energy_scale
177
+ self.ce_loss = nn.CrossEntropyLoss(reduction='none')
178
+
179
+ def forward(self, logits_history, energy_history, targets):
180
+ """
181
+ logits_history: [B, Class, T]
182
+ energy_history: [B, 1, T]
183
+ targets: [B]
184
+ """
185
+ B, C, T = logits_history.shape
186
+
187
+ # Flatten for easy computation
188
+ logits_flat = logits_history.permute(0, 2, 1).reshape(B * T, C)
189
+ energy_flat = energy_history.permute(0, 2, 1).reshape(B * T)
190
+ targets_expanded = targets.unsqueeze(1).repeat(1, T).reshape(B * T)
191
+
192
+ # 1. Standard Classification Loss (Cross Entropy)
193
+ ce_vals = self.ce_loss(logits_flat, targets_expanded)
194
+
195
+ # 2. Determine "Correctness" for Contrastive Divergence
196
+ # We treat a step as "positive" if the prediction matches the target
197
+ predictions = logits_flat.argmax(dim=1)
198
+ is_correct = (predictions == targets_expanded).float() # 1.0 if correct, 0.0 if wrong
199
+
200
+ # 3. Energy Loss Logic
201
+ # If Correct: Minimize Energy (Pull down to 0)
202
+ # If Incorrect: Maximize Energy (Push up to margin)
203
+
204
+ # L_pos = ||E(x)||^2 (Push correct states to 0 energy)
205
+ loss_pos = energy_flat ** 2
206
+
207
+ # L_neg = max(0, m - E(x))^2 (Push incorrect states above margin m)
208
+ loss_neg = F.relu(self.margin - energy_flat) ** 2
209
+
210
+ # Combine: correct samples use loss_pos, incorrect use loss_neg
211
+ energy_objective = (is_correct * loss_pos) + ((1 - is_correct) * loss_neg)
212
+
213
+ # Total Loss
214
+ total_loss = ce_vals.mean() + (self.energy_scale * energy_objective.mean())
215
+
216
+ return total_loss, {
217
+ "ce_loss": ce_vals.mean().item(),
218
+ "energy_loss": energy_objective.mean().item(),
219
+ "avg_energy": energy_flat.mean().item()
220
+ }
221
+
222
+
223
  def qamnist_loss(predictions, certainties, targets, use_most_certain=True):
224
  """
225
  Computes the qamnist loss over the last num_answer_steps steps.