Clamp of decay params applied to data so that gradients are valid moving forward. Fix suggested by user=kuviki
451276c
| import torch | |
| import numpy as np | |
| from models.ctm import ContinuousThoughtMachine | |
| from models.modules import MNISTBackbone, QAMNISTIndexEmbeddings, QAMNISTOperatorEmbeddings | |
| class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine): | |
| def __init__(self, | |
| iterations, | |
| d_model, | |
| d_input, | |
| heads, | |
| n_synch_out, | |
| n_synch_action, | |
| synapse_depth, | |
| memory_length, | |
| deep_nlms, | |
| memory_hidden_dims, | |
| do_layernorm_nlm, | |
| out_dims, | |
| iterations_per_digit, | |
| iterations_per_question_part, | |
| iterations_for_answering, | |
| prediction_reshaper=[-1], | |
| dropout=0, | |
| neuron_select_type='first-last', | |
| n_random_pairing_self=256 | |
| ): | |
| super().__init__( | |
| iterations=iterations, | |
| d_model=d_model, | |
| d_input=d_input, | |
| heads=heads, | |
| n_synch_out=n_synch_out, | |
| n_synch_action=n_synch_action, | |
| synapse_depth=synapse_depth, | |
| memory_length=memory_length, | |
| deep_nlms=deep_nlms, | |
| memory_hidden_dims=memory_hidden_dims, | |
| do_layernorm_nlm=do_layernorm_nlm, | |
| out_dims=out_dims, | |
| prediction_reshaper=prediction_reshaper, | |
| dropout=dropout, | |
| neuron_select_type=neuron_select_type, | |
| n_random_pairing_self=n_random_pairing_self, | |
| backbone_type='none', | |
| positional_embedding_type='none', | |
| ) | |
| # --- Core Parameters --- | |
| self.iterations_per_digit = iterations_per_digit | |
| self.iterations_per_question_part = iterations_per_question_part | |
| self.iterations_for_answering = iterations_for_answering | |
| # --- Setup Methods --- | |
| def set_initial_rgb(self): | |
| """Set the initial RGB values for the backbone.""" | |
| return None | |
| def get_d_backbone(self): | |
| """Get the dimensionality of the backbone output.""" | |
| return self.d_input | |
| def set_backbone(self): | |
| """Set the backbone module based on the specified type.""" | |
| self.backbone_digit = MNISTBackbone(self.d_input) | |
| self.index_backbone = QAMNISTIndexEmbeddings(50, self.d_input) | |
| self.operator_backbone = QAMNISTOperatorEmbeddings(2, self.d_input) | |
| pass | |
| # --- Utilty Methods --- | |
| def determine_step_type(self, total_iterations_for_digits, total_iterations_for_question, stepi: int): | |
| """Determine whether the current step is for digits, questions, or answers.""" | |
| is_digit_step = stepi < total_iterations_for_digits | |
| is_question_step = total_iterations_for_digits <= stepi < total_iterations_for_digits + total_iterations_for_question | |
| is_answer_step = stepi >= total_iterations_for_digits + total_iterations_for_question | |
| return is_digit_step, is_question_step, is_answer_step | |
| def determine_index_operator_step_type(self, total_iterations_for_digits, stepi: int): | |
| """Determine whether the current step is for index or operator.""" | |
| step_within_questions = stepi - total_iterations_for_digits | |
| if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part: | |
| is_index_step = True | |
| is_operator_step = False | |
| else: | |
| is_index_step = False | |
| is_operator_step = True | |
| return is_index_step, is_operator_step | |
| def get_kv_for_step(self, total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input=None, prev_kv=None): | |
| """Get the key-value for the current step.""" | |
| is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi) | |
| if is_digit_step: | |
| current_input = x[:, stepi] | |
| if prev_input is not None and torch.equal(current_input, prev_input): | |
| return prev_kv, prev_input | |
| kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1)) | |
| elif is_question_step: | |
| offset = stepi - total_iterations_for_digits | |
| current_input = z[:, offset] | |
| if prev_input is not None and torch.equal(current_input, prev_input): | |
| return prev_kv, prev_input | |
| is_index_step, is_operator_step = self.determine_index_operator_step_type(total_iterations_for_digits, stepi) | |
| if is_index_step: | |
| kv = self.index_backbone(current_input) | |
| elif is_operator_step: | |
| kv = self.operator_backbone(current_input) | |
| else: | |
| raise ValueError("Invalid step type for question processing.") | |
| elif is_answer_step: | |
| current_input = None | |
| kv = torch.zeros((x.size(0), self.d_input), device=x.device) | |
| else: | |
| raise ValueError("Invalid step type.") | |
| return kv, current_input | |
| def forward(self, x, z, track=False): | |
| B = x.size(0) | |
| device = x.device | |
| # --- Tracking Initialization --- | |
| pre_activations_tracking = [] | |
| post_activations_tracking = [] | |
| attention_tracking = [] | |
| embedding_tracking = [] | |
| total_iterations_for_digits = x.size(1) | |
| total_iterations_for_question = z.size(1) | |
| total_iterations = total_iterations_for_digits + total_iterations_for_question + self.iterations_for_answering | |
| # --- Initialise Recurrent State --- | |
| state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T) | |
| activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H) | |
| # --- Storage for outputs per iteration --- | |
| predictions = torch.empty(B, self.out_dims, total_iterations, device=device, dtype=x.dtype) | |
| certainties = torch.empty(B, 2, total_iterations, device=device, dtype=x.dtype) | |
| # --- Initialise Recurrent Synch Values --- | |
| decay_alpha_action, decay_beta_action = None, None | |
| self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki | |
| self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15) | |
| r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1) | |
| _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out') | |
| prev_input = None | |
| prev_kv = None | |
| # --- Recurrent Loop --- | |
| for stepi in range(total_iterations): | |
| is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi) | |
| kv, prev_input = self.get_kv_for_step(total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input, prev_kv) | |
| prev_kv = kv | |
| synchronization_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action') | |
| # --- Interact with Data via Attention --- | |
| attn_weights = None | |
| if is_digit_step: | |
| q = self.q_proj(synchronization_action).unsqueeze(1) | |
| attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True) | |
| attn_out = attn_out.squeeze(1) | |
| pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1) | |
| else: | |
| kv = kv.squeeze(1) | |
| pre_synapse_input = torch.concatenate((kv, activated_state), dim=-1) | |
| # --- Apply Synapses --- | |
| state = self.synapses(pre_synapse_input) | |
| state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1) | |
| # --- Apply NLMs --- | |
| activated_state = self.trace_processor(state_trace) | |
| # --- Calculate Synchronisation for Output Predictions --- | |
| synchronization_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out') | |
| # --- Get Predictions and Certainties --- | |
| current_prediction = self.output_projector(synchronization_out) | |
| current_certainty = self.compute_certainty(current_prediction) | |
| predictions[..., stepi] = current_prediction | |
| certainties[..., stepi] = current_certainty | |
| # --- Tracking --- | |
| if track: | |
| pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy()) | |
| post_activations_tracking.append(activated_state.detach().cpu().numpy()) | |
| if attn_weights is not None: | |
| attention_tracking.append(attn_weights.detach().cpu().numpy()) | |
| if is_question_step: | |
| embedding_tracking.append(kv.detach().cpu().numpy()) | |
| # --- Return Values --- | |
| if track: | |
| return predictions, certainties, synchronization_out, np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking), np.array(embedding_tracking) | |
| return predictions, certainties, synchronization_out |