import torch import torch.nn as nn import torch.nn.functional as F # Used for GLU import math import numpy as np # Assuming 'add_coord_dim' is defined in models.utils from models.utils import add_coord_dim # --- Basic Utility Modules --- class Identity(nn.Module): """ Identity Module. Returns the input tensor unchanged. Useful as a placeholder or a no-op layer in nn.Sequential containers or conditional network parts. """ def __init__(self): super().__init__() def forward(self, x): return x class Squeeze(nn.Module): """ Squeeze Module. Removes a specified dimension of size 1 from the input tensor. Useful for incorporating tensor dimension squeezing within nn.Sequential. Args: dim (int): The dimension to squeeze. """ def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): return x.squeeze(self.dim) # --- Core CTM Component Modules --- class SynapseUNET(nn.Module): """ UNET-style architecture for the Synapse Model (f_theta1 in the paper). This module implements the connections between neurons in the CTM's latent space. It processes the combined input (previous post-activation state z^t and attention output o^t) to produce the pre-activations (a^t) for the next internal tick (Eq. 1 in the paper). While a simpler Linear or MLP layer can be used, the paper notes that this U-Net structure empirically performed better, suggesting benefit from more flexible synaptic connections[cite: 79, 80]. This implementation uses `depth` points in linspace and creates `depth-1` down/up blocks. Args: in_dims (int): Number of input dimensions (d_model + d_input). out_dims (int): Number of output dimensions (d_model). depth (int): Determines structure size; creates `depth-1` down/up blocks. minimum_width (int): Smallest channel width at the U-Net bottleneck. dropout (float): Dropout rate applied within down/up projections. """ def __init__(self, out_dims, depth, minimum_width=16, dropout=0.0): super().__init__() self.width_out = out_dims self.n_deep = depth # Store depth just for reference if needed # Define UNET structure based on depth # Creates `depth` width values, leading to `depth-1` blocks widths = np.linspace(out_dims, minimum_width, depth) # Initial projection layer self.first_projection = nn.Sequential( nn.LazyLinear(int(widths[0])), # Project to the first width nn.LayerNorm(int(widths[0])), nn.SiLU() ) # Downward path (encoding layers) self.down_projections = nn.ModuleList() self.up_projections = nn.ModuleList() self.skip_lns = nn.ModuleList() num_blocks = len(widths) - 1 # Number of down/up blocks created for i in range(num_blocks): # Down block: widths[i] -> widths[i+1] self.down_projections.append(nn.Sequential( nn.Dropout(dropout), nn.Linear(int(widths[i]), int(widths[i+1])), nn.LayerNorm(int(widths[i+1])), nn.SiLU() )) # Up block: widths[i+1] -> widths[i] # Note: Up blocks are added in order matching down blocks conceptually, # but applied in reverse order in the forward pass. self.up_projections.append(nn.Sequential( nn.Dropout(dropout), nn.Linear(int(widths[i+1]), int(widths[i])), nn.LayerNorm(int(widths[i])), nn.SiLU() )) # Skip connection LayerNorm operates on width[i] self.skip_lns.append(nn.LayerNorm(int(widths[i]))) def forward(self, x): # Initial projection out_first = self.first_projection(x) # Downward path, storing outputs for skip connections outs_down = [out_first] for layer in self.down_projections: outs_down.append(layer(outs_down[-1])) # outs_down contains [level_0, level_1, ..., level_depth-1=bottleneck] outputs # Upward path, starting from the bottleneck output outs_up = outs_down[-1] # Bottleneck activation num_blocks = len(self.up_projections) # Should be depth - 1 for i in range(num_blocks): # Apply up projection in reverse order relative to down blocks # up_projection[num_blocks - 1 - i] processes deeper features first up_layer_idx = num_blocks - 1 - i out_up = self.up_projections[up_layer_idx](outs_up) # Get corresponding skip connection from downward path # skip_connection index = num_blocks - 1 - i (same as up_layer_idx) # This matches the output width of the up_projection[up_layer_idx] skip_idx = up_layer_idx skip_connection = outs_down[skip_idx] # Add skip connection and apply LayerNorm corresponding to this level # skip_lns index also corresponds to the level = skip_idx outs_up = self.skip_lns[skip_idx](out_up + skip_connection) # The final output after all up-projections return outs_up class SuperLinear(nn.Module): """ SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM. This layer is the core component enabling Neuron-Level Models (NLMs), referred to as g_theta_d in the paper (Eq. 3). It applies N independent linear transformations (or small MLPs when used sequentially) to corresponding slices of the input tensor along a specified dimension (typically the neuron or feature dimension). How it works for NLMs: - The input `x` is expected to be the pre-activation history for each neuron, shaped (batch_size, n_neurons=N, history_length=in_dims). - This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons. `w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims). - `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`): - For each neuron `n` (from 0 to N-1): - It takes the neuron's history `x[:, n, :]` (shape B, in_dims). - Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims). - Resulting in `out[:, n, :]` (shape B, out_dims). - The unique bias `self.b1[:, n, :]` is added. - The result is squeezed on the last dim (if out_dims=1) and scaled by `T`. This allows each neuron `d` to process its temporal history `A_d^t` using its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`, enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85]. It's typically used within the `trace_processor` module of the main CTM class. Args: in_dims (int): Input dimension (typically `memory_length`). out_dims (int): Output dimension per neuron. N (int): Number of independent linear models (typically `d_model`). T (float): Initial value for learnable temperature/scaling factor applied to output. do_norm (bool): Apply Layer Normalization to the input history before linear transform. dropout (float): Dropout rate applied to the input. """ def __init__(self, in_dims, out_dims, N, T=1.0, do_norm=False, dropout=0): super().__init__() # N is the number of neurons (d_model), in_dims is the history length (memory_length) self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity() self.in_dims = in_dims # Corresponds to memory_length # LayerNorm applied across the history dimension for each neuron independently self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity() self.do_norm = do_norm # Initialize weights and biases # w1 shape: (memory_length, out_dims, d_model) self.register_parameter('w1', nn.Parameter( torch.empty((in_dims, out_dims, N)).uniform_( -1/math.sqrt(in_dims + out_dims), 1/math.sqrt(in_dims + out_dims) ), requires_grad=True) ) # b1 shape: (1, d_model, out_dims) self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True)) # Learnable temperature/scaler T self.register_parameter('T', nn.Parameter(torch.Tensor([T]))) def forward(self, x): """ Args: x (torch.Tensor): Input tensor, expected shape (B, N, in_dims) where B=batch, N=d_model, in_dims=memory_length. Returns: torch.Tensor: Output tensor, shape (B, N) after squeeze(-1). """ # Input shape: (B, D, M) where D=d_model=N neurons in CTM, M=history/memory length out = self.dropout(x) # LayerNorm across the memory_length dimension (dim=-1) out = self.layernorm(out) # Shape remains (B, N, M) # Apply N independent linear models using einsum # einsum('BDM,MHD->BDH', ...) # x: (B=batch size, D=N neurons, one NLM per each of these, M=history/memory length) # w1: (M, H=hidden dims if using MLP, otherwise output, D=N neurons, parallel) # b1: (1, D=N neurons, H) # einsum result: (B, D, H) # Applying bias requires matching shapes, b1 is broadcasted. out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1 # Squeeze the output dimension (assumed to be 1 usually) and scale by T # This matches the original code's structure exactly. out = out.squeeze(-1) / self.T return out # --- Backbone Modules --- class ParityBackbone(nn.Module): def __init__(self, n_embeddings, d_embedding): super(ParityBackbone, self).__init__() self.embedding = nn.Embedding(n_embeddings, d_embedding) def forward(self, x): """ Maps -1 (negative parity) to 0 and 1 (positive) to 1 """ x = (x == 1).long() return self.embedding(x.long()).transpose(1, 2) # Transpose for compatibility with other backbones class QAMNISTOperatorEmbeddings(nn.Module): def __init__(self, num_operator_types, d_projection): super(QAMNISTOperatorEmbeddings, self).__init__() self.embedding = nn.Embedding(num_operator_types, d_projection) def forward(self, x): # -1 for plus and -2 for minus return self.embedding(-x - 1) class QAMNISTIndexEmbeddings(torch.nn.Module): def __init__(self, max_seq_length, embedding_dim): super().__init__() self.max_seq_length = max_seq_length self.embedding_dim = embedding_dim embedding = torch.zeros(max_seq_length, embedding_dim) position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim)) embedding[:, 0::2] = torch.sin(position * div_term) embedding[:, 1::2] = torch.cos(position * div_term) self.register_buffer('embedding', embedding) def forward(self, x): return self.embedding[x] class ThoughtSteps: """ Helper class for managing "thought steps" in the ctm_qamnist pipeline. Args: iterations_per_digit (int): Number of iterations for each digit. iterations_per_question_part (int): Number of iterations for each question part. total_iterations_for_answering (int): Total number of iterations for answering. total_iterations_for_digits (int): Total number of iterations for digits. total_iterations_for_question (int): Total number of iterations for question. """ def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question): self.iterations_per_digit = iterations_per_digit self.iterations_per_question_part = iterations_per_question_part self.total_iterations_for_digits = total_iterations_for_digits self.total_iterations_for_question = total_iterations_for_question self.total_iterations_for_answering = total_iterations_for_answering self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering def determine_step_type(self, stepi: int): is_digit_step = stepi < self.total_iterations_for_digits is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question return is_digit_step, is_question_step, is_answer_step def determine_answer_step_type(self, stepi: int): step_within_questions = stepi - self.total_iterations_for_digits if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part: is_index_step = True is_operator_step = False else: is_index_step = False is_operator_step = True return is_index_step, is_operator_step class MNISTBackbone(nn.Module): """ Simple backbone for MNIST feature extraction. """ def __init__(self, d_input): super(MNISTBackbone, self).__init__() self.layers = nn.Sequential( nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(d_input), nn.ReLU(), nn.MaxPool2d(2, 2), nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(d_input), nn.ReLU(), nn.MaxPool2d(2, 2), ) def forward(self, x): return self.layers(x) class MiniGridBackbone(nn.Module): def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8): super().__init__() self.object_embedding = nn.Embedding(num_objects, embedding_dim) self.color_embedding = nn.Embedding(num_colors, embedding_dim) self.state_embedding = nn.Embedding(num_states, embedding_dim) self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim) self.project_to_d_projection = nn.Sequential( nn.Linear(embedding_dim * 4, d_input * 2), nn.GLU(), nn.LayerNorm(d_input), nn.Linear(d_input, d_input * 2), nn.GLU(), nn.LayerNorm(d_input) ) def forward(self, x): x = x.long() B, H, W, C = x.size() object_idx = x[:,:,:, 0] color_idx = x[:,:,:, 1] state_idx = x[:,:,:, 2] obj_embed = self.object_embedding(object_idx) color_embed = self.color_embedding(color_idx) state_embed = self.state_embedding(state_idx) pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1) pos_embed = self.position_embedding(pos_idx) out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1)) return out class ClassicControlBackbone(nn.Module): def __init__(self, d_input): super().__init__() self.input_projector = nn.Sequential( nn.Flatten(), nn.LazyLinear(d_input * 2), nn.GLU(), nn.LayerNorm(d_input), nn.LazyLinear(d_input * 2), nn.GLU(), nn.LayerNorm(d_input) ) def forward(self, x): return self.input_projector(x) class ShallowWide(nn.Module): """ Simple, wide, shallow convolutional backbone for image feature extraction. Alternative to ResNet, uses grouped convolutions and GLU activations. Fixed structure, useful for specific experiments. """ def __init__(self): super(ShallowWide, self).__init__() # LazyConv2d infers input channels self.layers = nn.Sequential( nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), # Output channels = 4096 nn.GLU(dim=1), # Halves channels to 2048 nn.BatchNorm2d(2048), # Grouped convolution maintains width but processes groups independently nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32), nn.GLU(dim=1), # Halves channels to 2048 nn.BatchNorm2d(2048) ) def forward(self, x): return self.layers(x) class PretrainedResNetWrapper(nn.Module): """ Wrapper to use standard pre-trained ResNet models from torchvision. Loads a specified ResNet architecture pre-trained on ImageNet, removes the final classification layer (fc), average pooling, and optionally later layers (e.g., layer4), allowing it to be used as a feature extractor backbone. Args: resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50'). fine_tune (bool): If False, freezes the weights of the pre-trained backbone. """ def __init__(self, resnet_type, fine_tune=True): super(PretrainedResNetWrapper, self).__init__() self.resnet_type = resnet_type self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True) if not fine_tune: for param in self.backbone.parameters(): param.requires_grad = False # Remove final layers to use as feature extractor self.backbone.avgpool = Identity() self.backbone.fc = Identity() # Keep layer4 by default, user can modify instance if needed # self.backbone.layer4 = Identity() def forward(self, x): # Get features from the modified ResNet out = self.backbone(x) # Reshape output to (B, C, H, W) - This is heuristic based on original comment. # User might need to adjust this based on which layers are kept/removed. # Infer C based on ResNet type (example values) nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 # Approx for layer3/4 output channel numbers # Infer H, W assuming output is flattened C * H * W num_features = out.shape[-1] # This calculation assumes nc is correct and feature map is square wh_squared = num_features / nc if wh_squared < 0 or not float(wh_squared).is_integer(): print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}") # Return potentially flattened features if reshape fails return out wh = int(np.sqrt(wh_squared)) return out.reshape(x.size(0), nc, wh, wh) # --- Positional Encoding Modules --- class LearnableFourierPositionalEncoding(nn.Module): """ Learnable Fourier Feature Positional Encoding. Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf). Provides positional information for 2D feature maps. Args: d_model (int): The output dimension of the positional encoding (D). G (int): Positional groups (default 1). M (int): Dimensionality of input coordinates (default 2 for H, W). F_dim (int): Dimension of the Fourier features. H_dim (int): Hidden dimension of the MLP. gamma (float): Initialization scale for the Fourier projection weights (Wr). """ def __init__(self, d_model, G=1, M=2, F_dim=256, H_dim=128, gamma=1/2.5, ): super().__init__() self.G = G self.M = M self.F_dim = F_dim self.H_dim = H_dim self.D = d_model self.gamma = gamma self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False) self.mlp = nn.Sequential( nn.Linear(self.F_dim, self.H_dim, bias=True), nn.GLU(), # Halves H_dim nn.Linear(self.H_dim // 2, self.D // self.G), nn.LayerNorm(self.D // self.G) ) self.init_weights() def init_weights(self): nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2) def forward(self, x): """ Computes positional encodings for the input feature map x. Args: x (torch.Tensor): Input feature map, shape (B, C, H, W). Returns: torch.Tensor: Positional encoding tensor, shape (B, D, H, W). """ B, C, H, W = x.shape # Creates coordinates based on (H, W) and repeats for batch B. # Takes x[:,0] assuming channel dim isn't needed for coords. x_coord = add_coord_dim(x[:,0]) # Expects (B, H, W) -> (B, H, W, 2) # Compute Fourier features projected = self.Wr(x_coord) # (B, H, W, F_dim // 2) cosines = torch.cos(projected) sines = torch.sin(projected) F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) # (B, H, W, F_dim) # Project features through MLP Y = self.mlp(F) # (B, H, W, D // G) # Reshape to (B, D, H, W) PEx = Y.permute(0, 3, 1, 2) # Assuming G=1 return PEx class MultiLearnableFourierPositionalEncoding(nn.Module): """ Combines multiple LearnableFourierPositionalEncoding modules with different initialization scales (gamma) via a learnable weighted sum. Allows the model to learn an optimal combination of positional frequencies. Args: d_model (int): Output dimension of the encoding. G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding. gamma_range (list[float]): Min and max gamma values for the linspace. N (int): Number of parallel embedding modules to create. """ def __init__(self, d_model, G=1, M=2, F_dim=256, H_dim=128, gamma_range=[1.0, 0.1], # Default range N=10, ): super().__init__() self.embedders = nn.ModuleList() for gamma in np.linspace(gamma_range[0], gamma_range[1], N): self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma)) # Renamed parameter from 'combination' to 'combination_weights' for clarity only in comments # Actual registered name remains 'combination' as in original code self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True)) self.N = N def forward(self, x): """ Computes combined positional encoding. Args: x (torch.Tensor): Input feature map, shape (B, C, H, W). Returns: torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W). """ # Compute embeddings from all modules and stack: (N, B, D, H, W) pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0) # Compute combination weights using softmax # Use registered parameter name 'combination' # Reshape weights for broadcasting: (N,) -> (N, 1, 1, 1, 1) weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1) # Compute weighted sum over the N dimension combined_emb = (pos_embs * weights).sum(0) # (B, D, H, W) return combined_emb class CustomRotationalEmbedding(nn.Module): """ Custom Rotational Positional Embedding. Generates 2D positional embeddings based on rotating a fixed start vector. The rotation angle for each grid position is determined primarily by its horizontal position (width dimension). The resulting rotated vectors are concatenated and projected. Note: The current implementation derives angles only from the width dimension (`x.size(-1)`). Args: d_model (int): Dimensionality of the output embeddings. """ def __init__(self, d_model): super(CustomRotationalEmbedding, self).__init__() # Learnable 2D start vector self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True)) # Projects the 4D concatenated rotated vectors to d_model # Input size 4 comes from concatenating two 2D rotated vectors self.projection = nn.Sequential(nn.Linear(4, d_model)) def forward(self, x): """ Computes rotational positional embeddings based on input width. Args: x (torch.Tensor): Input tensor (used for shape and device), shape (batch_size, channels, height, width). Returns: Output tensor containing positional embeddings, shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all. """ B, C, H, W = x.shape device = x.device # --- Generate rotations based only on Width --- # Angles derived from width dimension theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) # Angle per column cos_theta = torch.cos(theta_rad) sin_theta = torch.sin(theta_rad) # Create rotation matrices: Shape (W, 2, 2) # Use unsqueeze(1) to allow stacking along dim 1 rotation_matrices = torch.stack([ torch.stack([cos_theta, -sin_theta], dim=-1), # Shape (W, 2) torch.stack([sin_theta, cos_theta], dim=-1) # Shape (W, 2) ], dim=1) # Stacks along dim 1 -> Shape (W, 2, 2) # Rotate the start vector by column angle: Shape (W, 2) rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector) # --- Create Grid Key --- # Original code uses repeats based on rotated_vectors.shape[0] (which is W) for both dimensions. # This creates a (W, W, 4) key tensor. key = torch.cat(( torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), # (W, 1, 2) -> (W, W, 2) torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0) # (1, W, 2) -> (W, W, 2) ), dim=-1) # Shape (W, W, 4) # Project the 4D key vector to d_model: Shape (W, W, d_model) pe_grid = self.projection(key) # Reshape to (1, d_model, W, W) and then select/resize to target H, W? # Original code permutes to (d_model, W, W) and unsqueezes to (1, d_model, W, W) pe = pe_grid.permute(2, 0, 1).unsqueeze(0) # If H != W, this needs adjustment. Assuming H=W or cropping/padding happens later. # Let's return the (1, d_model, W, W) tensor as generated by the original logic. # If H != W, downstream code must handle the mismatch or this PE needs modification. if H != W: # Simple interpolation/cropping could be added, but sticking to original logic: # Option 1: Interpolate # pe = F.interpolate(pe, size=(H, W), mode='bilinear', align_corners=False) # Option 2: Crop/Pad (e.g., crop if W > W_target, pad if W < W_target) # Sticking to original: return shape (1, d_model, W, W) pass return pe class CustomRotationalEmbedding1D(nn.Module): def __init__(self, d_model): super(CustomRotationalEmbedding1D, self).__init__() self.projection = nn.Linear(2, d_model) def forward(self, x): start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float) theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device)) cos_theta = torch.cos(theta_rad) sin_theta = torch.sin(theta_rad) cos_theta = cos_theta.unsqueeze(1) # Shape: (height, 1) sin_theta = sin_theta.unsqueeze(1) # Shape: (height, 1) # Create rotation matrices rotation_matrices = torch.stack([ torch.cat([cos_theta, -sin_theta], dim=1), torch.cat([sin_theta, cos_theta], dim=1) ], dim=1) # Shape: (height, 2, 2) # Rotate the start vector rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector) pe = self.projection(rotated_vectors) pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0) return pe.transpose(1, 2) # Transpose for compatibility with other backbones