LukeDarlow's picture
Welcome to the CTM. This is the first commit of the public repo. Enjoy!
68b32f4
import torch.nn as nn
import torch
import numpy as np
import math
from models.modules import ParityBackbone, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide
from models.resnet import prepare_resnet_backbone
from models.utils import compute_normalized_entropy
from models.constants import (
VALID_BACKBONE_TYPES,
VALID_POSITIONAL_EMBEDDING_TYPES
)
class LSTMBaseline(nn.Module):
"""
LSTM Baseline
Args:
iterations (int): Number of internal 'thought' steps (T, in paper).
d_model (int): Core dimensionality of the latent space.
d_input (int): Dimensionality of projected attention outputs or direct input features.
heads (int): Number of attention heads.
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
positional_embedding_type (str): Type of positional embedding for backbone features.
out_dims (int): Dimensionality of the final output projection.
prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific).
dropout (float): Dropout rate.
"""
def __init__(self,
iterations,
d_model,
d_input,
heads,
backbone_type,
num_layers,
positional_embedding_type,
out_dims,
prediction_reshaper=[-1],
dropout=0,
):
super(LSTMBaseline, self).__init__()
# --- Core Parameters ---
self.iterations = iterations
self.d_model = d_model
self.d_input = d_input
self.prediction_reshaper = prediction_reshaper
self.backbone_type = backbone_type
self.positional_embedding_type = positional_embedding_type
self.out_dims = out_dims
# --- Assertions ---
self.verify_args()
# --- Input Processing ---
d_backbone = self.get_d_backbone()
self.set_initial_rgb()
self.set_backbone()
self.positional_embedding = self.get_positional_embedding(d_backbone)
self.kv_proj = self.get_kv_proj()
self.lstm = nn.LSTM(d_input, d_model, num_layers, batch_first=True, dropout=dropout)
self.q_proj = self.get_q_proj()
self.attention = self.get_attention(heads, dropout)
self.output_projector = nn.Sequential(nn.LazyLinear(out_dims))
# --- Start States ---
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
# --- Core LSTM Methods ---
def compute_features(self, x):
"""Applies backbone and positional embedding to input."""
x = self.initial_rgb(x)
self.kv_features = self.backbone(x)
pos_emb = self.positional_embedding(self.kv_features)
combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2)
kv = self.kv_proj(combined_features)
return kv
def compute_certainty(self, current_prediction):
"""Compute the certainty of the current prediction."""
B = current_prediction.size(0)
reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper)
ne = compute_normalized_entropy(reshaped_pred)
current_certainty = torch.stack((ne, 1-ne), -1)
return current_certainty
# --- Setup Methods ---
def set_initial_rgb(self):
"""Set the initial RGB processing module based on the backbone type."""
if 'resnet' in self.backbone_type:
self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
else:
self.initial_rgb = nn.Identity()
def get_d_backbone(self):
"""
Get the dimensionality of the backbone output, to be used for positional embedding setup.
This is a little bit complicated for resnets, but the logic should be easy enough to read below.
"""
if self.backbone_type == 'shallow-wide':
return 2048
elif self.backbone_type == 'parity_backbone':
return self.d_input
elif 'resnet' in self.backbone_type:
if '18' in self.backbone_type or '34' in self.backbone_type:
if self.backbone_type.split('-')[1]=='1': return 64
elif self.backbone_type.split('-')[1]=='2': return 128
elif self.backbone_type.split('-')[1]=='3': return 256
elif self.backbone_type.split('-')[1]=='4': return 512
else:
raise NotImplementedError
else:
if self.backbone_type.split('-')[1]=='1': return 256
elif self.backbone_type.split('-')[1]=='2': return 512
elif self.backbone_type.split('-')[1]=='3': return 1024
elif self.backbone_type.split('-')[1]=='4': return 2048
else:
raise NotImplementedError
elif self.backbone_type == 'none':
return None
else:
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
def set_backbone(self):
"""Set the backbone module based on the specified type."""
if self.backbone_type == 'shallow-wide':
self.backbone = ShallowWide()
elif self.backbone_type == 'parity_backbone':
d_backbone = self.get_d_backbone()
self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone)
elif 'resnet' in self.backbone_type:
self.backbone = prepare_resnet_backbone(self.backbone_type)
elif self.backbone_type == 'none':
self.backbone = nn.Identity()
else:
raise ValueError(f"Invalid backbone_type: {self.backbone_type}")
def get_positional_embedding(self, d_backbone):
"""Get the positional embedding module."""
if self.positional_embedding_type == 'learnable-fourier':
return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5)
elif self.positional_embedding_type == 'multi-learnable-fourier':
return MultiLearnableFourierPositionalEncoding(d_backbone)
elif self.positional_embedding_type == 'custom-rotational':
return CustomRotationalEmbedding(d_backbone)
elif self.positional_embedding_type == 'custom-rotational-1d':
return CustomRotationalEmbedding1D(d_backbone)
elif self.positional_embedding_type == 'none':
return lambda x: 0 # Default no-op
else:
raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}")
def get_attention(self, heads, dropout):
"""Get the attention module."""
return nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True)
def get_kv_proj(self):
"""Get the key-value projection module."""
return nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input))
def get_q_proj(self):
"""Get the query projection module."""
return nn.LazyLinear(self.d_input)
def verify_args(self):
"""Verify the validity of the input arguments."""
assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \
f"Invalid backbone_type: {self.backbone_type}"
assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \
f"Invalid positional_embedding_type: {self.positional_embedding_type}"
if self.backbone_type=='none' and self.positional_embedding_type!='none':
raise AssertionError("There should be no positional embedding if there is no backbone.")
pass
def forward(self, x, track=False):
"""
Forward pass - Reverted to structure closer to user's working version.
Executes T=iterations steps.
"""
B = x.size(0)
device = x.device
# --- Tracking Initialization ---
activations_tracking = []
attention_tracking = []
# --- Featurise Input Data ---
kv = self.compute_features(x)
# --- Initialise Recurrent State ---
hn = torch.repeat_interleave(self.start_hidden_state.unsqueeze(1), x.size(0), 1)
cn = torch.repeat_interleave(self.start_cell_state.unsqueeze(1), x.size(0), 1)
state_trace = [hn[-1]]
# --- Prepare Storage for Outputs per Iteration ---
predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype)
certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype)
# --- Recurrent Loop ---
for stepi in range(self.iterations):
# --- Interact with Data via Attention ---
q = self.q_proj(hn[-1].unsqueeze(1))
attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True)
lstm_input = attn_out
# --- Apply LSTM ---
hidden_state, (hn,cn) = self.lstm(lstm_input, (hn, cn))
hidden_state = hidden_state.squeeze(1)
state_trace.append(hidden_state)
# --- Get Predictions and Certainties ---
current_prediction = self.output_projector(hidden_state)
current_certainty = self.compute_certainty(current_prediction)
predictions[..., stepi] = current_prediction
certainties[..., stepi] = current_certainty
# --- Tracking ---
if track:
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
attention_tracking.append(attn_weights.detach().cpu().numpy())
# --- Return Values ---
if track:
return predictions, certainties, None, np.zeros_like(activations_tracking), np.array(activations_tracking), np.array(attention_tracking)
return predictions, certainties, None