|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
import math |
|
|
from models.ctm import ContinuousThoughtMachine |
|
|
from models.modules import MiniGridBackbone, ClassicControlBackbone, SynapseUNET |
|
|
from models.utils import compute_decay |
|
|
from models.constants import VALID_NEURON_SELECT_TYPES |
|
|
|
|
|
class ContinuousThoughtMachineRL(ContinuousThoughtMachine): |
|
|
def __init__(self, |
|
|
iterations, |
|
|
d_model, |
|
|
d_input, |
|
|
n_synch_out, |
|
|
synapse_depth, |
|
|
memory_length, |
|
|
deep_nlms, |
|
|
memory_hidden_dims, |
|
|
do_layernorm_nlm, |
|
|
backbone_type, |
|
|
prediction_reshaper=[-1], |
|
|
dropout=0, |
|
|
neuron_select_type='first-last', |
|
|
): |
|
|
super().__init__( |
|
|
iterations=iterations, |
|
|
d_model=d_model, |
|
|
d_input=d_input, |
|
|
heads=0, |
|
|
n_synch_out=n_synch_out, |
|
|
n_synch_action=0, |
|
|
synapse_depth=synapse_depth, |
|
|
memory_length=memory_length, |
|
|
deep_nlms=deep_nlms, |
|
|
memory_hidden_dims=memory_hidden_dims, |
|
|
do_layernorm_nlm=do_layernorm_nlm, |
|
|
out_dims=0, |
|
|
prediction_reshaper=prediction_reshaper, |
|
|
dropout=dropout, |
|
|
neuron_select_type=neuron_select_type, |
|
|
backbone_type=backbone_type, |
|
|
n_random_pairing_self=0, |
|
|
positional_embedding_type='none', |
|
|
) |
|
|
|
|
|
|
|
|
self.neuron_select_type_action = None |
|
|
self.synch_representation_size_action = None |
|
|
|
|
|
|
|
|
self.register_parameter('start_activated_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True)) |
|
|
self.start_activated_state = None |
|
|
|
|
|
self.register_buffer('diagonal_mask_out', torch.triu(torch.ones(self.n_synch_out, self.n_synch_out, dtype=torch.bool))) |
|
|
|
|
|
self.attention = None |
|
|
self.q_proj = None |
|
|
self.kv_proj = None |
|
|
self.output_projector = None |
|
|
|
|
|
|
|
|
|
|
|
def compute_synchronisation(self, activated_state_trace): |
|
|
"""Compute the synchronisation between neurons.""" |
|
|
assert self.neuron_select_type == "first-last", "only fisrst-last neuron selection is supported here" |
|
|
|
|
|
S = activated_state_trace.permute(0, 2, 1) |
|
|
diagonal_mask = self.diagonal_mask_out.to(S.device) |
|
|
decay = compute_decay(S.size(1), self.decay_params_out, clamp_lims=(0, 4)) |
|
|
synchronisation = ((decay.unsqueeze(0) *(S[:,:,-self.n_synch_out:].unsqueeze(-1) * S[:,:,-self.n_synch_out:].unsqueeze(-2))[:,:,diagonal_mask]).sum(1))/torch.sqrt(decay.unsqueeze(0).sum(1,)) |
|
|
return synchronisation |
|
|
|
|
|
|
|
|
|
|
|
def set_initial_rgb(self): |
|
|
"""Set the initial RGB values for the backbone.""" |
|
|
return None |
|
|
|
|
|
def get_d_backbone(self): |
|
|
"""Get the dimensionality of the backbone output.""" |
|
|
return self.d_input |
|
|
|
|
|
def set_backbone(self): |
|
|
"""Set the backbone module based on the specified type.""" |
|
|
if self.backbone_type == 'navigation-backbone': |
|
|
self.backbone = MiniGridBackbone(self.d_input) |
|
|
elif self.backbone_type == 'classic-control-backbone': |
|
|
self.backbone = ClassicControlBackbone(self.d_input) |
|
|
else: |
|
|
raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).') |
|
|
pass |
|
|
|
|
|
def get_positional_embedding(self, d_backbone): |
|
|
"""Get the positional embedding module.""" |
|
|
return None |
|
|
|
|
|
|
|
|
def get_synapses(self, synapse_depth, d_model, dropout): |
|
|
""" |
|
|
Get the synapse module. |
|
|
|
|
|
We found in our early experimentation that a single Linear, GLU and LayerNorm block performed worse than two blocks. |
|
|
For that reason we set the default synapse depth to two blocks. |
|
|
|
|
|
TODO: This is legacy and needs further experimentation to iron out. |
|
|
""" |
|
|
if synapse_depth == 1: |
|
|
return nn.Sequential( |
|
|
nn.Dropout(dropout), |
|
|
nn.LazyLinear(d_model*2), |
|
|
nn.GLU(), |
|
|
nn.LayerNorm(d_model), |
|
|
nn.LazyLinear(d_model*2), |
|
|
nn.GLU(), |
|
|
nn.LayerNorm(d_model) |
|
|
) |
|
|
else: |
|
|
return SynapseUNET(d_model, synapse_depth, 16, dropout) |
|
|
|
|
|
def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0): |
|
|
"""Set the parameters for the synchronisation of neurons.""" |
|
|
if synch_type == 'action': |
|
|
pass |
|
|
elif synch_type == 'out': |
|
|
left, right = self.initialize_left_right_neurons("out", self.d_model, n_synch, n_random_pairing_self) |
|
|
self.register_buffer(f'out_neuron_indices_left', left) |
|
|
self.register_buffer(f'out_neuron_indices_right', right) |
|
|
self.register_parameter(f'decay_params_out', nn.Parameter(torch.zeros(self.synch_representation_size_out), requires_grad=True)) |
|
|
pass |
|
|
else: |
|
|
raise ValueError(f"Invalid synch_type: {synch_type}") |
|
|
|
|
|
|
|
|
|
|
|
def verify_args(self): |
|
|
"""Verify the validity of the input arguments.""" |
|
|
assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \ |
|
|
f"Invalid neuron selection type: {self.neuron_select_type}" |
|
|
assert self.neuron_select_type != 'random-pairing', \ |
|
|
f"Random pairing is not supported for RL." |
|
|
assert self.backbone_type in ('navigation-backbone', 'classic-control-backbone'), \ |
|
|
f"Invalid backbone_type: {self.backbone_type}" |
|
|
assert self.d_model >= (self.n_synch_out), \ |
|
|
"d_model must be >= n_synch_out for neuron subsets" |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, hidden_states, track=False): |
|
|
|
|
|
|
|
|
pre_activations_tracking = [] |
|
|
post_activations_tracking = [] |
|
|
|
|
|
|
|
|
features = self.backbone(x) |
|
|
|
|
|
|
|
|
state_trace, activated_state_trace = hidden_states |
|
|
|
|
|
|
|
|
for stepi in range(self.iterations): |
|
|
|
|
|
pre_synapse_input = torch.concatenate((features.reshape(x.size(0), -1), activated_state_trace[:,:,-1]), -1) |
|
|
|
|
|
|
|
|
state = self.synapses(pre_synapse_input) |
|
|
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1) |
|
|
|
|
|
|
|
|
activated_state = self.trace_processor(state_trace) |
|
|
activated_state_trace = torch.concatenate((activated_state_trace[:,:,1:], activated_state.unsqueeze(-1)), -1) |
|
|
|
|
|
|
|
|
if track: |
|
|
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy()) |
|
|
post_activations_tracking.append(activated_state.detach().cpu().numpy()) |
|
|
|
|
|
hidden_states = ( |
|
|
state_trace, |
|
|
activated_state_trace, |
|
|
) |
|
|
|
|
|
|
|
|
synchronisation_out = self.compute_synchronisation(activated_state_trace) |
|
|
|
|
|
|
|
|
if track: |
|
|
return synchronisation_out, hidden_states, np.array(pre_activations_tracking), np.array(post_activations_tracking) |
|
|
return synchronisation_out, hidden_states |