File size: 8,095 Bytes
68b32f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import torch
import torch.nn as nn
import numpy as np
import math
from models.ctm import ContinuousThoughtMachine
from models.modules import MiniGridBackbone, ClassicControlBackbone, SynapseUNET
from models.utils import compute_decay
from models.constants import VALID_NEURON_SELECT_TYPES
class ContinuousThoughtMachineRL(ContinuousThoughtMachine):
def __init__(self,
iterations,
d_model,
d_input,
n_synch_out,
synapse_depth,
memory_length,
deep_nlms,
memory_hidden_dims,
do_layernorm_nlm,
backbone_type,
prediction_reshaper=[-1],
dropout=0,
neuron_select_type='first-last',
):
super().__init__(
iterations=iterations,
d_model=d_model,
d_input=d_input,
heads=0, # Set heads to 0 will return None
n_synch_out=n_synch_out,
n_synch_action=0,
synapse_depth=synapse_depth,
memory_length=memory_length,
deep_nlms=deep_nlms,
memory_hidden_dims=memory_hidden_dims,
do_layernorm_nlm=do_layernorm_nlm,
out_dims=0,
prediction_reshaper=prediction_reshaper,
dropout=dropout,
neuron_select_type=neuron_select_type,
backbone_type=backbone_type,
n_random_pairing_self=0,
positional_embedding_type='none',
)
# --- Use a minimal CTM w/out input (action) synch ---
self.neuron_select_type_action = None
self.synch_representation_size_action = None
# --- Start dynamics with a learned activated state trace ---
self.register_parameter('start_activated_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True))
self.start_activated_state = None
self.register_buffer('diagonal_mask_out', torch.triu(torch.ones(self.n_synch_out, self.n_synch_out, dtype=torch.bool)))
self.attention = None # Should already be None because super(... heads=0... )
self.q_proj = None # Should already be None because super(... heads=0... )
self.kv_proj = None # Should already be None because super(... heads=0... )
self.output_projector = None
# --- Core CTM Methods ---
def compute_synchronisation(self, activated_state_trace):
"""Compute the synchronisation between neurons."""
assert self.neuron_select_type == "first-last", "only fisrst-last neuron selection is supported here"
# For RL tasks we track a sliding window of activations from which we compute synchronisation
S = activated_state_trace.permute(0, 2, 1)
diagonal_mask = self.diagonal_mask_out.to(S.device)
decay = compute_decay(S.size(1), self.decay_params_out, clamp_lims=(0, 4))
synchronisation = ((decay.unsqueeze(0) *(S[:,:,-self.n_synch_out:].unsqueeze(-1) * S[:,:,-self.n_synch_out:].unsqueeze(-2))[:,:,diagonal_mask]).sum(1))/torch.sqrt(decay.unsqueeze(0).sum(1,))
return synchronisation
# --- Setup Methods ---
def set_initial_rgb(self):
"""Set the initial RGB values for the backbone."""
return None
def get_d_backbone(self):
"""Get the dimensionality of the backbone output."""
return self.d_input
def set_backbone(self):
"""Set the backbone module based on the specified type."""
if self.backbone_type == 'navigation-backbone':
self.backbone = MiniGridBackbone(self.d_input)
elif self.backbone_type == 'classic-control-backbone':
self.backbone = ClassicControlBackbone(self.d_input)
else:
raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
pass
def get_positional_embedding(self, d_backbone):
"""Get the positional embedding module."""
return None
def get_synapses(self, synapse_depth, d_model, dropout):
"""
Get the synapse module.
We found in our early experimentation that a single Linear, GLU and LayerNorm block performed worse than two blocks.
For that reason we set the default synapse depth to two blocks.
TODO: This is legacy and needs further experimentation to iron out.
"""
if synapse_depth == 1:
return nn.Sequential(
nn.Dropout(dropout),
nn.LazyLinear(d_model*2),
nn.GLU(),
nn.LayerNorm(d_model),
nn.LazyLinear(d_model*2),
nn.GLU(),
nn.LayerNorm(d_model)
)
else:
return SynapseUNET(d_model, synapse_depth, 16, dropout)
def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0):
"""Set the parameters for the synchronisation of neurons."""
if synch_type == 'action':
pass
elif synch_type == 'out':
left, right = self.initialize_left_right_neurons("out", self.d_model, n_synch, n_random_pairing_self)
self.register_buffer(f'out_neuron_indices_left', left)
self.register_buffer(f'out_neuron_indices_right', right)
self.register_parameter(f'decay_params_out', nn.Parameter(torch.zeros(self.synch_representation_size_out), requires_grad=True))
pass
else:
raise ValueError(f"Invalid synch_type: {synch_type}")
# --- Utilty Methods ---
def verify_args(self):
"""Verify the validity of the input arguments."""
assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \
f"Invalid neuron selection type: {self.neuron_select_type}"
assert self.neuron_select_type != 'random-pairing', \
f"Random pairing is not supported for RL."
assert self.backbone_type in ('navigation-backbone', 'classic-control-backbone'), \
f"Invalid backbone_type: {self.backbone_type}"
assert self.d_model >= (self.n_synch_out), \
"d_model must be >= n_synch_out for neuron subsets"
pass
def forward(self, x, hidden_states, track=False):
# --- Tracking Initialization ---
pre_activations_tracking = []
post_activations_tracking = []
# --- Featurise Input Data ---
features = self.backbone(x)
# --- Get Recurrent State ---
state_trace, activated_state_trace = hidden_states
# --- Recurrent Loop ---
for stepi in range(self.iterations):
pre_synapse_input = torch.concatenate((features.reshape(x.size(0), -1), activated_state_trace[:,:,-1]), -1)
# --- Apply Synapses ---
state = self.synapses(pre_synapse_input)
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
# --- Apply NLMs ---
activated_state = self.trace_processor(state_trace)
activated_state_trace = torch.concatenate((activated_state_trace[:,:,1:], activated_state.unsqueeze(-1)), -1)
# --- Tracking ---
if track:
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
post_activations_tracking.append(activated_state.detach().cpu().numpy())
hidden_states = (
state_trace,
activated_state_trace,
)
# --- Calculate Output Synchronisation ---
synchronisation_out = self.compute_synchronisation(activated_state_trace)
# --- Return Values ---
if track:
return synchronisation_out, hidden_states, np.array(pre_activations_tracking), np.array(post_activations_tracking)
return synchronisation_out, hidden_states |