LukeDarlow's picture
Clamp of decay params applied to data so that gradients are valid moving forward. Fix suggested by user=kuviki
451276c
import torch
import numpy as np
from models.ctm import ContinuousThoughtMachine
from models.modules import MNISTBackbone, QAMNISTIndexEmbeddings, QAMNISTOperatorEmbeddings
class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine):
def __init__(self,
iterations,
d_model,
d_input,
heads,
n_synch_out,
n_synch_action,
synapse_depth,
memory_length,
deep_nlms,
memory_hidden_dims,
do_layernorm_nlm,
out_dims,
iterations_per_digit,
iterations_per_question_part,
iterations_for_answering,
prediction_reshaper=[-1],
dropout=0,
neuron_select_type='first-last',
n_random_pairing_self=256
):
super().__init__(
iterations=iterations,
d_model=d_model,
d_input=d_input,
heads=heads,
n_synch_out=n_synch_out,
n_synch_action=n_synch_action,
synapse_depth=synapse_depth,
memory_length=memory_length,
deep_nlms=deep_nlms,
memory_hidden_dims=memory_hidden_dims,
do_layernorm_nlm=do_layernorm_nlm,
out_dims=out_dims,
prediction_reshaper=prediction_reshaper,
dropout=dropout,
neuron_select_type=neuron_select_type,
n_random_pairing_self=n_random_pairing_self,
backbone_type='none',
positional_embedding_type='none',
)
# --- Core Parameters ---
self.iterations_per_digit = iterations_per_digit
self.iterations_per_question_part = iterations_per_question_part
self.iterations_for_answering = iterations_for_answering
# --- Setup Methods ---
def set_initial_rgb(self):
"""Set the initial RGB values for the backbone."""
return None
def get_d_backbone(self):
"""Get the dimensionality of the backbone output."""
return self.d_input
def set_backbone(self):
"""Set the backbone module based on the specified type."""
self.backbone_digit = MNISTBackbone(self.d_input)
self.index_backbone = QAMNISTIndexEmbeddings(50, self.d_input)
self.operator_backbone = QAMNISTOperatorEmbeddings(2, self.d_input)
pass
# --- Utilty Methods ---
def determine_step_type(self, total_iterations_for_digits, total_iterations_for_question, stepi: int):
"""Determine whether the current step is for digits, questions, or answers."""
is_digit_step = stepi < total_iterations_for_digits
is_question_step = total_iterations_for_digits <= stepi < total_iterations_for_digits + total_iterations_for_question
is_answer_step = stepi >= total_iterations_for_digits + total_iterations_for_question
return is_digit_step, is_question_step, is_answer_step
def determine_index_operator_step_type(self, total_iterations_for_digits, stepi: int):
"""Determine whether the current step is for index or operator."""
step_within_questions = stepi - total_iterations_for_digits
if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
is_index_step = True
is_operator_step = False
else:
is_index_step = False
is_operator_step = True
return is_index_step, is_operator_step
def get_kv_for_step(self, total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input=None, prev_kv=None):
"""Get the key-value for the current step."""
is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
if is_digit_step:
current_input = x[:, stepi]
if prev_input is not None and torch.equal(current_input, prev_input):
return prev_kv, prev_input
kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
elif is_question_step:
offset = stepi - total_iterations_for_digits
current_input = z[:, offset]
if prev_input is not None and torch.equal(current_input, prev_input):
return prev_kv, prev_input
is_index_step, is_operator_step = self.determine_index_operator_step_type(total_iterations_for_digits, stepi)
if is_index_step:
kv = self.index_backbone(current_input)
elif is_operator_step:
kv = self.operator_backbone(current_input)
else:
raise ValueError("Invalid step type for question processing.")
elif is_answer_step:
current_input = None
kv = torch.zeros((x.size(0), self.d_input), device=x.device)
else:
raise ValueError("Invalid step type.")
return kv, current_input
def forward(self, x, z, track=False):
B = x.size(0)
device = x.device
# --- Tracking Initialization ---
pre_activations_tracking = []
post_activations_tracking = []
attention_tracking = []
embedding_tracking = []
total_iterations_for_digits = x.size(1)
total_iterations_for_question = z.size(1)
total_iterations = total_iterations_for_digits + total_iterations_for_question + self.iterations_for_answering
# --- Initialise Recurrent State ---
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
# --- Storage for outputs per iteration ---
predictions = torch.empty(B, self.out_dims, total_iterations, device=device, dtype=x.dtype)
certainties = torch.empty(B, 2, total_iterations, device=device, dtype=x.dtype)
# --- Initialise Recurrent Synch Values ---
decay_alpha_action, decay_beta_action = None, None
self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
prev_input = None
prev_kv = None
# --- Recurrent Loop ---
for stepi in range(total_iterations):
is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
kv, prev_input = self.get_kv_for_step(total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input, prev_kv)
prev_kv = kv
synchronization_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
# --- Interact with Data via Attention ---
attn_weights = None
if is_digit_step:
q = self.q_proj(synchronization_action).unsqueeze(1)
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
attn_out = attn_out.squeeze(1)
pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
else:
kv = kv.squeeze(1)
pre_synapse_input = torch.concatenate((kv, activated_state), dim=-1)
# --- Apply Synapses ---
state = self.synapses(pre_synapse_input)
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
# --- Apply NLMs ---
activated_state = self.trace_processor(state_trace)
# --- Calculate Synchronisation for Output Predictions ---
synchronization_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
# --- Get Predictions and Certainties ---
current_prediction = self.output_projector(synchronization_out)
current_certainty = self.compute_certainty(current_prediction)
predictions[..., stepi] = current_prediction
certainties[..., stepi] = current_certainty
# --- Tracking ---
if track:
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
post_activations_tracking.append(activated_state.detach().cpu().numpy())
if attn_weights is not None:
attention_tracking.append(attn_weights.detach().cpu().numpy())
if is_question_step:
embedding_tracking.append(kv.detach().cpu().numpy())
# --- Return Values ---
if track:
return predictions, certainties, synchronization_out, np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
return predictions, certainties, synchronization_out