import torch import numpy as np from models.ctm import ContinuousThoughtMachine class ContinuousThoughtMachineSORT(ContinuousThoughtMachine): """ Slight adaption of the CTM to work with the sort task. """ def __init__(self, iterations, d_model, d_input, heads, n_synch_out, n_synch_action, synapse_depth, memory_length, deep_nlms, memory_hidden_dims, do_layernorm_nlm, backbone_type, positional_embedding_type, out_dims, prediction_reshaper=[-1], dropout=0, dropout_nlm=None, neuron_select_type='random-pairing', n_random_pairing_self=0, ): super().__init__( iterations=iterations, d_model=d_model, d_input=d_input, heads=0, n_synch_out=n_synch_out, n_synch_action=0, synapse_depth=synapse_depth, memory_length=memory_length, deep_nlms=deep_nlms, memory_hidden_dims=memory_hidden_dims, do_layernorm_nlm=do_layernorm_nlm, backbone_type='none', positional_embedding_type='none', out_dims=out_dims, prediction_reshaper=prediction_reshaper, dropout=dropout, dropout_nlm=dropout_nlm, neuron_select_type=neuron_select_type, n_random_pairing_self=n_random_pairing_self, ) # --- Use a minimal CTM w/out input (action) synch --- self.neuron_select_type_action = None self.synch_representation_size_action = None self.attention = None # Should already be None because super(... heads=0... ) self.q_proj = None # Should already be None because super(... heads=0... ) self.kv_proj = None # Should already be None because super(... heads=0... ) def forward(self, x, track=False): B = x.size(0) device = x.device # --- Tracking Initialization --- pre_activations_tracking = [] post_activations_tracking = [] synch_out_tracking = [] attention_tracking = [] # --- For SORT: no need to featurise data --- # --- Initialise Recurrent State --- state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T) activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H) # --- Prepare Storage for Outputs per Iteration --- predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype) certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype) # --- Initialise Recurrent Synch Values --- r_out = torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1) _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out') # Compute learned weighting for synchronisation # --- Recurrent Loop --- for stepi in range(self.iterations): pre_synapse_input = torch.concatenate((x, activated_state), dim=-1) # --- Apply Synapses --- state = self.synapses(pre_synapse_input) # The 'state_trace' is the history of incoming pre-activations state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1) # --- Apply Neuron-Level Models --- activated_state = self.trace_processor(state_trace) # One would also keep an 'activated_state_trace' as the history of outgoing post-activations # BUT, this is unnecessary because the synchronisation calculation is fully linear and can be # done using only the currect activated state (see compute_synchronisation method for explanation) # --- Calculate Synchronisation for Output Predictions --- synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out') # --- Get Predictions and Certainties --- current_prediction = self.output_projector(synchronisation_out) current_certainty = self.compute_certainty(current_prediction) predictions[..., stepi] = current_prediction certainties[..., stepi] = current_certainty # --- Tracking --- if track: pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy()) post_activations_tracking.append(activated_state.detach().cpu().numpy()) synch_out_tracking.append(synchronisation_out.detach().cpu().numpy()) # --- Return Values --- if track: return predictions, certainties, np.array(synch_out_tracking), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking) return predictions, certainties, synchronisation_out