File size: 8,479 Bytes
68b32f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import torch.nn as nn
import torch
import torch.nn.functional as F # Used for GLU if not in modules
import numpy as np
import math
# Local imports (Assuming these contain necessary custom modules)
from models.modules import *
from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
class LSTMBaseline(nn.Module):
"""
LSTM Baseline
Args:
iterations (int): Number of internal 'thought' steps (T, in paper).
d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
d_input (int): Dimensionality of projected attention outputs or direct input features.
heads (int): Number of attention heads.
n_synch_out (int): Number of neurons used for output synchronisation (No, in paper).
n_synch_action (int): Number of neurons used for action/attention synchronisation (Ni, in paper).
synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
memory_length (int): History length for Neuron-Level Models (M, in paper).
deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
memory_hidden_dims (int): Hidden dimension size for deep NLMs.
do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
positional_embedding_type (str): Type of positional embedding for backbone features.
out_dims (int): Dimensionality of the final output projection.
prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
dropout (float): Dropout rate.
"""
def __init__(self,
iterations,
d_model,
d_input,
heads,
out_dims,
iterations_per_digit,
iterations_per_question_part,
iterations_for_answering,
prediction_reshaper=[-1],
dropout=0,
):
super(LSTMBaseline, self).__init__()
# --- Core Parameters ---
self.iterations = iterations
self.d_model = d_model
self.prediction_reshaper = prediction_reshaper
self.out_dims = out_dims
self.d_input = d_input
self.backbone_type = 'qamnist_backbone'
self.iterations_per_digit = iterations_per_digit
self.iterations_per_question_part = iterations_per_question_part
self.total_iterations_for_answering = iterations_for_answering
# --- Backbone / Feature Extraction ---
self.backbone_digit = MNISTBackbone(d_input)
self.index_backbone = QAMNISTIndexEmbeddings(50, d_input)
self.operator_backbone = QAMNISTOperatorEmbeddings(2, d_input)
# --- Core CTM Modules ---
self.lstm_cell = nn.LSTMCell(d_input, d_model)
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
# Attention
self.q_proj = nn.LazyLinear(d_input)
self.kv_proj = nn.Sequential(nn.LazyLinear(d_input), nn.LayerNorm(d_input))
self.attention = nn.MultiheadAttention(d_input, heads, dropout, batch_first=True)
# Output Projection
self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
def compute_certainty(self, current_prediction):
"""Compute the certainty of the current prediction."""
B = current_prediction.size(0)
reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
ne = compute_normalized_entropy(reshaped_pred)
current_certainty = torch.stack((ne, 1-ne), -1)
return current_certainty
def get_kv_for_step(self, stepi, x, z, thought_steps, prev_input=None, prev_kv=None):
is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
if is_digit_step:
current_input = x[:, stepi]
if prev_input is not None and torch.equal(current_input, prev_input):
return prev_kv, prev_input
kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))
elif is_question_step:
offset = stepi - thought_steps.total_iterations_for_digits
current_input = z[:, offset].squeeze(0)
if prev_input is not None and torch.equal(current_input, prev_input):
return prev_kv, prev_input
is_index_step, is_operator_step = thought_steps.determine_answer_step_type(stepi)
if is_index_step:
kv = self.kv_proj(self.index_backbone(current_input))
elif is_operator_step:
kv = self.kv_proj(self.operator_backbone(current_input))
else:
raise ValueError("Invalid step type for question processing.")
elif is_answer_step:
current_input = None
kv = torch.zeros((x.size(0), self.d_input), device=x.device)
else:
raise ValueError("Invalid step type.")
return kv, current_input
def forward(self, x, z, track=False):
"""
Forward pass - Reverted to structure closer to user's working version.
Executes T=iterations steps.
"""
B = x.size(0) # Batch size
# --- Tracking Initialization ---
activations_tracking = []
attention_tracking = [] # Note: reshaping this correctly requires knowing num_heads
embedding_tracking = []
thought_steps = ThoughtSteps(self.iterations_per_digit, self.iterations_per_question_part, self.total_iterations_for_answering, x.size(1), z.size(1))
# --- Step 2: Initialise Recurrent State ---
hidden_state = torch.repeat_interleave(self.start_hidden_state.unsqueeze(0), x.size(0), 0)
cell_state = torch.repeat_interleave(self.start_cell_state.unsqueeze(0), x.size(0), 0)
state_trace = [hidden_state]
device = hidden_state.device
# Storage for outputs per iteration
predictions = torch.empty(B, self.out_dims, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
certainties = torch.empty(B, 2, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
prev_input = None
prev_kv = None
# --- Recurrent Loop (T=iterations steps) ---
for stepi in range(thought_steps.total_iterations):
is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
kv, prev_input = self.get_kv_for_step(stepi, x, z, thought_steps, prev_input, prev_kv)
prev_kv = kv
# --- Interact with Data via Attention ---
attn_weights = None
if is_digit_step:
q = self.q_proj(hidden_state).unsqueeze(1)
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
lstm_input = attn_out.squeeze(1)
else:
lstm_input = kv
hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
state_trace.append(hidden_state)
# --- Get Predictions and Certainties ---
current_prediction = self.output_projector(hidden_state)
current_certainty = self.compute_certainty(current_prediction)
predictions[..., stepi] = current_prediction
certainties[..., stepi] = current_certainty
# --- Tracking ---
if track:
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
if attn_weights is not None:
attention_tracking.append(attn_weights.detach().cpu().numpy())
if is_question_step:
embedding_tracking.append(kv.detach().cpu().numpy())
# --- Return Values ---
if track:
return predictions, certainties, None, np.array(activations_tracking), np.array(activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
return predictions, certainties, None |