File size: 5,308 Bytes
68b32f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
import numpy as np
from models.ctm import ContinuousThoughtMachine
class ContinuousThoughtMachineSORT(ContinuousThoughtMachine):
"""
Slight adaption of the CTM to work with the sort task.
"""
def __init__(self,
iterations,
d_model,
d_input,
heads,
n_synch_out,
n_synch_action,
synapse_depth,
memory_length,
deep_nlms,
memory_hidden_dims,
do_layernorm_nlm,
backbone_type,
positional_embedding_type,
out_dims,
prediction_reshaper=[-1],
dropout=0,
dropout_nlm=None,
neuron_select_type='random-pairing',
n_random_pairing_self=0,
):
super().__init__(
iterations=iterations,
d_model=d_model,
d_input=d_input,
heads=0,
n_synch_out=n_synch_out,
n_synch_action=0,
synapse_depth=synapse_depth,
memory_length=memory_length,
deep_nlms=deep_nlms,
memory_hidden_dims=memory_hidden_dims,
do_layernorm_nlm=do_layernorm_nlm,
backbone_type='none',
positional_embedding_type='none',
out_dims=out_dims,
prediction_reshaper=prediction_reshaper,
dropout=dropout,
dropout_nlm=dropout_nlm,
neuron_select_type=neuron_select_type,
n_random_pairing_self=n_random_pairing_self,
)
# --- Use a minimal CTM w/out input (action) synch ---
self.neuron_select_type_action = None
self.synch_representation_size_action = None
self.attention = None # Should already be None because super(... heads=0... )
self.q_proj = None # Should already be None because super(... heads=0... )
self.kv_proj = None # Should already be None because super(... heads=0... )
def forward(self, x, track=False):
B = x.size(0)
device = x.device
# --- Tracking Initialization ---
pre_activations_tracking = []
post_activations_tracking = []
synch_out_tracking = []
attention_tracking = []
# --- For SORT: no need to featurise data ---
# --- Initialise Recurrent State ---
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
# --- Prepare Storage for Outputs per Iteration ---
predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
# --- Initialise Recurrent Synch Values ---
r_out = torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1)
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
# Compute learned weighting for synchronisation
# --- Recurrent Loop ---
for stepi in range(self.iterations):
pre_synapse_input = torch.concatenate((x, activated_state), dim=-1)
# --- Apply Synapses ---
state = self.synapses(pre_synapse_input)
# The 'state_trace' is the history of incoming pre-activations
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
# --- Apply Neuron-Level Models ---
activated_state = self.trace_processor(state_trace)
# One would also keep an 'activated_state_trace' as the history of outgoing post-activations
# BUT, this is unnecessary because the synchronisation calculation is fully linear and can be
# done using only the currect activated state (see compute_synchronisation method for explanation)
# --- Calculate Synchronisation for Output Predictions ---
synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
# --- Get Predictions and Certainties ---
current_prediction = self.output_projector(synchronisation_out)
current_certainty = self.compute_certainty(current_prediction)
predictions[..., stepi] = current_prediction
certainties[..., stepi] = current_certainty
# --- Tracking ---
if track:
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
post_activations_tracking.append(activated_state.detach().cpu().numpy())
synch_out_tracking.append(synchronisation_out.detach().cpu().numpy())
# --- Return Values ---
if track:
return predictions, certainties, np.array(synch_out_tracking), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking)
return predictions, certainties, synchronisation_out
|