File size: 9,522 Bytes
68b32f4 451276c 68b32f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import torch
import numpy as np
from models.ctm import ContinuousThoughtMachine
from models.modules import MNISTBackbone, QAMNISTIndexEmbeddings, QAMNISTOperatorEmbeddings
class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine):
def __init__(self,
iterations,
d_model,
d_input,
heads,
n_synch_out,
n_synch_action,
synapse_depth,
memory_length,
deep_nlms,
memory_hidden_dims,
do_layernorm_nlm,
out_dims,
iterations_per_digit,
iterations_per_question_part,
iterations_for_answering,
prediction_reshaper=[-1],
dropout=0,
neuron_select_type='first-last',
n_random_pairing_self=256
):
super().__init__(
iterations=iterations,
d_model=d_model,
d_input=d_input,
heads=heads,
n_synch_out=n_synch_out,
n_synch_action=n_synch_action,
synapse_depth=synapse_depth,
memory_length=memory_length,
deep_nlms=deep_nlms,
memory_hidden_dims=memory_hidden_dims,
do_layernorm_nlm=do_layernorm_nlm,
out_dims=out_dims,
prediction_reshaper=prediction_reshaper,
dropout=dropout,
neuron_select_type=neuron_select_type,
n_random_pairing_self=n_random_pairing_self,
backbone_type='none',
positional_embedding_type='none',
)
# --- Core Parameters ---
self.iterations_per_digit = iterations_per_digit
self.iterations_per_question_part = iterations_per_question_part
self.iterations_for_answering = iterations_for_answering
# --- Setup Methods ---
def set_initial_rgb(self):
"""Set the initial RGB values for the backbone."""
return None
def get_d_backbone(self):
"""Get the dimensionality of the backbone output."""
return self.d_input
def set_backbone(self):
"""Set the backbone module based on the specified type."""
self.backbone_digit = MNISTBackbone(self.d_input)
self.index_backbone = QAMNISTIndexEmbeddings(50, self.d_input)
self.operator_backbone = QAMNISTOperatorEmbeddings(2, self.d_input)
pass
# --- Utilty Methods ---
def determine_step_type(self, total_iterations_for_digits, total_iterations_for_question, stepi: int):
"""Determine whether the current step is for digits, questions, or answers."""
is_digit_step = stepi < total_iterations_for_digits
is_question_step = total_iterations_for_digits <= stepi < total_iterations_for_digits + total_iterations_for_question
is_answer_step = stepi >= total_iterations_for_digits + total_iterations_for_question
return is_digit_step, is_question_step, is_answer_step
def determine_index_operator_step_type(self, total_iterations_for_digits, stepi: int):
"""Determine whether the current step is for index or operator."""
step_within_questions = stepi - total_iterations_for_digits
if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
is_index_step = True
is_operator_step = False
else:
is_index_step = False
is_operator_step = True
return is_index_step, is_operator_step
def get_kv_for_step(self, total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input=None, prev_kv=None):
"""Get the key-value for the current step."""
is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
if is_digit_step:
current_input = x[:, stepi]
if prev_input is not None and torch.equal(current_input, prev_input):
return prev_kv, prev_input
kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
elif is_question_step:
offset = stepi - total_iterations_for_digits
current_input = z[:, offset]
if prev_input is not None and torch.equal(current_input, prev_input):
return prev_kv, prev_input
is_index_step, is_operator_step = self.determine_index_operator_step_type(total_iterations_for_digits, stepi)
if is_index_step:
kv = self.index_backbone(current_input)
elif is_operator_step:
kv = self.operator_backbone(current_input)
else:
raise ValueError("Invalid step type for question processing.")
elif is_answer_step:
current_input = None
kv = torch.zeros((x.size(0), self.d_input), device=x.device)
else:
raise ValueError("Invalid step type.")
return kv, current_input
def forward(self, x, z, track=False):
B = x.size(0)
device = x.device
# --- Tracking Initialization ---
pre_activations_tracking = []
post_activations_tracking = []
attention_tracking = []
embedding_tracking = []
total_iterations_for_digits = x.size(1)
total_iterations_for_question = z.size(1)
total_iterations = total_iterations_for_digits + total_iterations_for_question + self.iterations_for_answering
# --- Initialise Recurrent State ---
state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T)
activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H)
# --- Storage for outputs per iteration ---
predictions = torch.empty(B, self.out_dims, total_iterations, device=device, dtype=x.dtype)
certainties = torch.empty(B, 2, total_iterations, device=device, dtype=x.dtype)
# --- Initialise Recurrent Synch Values ---
decay_alpha_action, decay_beta_action = None, None
self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki
self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15)
r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1)
_, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out')
prev_input = None
prev_kv = None
# --- Recurrent Loop ---
for stepi in range(total_iterations):
is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi)
kv, prev_input = self.get_kv_for_step(total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input, prev_kv)
prev_kv = kv
synchronization_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action')
# --- Interact with Data via Attention ---
attn_weights = None
if is_digit_step:
q = self.q_proj(synchronization_action).unsqueeze(1)
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
attn_out = attn_out.squeeze(1)
pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1)
else:
kv = kv.squeeze(1)
pre_synapse_input = torch.concatenate((kv, activated_state), dim=-1)
# --- Apply Synapses ---
state = self.synapses(pre_synapse_input)
state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1)
# --- Apply NLMs ---
activated_state = self.trace_processor(state_trace)
# --- Calculate Synchronisation for Output Predictions ---
synchronization_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out')
# --- Get Predictions and Certainties ---
current_prediction = self.output_projector(synchronization_out)
current_certainty = self.compute_certainty(current_prediction)
predictions[..., stepi] = current_prediction
certainties[..., stepi] = current_certainty
# --- Tracking ---
if track:
pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy())
post_activations_tracking.append(activated_state.detach().cpu().numpy())
if attn_weights is not None:
attention_tracking.append(attn_weights.detach().cpu().numpy())
if is_question_step:
embedding_tracking.append(kv.detach().cpu().numpy())
# --- Return Values ---
if track:
return predictions, certainties, synchronization_out, np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
return predictions, certainties, synchronization_out |