File size: 8,479 Bytes
68b32f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch.nn as nn
import torch
import torch.nn.functional as F # Used for GLU if not in modules
import numpy as np
import math

# Local imports (Assuming these contain necessary custom modules)
from models.modules import *
from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here

class LSTMBaseline(nn.Module):
    """
    LSTM Baseline

    Args:
        iterations (int): Number of internal 'thought' steps (T, in paper).
        d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
        d_input (int): Dimensionality of projected attention outputs or direct input features.
        heads (int): Number of attention heads.
        n_synch_out (int): Number of neurons used for output synchronisation (No, in paper).
        n_synch_action (int): Number of neurons used for action/attention synchronisation (Ni, in paper).
        synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP).
        memory_length (int): History length for Neuron-Level Models (M, in paper).
        deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear.
        memory_hidden_dims (int): Hidden dimension size for deep NLMs.
        do_layernorm_nlm (bool): Apply LayerNorm within NLMs.
        backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
        positional_embedding_type (str): Type of positional embedding for backbone features.
        out_dims (int): Dimensionality of the final output projection.
        prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
        dropout (float): Dropout rate.
    """

    def __init__(self,
                 iterations,
                 d_model,
                 d_input,
                 heads,
                 out_dims,
                 iterations_per_digit,
                 iterations_per_question_part,
                 iterations_for_answering,
                 prediction_reshaper=[-1],
                 dropout=0,
                 ):
        super(LSTMBaseline, self).__init__()

        # --- Core Parameters ---
        self.iterations = iterations
        self.d_model = d_model
        self.prediction_reshaper = prediction_reshaper
        self.out_dims = out_dims
        self.d_input = d_input
        self.backbone_type = 'qamnist_backbone'
        self.iterations_per_digit = iterations_per_digit
        self.iterations_per_question_part = iterations_per_question_part
        self.total_iterations_for_answering = iterations_for_answering

        # --- Backbone / Feature Extraction ---
        self.backbone_digit = MNISTBackbone(d_input)
        self.index_backbone = QAMNISTIndexEmbeddings(50, d_input)
        self.operator_backbone = QAMNISTOperatorEmbeddings(2, d_input)

        # --- Core CTM Modules ---
        self.lstm_cell = nn.LSTMCell(d_input, d_model)
        self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
        self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
        
        # Attention
        self.q_proj = nn.LazyLinear(d_input)
        self.kv_proj = nn.Sequential(nn.LazyLinear(d_input), nn.LayerNorm(d_input))
        self.attention = nn.MultiheadAttention(d_input, heads, dropout, batch_first=True)
   
        # Output Projection
        self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))

    def compute_certainty(self, current_prediction):
        """Compute the certainty of the current prediction."""
        B = current_prediction.size(0)
        reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
        ne = compute_normalized_entropy(reshaped_pred)
        current_certainty = torch.stack((ne, 1-ne), -1)
        return current_certainty

    def get_kv_for_step(self, stepi, x, z, thought_steps, prev_input=None, prev_kv=None):
        is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)

        if is_digit_step:
            current_input = x[:, stepi]
            if prev_input is not None and torch.equal(current_input, prev_input):
                return prev_kv, prev_input
            kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1))

        elif is_question_step:
            offset = stepi - thought_steps.total_iterations_for_digits
            current_input = z[:, offset].squeeze(0)
            if prev_input is not None and torch.equal(current_input, prev_input):
                return prev_kv, prev_input
            is_index_step, is_operator_step = thought_steps.determine_answer_step_type(stepi)
            if is_index_step:
                kv = self.kv_proj(self.index_backbone(current_input))
            elif is_operator_step:
                kv = self.kv_proj(self.operator_backbone(current_input))
            else:
                raise ValueError("Invalid step type for question processing.")

        elif is_answer_step:
            current_input = None
            kv = torch.zeros((x.size(0), self.d_input), device=x.device)

        else:
            raise ValueError("Invalid step type.")

        return kv, current_input

    def forward(self, x, z, track=False):
        """
        Forward pass - Reverted to structure closer to user's working version.
        Executes T=iterations steps.
        """
        B = x.size(0) # Batch size

        # --- Tracking Initialization ---
        activations_tracking = []
        attention_tracking = [] # Note: reshaping this correctly requires knowing num_heads
        embedding_tracking = []

        thought_steps = ThoughtSteps(self.iterations_per_digit, self.iterations_per_question_part, self.total_iterations_for_answering, x.size(1), z.size(1))

        # --- Step 2: Initialise Recurrent State ---
        hidden_state = torch.repeat_interleave(self.start_hidden_state.unsqueeze(0), x.size(0), 0)
        cell_state = torch.repeat_interleave(self.start_cell_state.unsqueeze(0), x.size(0), 0)

        state_trace = [hidden_state]

        device = hidden_state.device

        # Storage for outputs per iteration
        predictions = torch.empty(B, self.out_dims, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed
        certainties = torch.empty(B, 2, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed

        prev_input = None
        prev_kv = None

        # --- Recurrent Loop (T=iterations steps) ---
        for stepi in range(thought_steps.total_iterations):

            is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi)
            kv, prev_input = self.get_kv_for_step(stepi, x, z, thought_steps, prev_input, prev_kv)
            prev_kv = kv

            # --- Interact with Data via Attention ---
            attn_weights = None
            if is_digit_step:
                q = self.q_proj(hidden_state).unsqueeze(1)
                attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
                lstm_input = attn_out.squeeze(1)
            else:
                lstm_input = kv



            hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
            state_trace.append(hidden_state)

            # --- Get Predictions and Certainties ---
            current_prediction = self.output_projector(hidden_state)
            current_certainty = self.compute_certainty(current_prediction)

            predictions[..., stepi] = current_prediction
            certainties[..., stepi] = current_certainty

            # --- Tracking ---
            if track:
                activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
                if attn_weights is not None:
                    attention_tracking.append(attn_weights.detach().cpu().numpy())
                if is_question_step:
                    embedding_tracking.append(kv.detach().cpu().numpy())

        # --- Return Values ---
        if track:
            return predictions, certainties, None, np.array(activations_tracking), np.array(activations_tracking), np.array(attention_tracking), np.array(embedding_tracking)
        return predictions, certainties, None