LukeDarlow's picture
Welcome to the CTM. This is the first commit of the public repo. Enjoy!
68b32f4
import torch
import torch.nn as nn
import numpy as np
import math
from models.ctm import ContinuousThoughtMachine
from models.modules import MiniGridBackbone, ClassicControlBackbone, SynapseUNET
from models.utils import compute_decay
from models.constants import VALID_NEURON_SELECT_TYPES
class ContinuousThoughtMachineRL(ContinuousThoughtMachine):
def __init__(self,
iterations,
d_model,
d_input,
n_synch_out,
synapse_depth,
memory_length,
deep_nlms,
memory_hidden_dims,
do_layernorm_nlm,
backbone_type,
prediction_reshaper=[-1],
dropout=0,
neuron_select_type='first-last',
):
super().__init__(
iterations=iterations,
d_model=d_model,
d_input=d_input,
heads=0, # Set heads to 0 will return None
n_synch_out=n_synch_out,
n_synch_action=0,
synapse_depth=synapse_depth,
memory_length=memory_length,
deep_nlms=deep_nlms,
memory_hidden_dims=memory_hidden_dims,
do_layernorm_nlm=do_layernorm_nlm,
out_dims=0,
prediction_reshaper=prediction_reshaper,
dropout=dropout,
neuron_select_type=neuron_select_type,
backbone_type=backbone_type,
n_random_pairing_self=0,
positional_embedding_type='none',
)
# --- Use a minimal CTM w/out input (action) synch ---
self.neuron_select_type_action = None
self.synch_representation_size_action = None
# --- Start dynamics with a learned activated state trace ---
self.register_parameter('start_activated_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True))
self.start_activated_state = None
self.register_buffer('diagonal_mask_out', torch.triu(torch.ones(self.n_synch_out, self.n_synch_out, dtype=torch.bool)))
self.attention = None # Should already be None because super(... heads=0... )
self.q_proj = None # Should already be None because super(... heads=0... )
self.kv_proj = None # Should already be None because super(... heads=0... )
self.output_projector = None
# --- Core CTM Methods ---
def compute_synchronisation(self, activated_state_trace):
"""Compute the synchronisation between neurons."""
assert self.neuron_select_type == "first-last", "only fisrst-last neuron selection is supported here"
# For RL tasks we track a sliding window of activations from which we compute synchronisation
S = activated_state_trace.permute(0, 2, 1)
diagonal_mask = self.diagonal_mask_out.to(S.device)
decay = compute_decay(S.size(1), self.decay_params_out, clamp_lims=(0, 4))
synchronisation = ((decay.unsqueeze(0) *(S[:,:,-self.n_synch_out:].unsqueeze(-1) * S[:,:,-self.n_synch_out:].unsqueeze(-2))[:,:,diagonal_mask]).sum(1))/torch.sqrt(decay.unsqueeze(0).sum(1,))
return synchronisation
# --- Setup Methods ---
def set_initial_rgb(self):
"""Set the initial RGB values for the backbone."""
return None
def get_d_backbone(self):
"""Get the dimensionality of the backbone output."""
return self.d_input
def set_backbone(self):
"""Set the backbone module based on the specified type."""
if self.backbone_type == 'navigation-backbone':
self.backbone = MiniGridBackbone(self.d_input)
elif self.backbone_type == 'classic-control-backbone':
self.backbone = ClassicControlBackbone(self.d_input)
else:
raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
pass
def get_positional_embedding(self, d_backbone):
"""Get the positional embedding module."""
return None
def get_synapses(self, synapse_depth, d_model, dropout):
"""
Get the synapse module.
We found in our early experimentation that a single Linear, GLU and LayerNorm block performed worse than two blocks.
For that reason we set the default synapse depth to two blocks.
TODO: This is legacy and needs further experimentation to iron out.
"""
if synapse_depth == 1:
return nn.Sequential(
nn.Dropout(dropout),
nn.LazyLinear(d_model*2),
nn.GLU(),
nn.LayerNorm(d_model),
nn.LazyLinear(d_model*2),
nn.GLU(),
nn.LayerNorm(d_model)
)
else:
return SynapseUNET(d_model, synapse_depth, 16, dropout)
def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
"""Set the parameters for the synchronisation of neurons."""
if synch_type == 'action':
pass
elif synch_type == 'out':
left, right = self.initialize_left_right_neurons("out", self.d_model, n_synch, n_random_pairing_self)
self.register_buffer(f'out_neuron_indices_left', left)
self.register_buffer(f'out_neuron_indices_right', right)
self.register_parameter(f'decay_params_out', nn.Parameter(torch.zeros(self.synch_representation_size_out), requires_grad=True))
pass
else:
raise ValueError(f"Invalid synch_type: {synch_type}")
# --- Utilty Methods ---
def verify_args(self):
"""Verify the validity of the input arguments."""
assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
f"Invalid neuron selection type: {self.neuron_select_type}"
assert self.neuron_select_type != 'random-pairing', \
f"Random pairing is not supported for RL."
assert self.backbone_type in ('navigation-backbone', 'classic-control-backbone'), \
f"Invalid backbone_type: {self.backbone_type}"
assert self.d_model >= (self.n_synch_out), \
"d_model must be >= n_synch_out for neuron subsets"
pass
def forward(self, x, hidden_states, track=False):
# --- Tracking Initialization ---
pre_activations_tracking = []
post_activations_tracking = []
# --- Featurise Input Data ---
features = self.backbone(x)
# --- Get Recurrent State ---
state_trace, activated_state_trace = hidden_states
# --- Recurrent Loop ---
for stepi in range(self.iterations):
pre_synapse_input = torch.concatenate((features.reshape(x.size(0), -1), activated_state_trace[:,:,-1]), -1)
# --- Apply Synapses ---
state = self.synapses(pre_synapse_input)
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
# --- Apply NLMs ---
activated_state = self.trace_processor(state_trace)
activated_state_trace = torch.concatenate((activated_state_trace[:,:,1:], activated_state.unsqueeze(-1)), -1)
# --- Tracking ---
if track:
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
post_activations_tracking.append(activated_state.detach().cpu().numpy())
hidden_states = (
state_trace,
activated_state_trace,
)
# --- Calculate Output Synchronisation ---
synchronisation_out = self.compute_synchronisation(activated_state_trace)
# --- Return Values ---
if track:
return synchronisation_out, hidden_states, np.array(pre_activations_tracking), np.array(post_activations_tracking)
return synchronisation_out, hidden_states