import torch import torch.nn as nn import numpy as np import math from models.ctm import ContinuousThoughtMachine from models.modules import MiniGridBackbone, ClassicControlBackbone, SynapseUNET from models.utils import compute_decay from models.constants import VALID_NEURON_SELECT_TYPES class ContinuousThoughtMachineRL(ContinuousThoughtMachine): def __init__(self, iterations, d_model, d_input, n_synch_out, synapse_depth, memory_length, deep_nlms, memory_hidden_dims, do_layernorm_nlm, backbone_type, prediction_reshaper=[-1], dropout=0, neuron_select_type='first-last', ): super().__init__( iterations=iterations, d_model=d_model, d_input=d_input, heads=0, # Set heads to 0 will return None n_synch_out=n_synch_out, n_synch_action=0, synapse_depth=synapse_depth, memory_length=memory_length, deep_nlms=deep_nlms, memory_hidden_dims=memory_hidden_dims, do_layernorm_nlm=do_layernorm_nlm, out_dims=0, prediction_reshaper=prediction_reshaper, dropout=dropout, neuron_select_type=neuron_select_type, backbone_type=backbone_type, n_random_pairing_self=0, positional_embedding_type='none', ) # --- Use a minimal CTM w/out input (action) synch --- self.neuron_select_type_action = None self.synch_representation_size_action = None # --- Start dynamics with a learned activated state trace --- self.register_parameter('start_activated_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True)) self.start_activated_state = None self.register_buffer('diagonal_mask_out', torch.triu(torch.ones(self.n_synch_out, self.n_synch_out, dtype=torch.bool))) self.attention = None # Should already be None because super(... heads=0... ) self.q_proj = None # Should already be None because super(... heads=0... ) self.kv_proj = None # Should already be None because super(... heads=0... ) self.output_projector = None # --- Core CTM Methods --- def compute_synchronisation(self, activated_state_trace): """Compute the synchronisation between neurons.""" assert self.neuron_select_type == "first-last", "only fisrst-last neuron selection is supported here" # For RL tasks we track a sliding window of activations from which we compute synchronisation S = activated_state_trace.permute(0, 2, 1) diagonal_mask = self.diagonal_mask_out.to(S.device) decay = compute_decay(S.size(1), self.decay_params_out, clamp_lims=(0, 4)) synchronisation = ((decay.unsqueeze(0) *(S[:,:,-self.n_synch_out:].unsqueeze(-1) * S[:,:,-self.n_synch_out:].unsqueeze(-2))[:,:,diagonal_mask]).sum(1))/torch.sqrt(decay.unsqueeze(0).sum(1,)) return synchronisation # --- Setup Methods --- def set_initial_rgb(self): """Set the initial RGB values for the backbone.""" return None def get_d_backbone(self): """Get the dimensionality of the backbone output.""" return self.d_input def set_backbone(self): """Set the backbone module based on the specified type.""" if self.backbone_type == 'navigation-backbone': self.backbone = MiniGridBackbone(self.d_input) elif self.backbone_type == 'classic-control-backbone': self.backbone = ClassicControlBackbone(self.d_input) else: raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).') pass def get_positional_embedding(self, d_backbone): """Get the positional embedding module.""" return None def get_synapses(self, synapse_depth, d_model, dropout): """ Get the synapse module. We found in our early experimentation that a single Linear, GLU and LayerNorm block performed worse than two blocks. For that reason we set the default synapse depth to two blocks. TODO: This is legacy and needs further experimentation to iron out. """ if synapse_depth == 1: return nn.Sequential( nn.Dropout(dropout), nn.LazyLinear(d_model*2), nn.GLU(), nn.LayerNorm(d_model), nn.LazyLinear(d_model*2), nn.GLU(), nn.LayerNorm(d_model) ) else: return SynapseUNET(d_model, synapse_depth, 16, dropout) def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0): """Set the parameters for the synchronisation of neurons.""" if synch_type == 'action': pass elif synch_type == 'out': left, right = self.initialize_left_right_neurons("out", self.d_model, n_synch, n_random_pairing_self) self.register_buffer(f'out_neuron_indices_left', left) self.register_buffer(f'out_neuron_indices_right', right) self.register_parameter(f'decay_params_out', nn.Parameter(torch.zeros(self.synch_representation_size_out), requires_grad=True)) pass else: raise ValueError(f"Invalid synch_type: {synch_type}") # --- Utilty Methods --- def verify_args(self): """Verify the validity of the input arguments.""" assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \ f"Invalid neuron selection type: {self.neuron_select_type}" assert self.neuron_select_type != 'random-pairing', \ f"Random pairing is not supported for RL." assert self.backbone_type in ('navigation-backbone', 'classic-control-backbone'), \ f"Invalid backbone_type: {self.backbone_type}" assert self.d_model >= (self.n_synch_out), \ "d_model must be >= n_synch_out for neuron subsets" pass def forward(self, x, hidden_states, track=False): # --- Tracking Initialization --- pre_activations_tracking = [] post_activations_tracking = [] # --- Featurise Input Data --- features = self.backbone(x) # --- Get Recurrent State --- state_trace, activated_state_trace = hidden_states # --- Recurrent Loop --- for stepi in range(self.iterations): pre_synapse_input = torch.concatenate((features.reshape(x.size(0), -1), activated_state_trace[:,:,-1]), -1) # --- Apply Synapses --- state = self.synapses(pre_synapse_input) state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1) # --- Apply NLMs --- activated_state = self.trace_processor(state_trace) activated_state_trace = torch.concatenate((activated_state_trace[:,:,1:], activated_state.unsqueeze(-1)), -1) # --- Tracking --- if track: pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy()) post_activations_tracking.append(activated_state.detach().cpu().numpy()) hidden_states = ( state_trace, activated_state_trace, ) # --- Calculate Output Synchronisation --- synchronisation_out = self.compute_synchronisation(activated_state_trace) # --- Return Values --- if track: return synchronisation_out, hidden_states, np.array(pre_activations_tracking), np.array(post_activations_tracking) return synchronisation_out, hidden_states