LukeDarlow's picture
Welcome to the CTM. This is the first commit of the public repo. Enjoy!
68b32f4
raw
history blame
28.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F # Used for GLU
import math
import numpy as np
# Assuming 'add_coord_dim' is defined in models.utils
from models.utils import add_coord_dim
# --- Basic Utility Modules ---
class Identity(nn.Module):
"""
Identity Module.
Returns the input tensor unchanged. Useful as a placeholder or a no-op layer
in nn.Sequential containers or conditional network parts.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Squeeze(nn.Module):
"""
Squeeze Module.
Removes a specified dimension of size 1 from the input tensor.
Useful for incorporating tensor dimension squeezing within nn.Sequential.
Args:
dim (int): The dimension to squeeze.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
return x.squeeze(self.dim)
# --- Core CTM Component Modules ---
class SynapseUNET(nn.Module):
"""
UNET-style architecture for the Synapse Model (f_theta1 in the paper).
This module implements the connections between neurons in the CTM's latent
space. It processes the combined input (previous post-activation state z^t
and attention output o^t) to produce the pre-activations (a^t) for the
next internal tick (Eq. 1 in the paper).
While a simpler Linear or MLP layer can be used, the paper notes
that this U-Net structure empirically performed better, suggesting benefit
from more flexible synaptic connections[cite: 79, 80]. This implementation
uses `depth` points in linspace and creates `depth-1` down/up blocks.
Args:
in_dims (int): Number of input dimensions (d_model + d_input).
out_dims (int): Number of output dimensions (d_model).
depth (int): Determines structure size; creates `depth-1` down/up blocks.
minimum_width (int): Smallest channel width at the U-Net bottleneck.
dropout (float): Dropout rate applied within down/up projections.
"""
def __init__(self,
out_dims,
depth,
minimum_width=16,
dropout=0.0):
super().__init__()
self.width_out = out_dims
self.n_deep = depth # Store depth just for reference if needed
# Define UNET structure based on depth
# Creates `depth` width values, leading to `depth-1` blocks
widths = np.linspace(out_dims, minimum_width, depth)
# Initial projection layer
self.first_projection = nn.Sequential(
nn.LazyLinear(int(widths[0])), # Project to the first width
nn.LayerNorm(int(widths[0])),
nn.SiLU()
)
# Downward path (encoding layers)
self.down_projections = nn.ModuleList()
self.up_projections = nn.ModuleList()
self.skip_lns = nn.ModuleList()
num_blocks = len(widths) - 1 # Number of down/up blocks created
for i in range(num_blocks):
# Down block: widths[i] -> widths[i+1]
self.down_projections.append(nn.Sequential(
nn.Dropout(dropout),
nn.Linear(int(widths[i]), int(widths[i+1])),
nn.LayerNorm(int(widths[i+1])),
nn.SiLU()
))
# Up block: widths[i+1] -> widths[i]
# Note: Up blocks are added in order matching down blocks conceptually,
# but applied in reverse order in the forward pass.
self.up_projections.append(nn.Sequential(
nn.Dropout(dropout),
nn.Linear(int(widths[i+1]), int(widths[i])),
nn.LayerNorm(int(widths[i])),
nn.SiLU()
))
# Skip connection LayerNorm operates on width[i]
self.skip_lns.append(nn.LayerNorm(int(widths[i])))
def forward(self, x):
# Initial projection
out_first = self.first_projection(x)
# Downward path, storing outputs for skip connections
outs_down = [out_first]
for layer in self.down_projections:
outs_down.append(layer(outs_down[-1]))
# outs_down contains [level_0, level_1, ..., level_depth-1=bottleneck] outputs
# Upward path, starting from the bottleneck output
outs_up = outs_down[-1] # Bottleneck activation
num_blocks = len(self.up_projections) # Should be depth - 1
for i in range(num_blocks):
# Apply up projection in reverse order relative to down blocks
# up_projection[num_blocks - 1 - i] processes deeper features first
up_layer_idx = num_blocks - 1 - i
out_up = self.up_projections[up_layer_idx](outs_up)
# Get corresponding skip connection from downward path
# skip_connection index = num_blocks - 1 - i (same as up_layer_idx)
# This matches the output width of the up_projection[up_layer_idx]
skip_idx = up_layer_idx
skip_connection = outs_down[skip_idx]
# Add skip connection and apply LayerNorm corresponding to this level
# skip_lns index also corresponds to the level = skip_idx
outs_up = self.skip_lns[skip_idx](out_up + skip_connection)
# The final output after all up-projections
return outs_up
class SuperLinear(nn.Module):
"""
SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.
This layer is the core component enabling Neuron-Level Models (NLMs),
referred to as g_theta_d in the paper (Eq. 3). It applies N independent
linear transformations (or small MLPs when used sequentially) to corresponding
slices of the input tensor along a specified dimension (typically the neuron
or feature dimension).
How it works for NLMs:
- The input `x` is expected to be the pre-activation history for each neuron,
shaped (batch_size, n_neurons=N, history_length=in_dims).
- This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons.
`w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims).
- `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix
multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`):
- For each neuron `n` (from 0 to N-1):
- It takes the neuron's history `x[:, n, :]` (shape B, in_dims).
- Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims).
- Resulting in `out[:, n, :]` (shape B, out_dims).
- The unique bias `self.b1[:, n, :]` is added.
- The result is squeezed on the last dim (if out_dims=1) and scaled by `T`.
This allows each neuron `d` to process its temporal history `A_d^t` using
its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`,
enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85].
It's typically used within the `trace_processor` module of the main CTM class.
Args:
in_dims (int): Input dimension (typically `memory_length`).
out_dims (int): Output dimension per neuron.
N (int): Number of independent linear models (typically `d_model`).
T (float): Initial value for learnable temperature/scaling factor applied to output.
do_norm (bool): Apply Layer Normalization to the input history before linear transform.
dropout (float): Dropout rate applied to the input.
"""
def __init__(self,
in_dims,
out_dims,
N,
T=1.0,
do_norm=False,
dropout=0):
super().__init__()
# N is the number of neurons (d_model), in_dims is the history length (memory_length)
self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
self.in_dims = in_dims # Corresponds to memory_length
# LayerNorm applied across the history dimension for each neuron independently
self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity()
self.do_norm = do_norm
# Initialize weights and biases
# w1 shape: (memory_length, out_dims, d_model)
self.register_parameter('w1', nn.Parameter(
torch.empty((in_dims, out_dims, N)).uniform_(
-1/math.sqrt(in_dims + out_dims),
1/math.sqrt(in_dims + out_dims)
), requires_grad=True)
)
# b1 shape: (1, d_model, out_dims)
self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
# Learnable temperature/scaler T
self.register_parameter('T', nn.Parameter(torch.Tensor([T])))
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor, expected shape (B, N, in_dims)
where B=batch, N=d_model, in_dims=memory_length.
Returns:
torch.Tensor: Output tensor, shape (B, N) after squeeze(-1).
"""
# Input shape: (B, D, M) where D=d_model=N neurons in CTM, M=history/memory length
out = self.dropout(x)
# LayerNorm across the memory_length dimension (dim=-1)
out = self.layernorm(out) # Shape remains (B, N, M)
# Apply N independent linear models using einsum
# einsum('BDM,MHD->BDH', ...)
# x: (B=batch size, D=N neurons, one NLM per each of these, M=history/memory length)
# w1: (M, H=hidden dims if using MLP, otherwise output, D=N neurons, parallel)
# b1: (1, D=N neurons, H)
# einsum result: (B, D, H)
# Applying bias requires matching shapes, b1 is broadcasted.
out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1
# Squeeze the output dimension (assumed to be 1 usually) and scale by T
# This matches the original code's structure exactly.
out = out.squeeze(-1) / self.T
return out
# --- Backbone Modules ---
class ParityBackbone(nn.Module):
def __init__(self, n_embeddings, d_embedding):
super(ParityBackbone, self).__init__()
self.embedding = nn.Embedding(n_embeddings, d_embedding)
def forward(self, x):
"""
Maps -1 (negative parity) to 0 and 1 (positive) to 1
"""
x = (x == 1).long()
return self.embedding(x.long()).transpose(1, 2) # Transpose for compatibility with other backbones
class QAMNISTOperatorEmbeddings(nn.Module):
def __init__(self, num_operator_types, d_projection):
super(QAMNISTOperatorEmbeddings, self).__init__()
self.embedding = nn.Embedding(num_operator_types, d_projection)
def forward(self, x):
# -1 for plus and -2 for minus
return self.embedding(-x - 1)
class QAMNISTIndexEmbeddings(torch.nn.Module):
def __init__(self, max_seq_length, embedding_dim):
super().__init__()
self.max_seq_length = max_seq_length
self.embedding_dim = embedding_dim
embedding = torch.zeros(max_seq_length, embedding_dim)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
embedding[:, 0::2] = torch.sin(position * div_term)
embedding[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('embedding', embedding)
def forward(self, x):
return self.embedding[x]
class ThoughtSteps:
"""
Helper class for managing "thought steps" in the ctm_qamnist pipeline.
Args:
iterations_per_digit (int): Number of iterations for each digit.
iterations_per_question_part (int): Number of iterations for each question part.
total_iterations_for_answering (int): Total number of iterations for answering.
total_iterations_for_digits (int): Total number of iterations for digits.
total_iterations_for_question (int): Total number of iterations for question.
"""
def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question):
self.iterations_per_digit = iterations_per_digit
self.iterations_per_question_part = iterations_per_question_part
self.total_iterations_for_digits = total_iterations_for_digits
self.total_iterations_for_question = total_iterations_for_question
self.total_iterations_for_answering = total_iterations_for_answering
self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering
def determine_step_type(self, stepi: int):
is_digit_step = stepi < self.total_iterations_for_digits
is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question
is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question
return is_digit_step, is_question_step, is_answer_step
def determine_answer_step_type(self, stepi: int):
step_within_questions = stepi - self.total_iterations_for_digits
if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
is_index_step = True
is_operator_step = False
else:
is_index_step = False
is_operator_step = True
return is_index_step, is_operator_step
class MNISTBackbone(nn.Module):
"""
Simple backbone for MNIST feature extraction.
"""
def __init__(self, d_input):
super(MNISTBackbone, self).__init__()
self.layers = nn.Sequential(
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(d_input),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(d_input),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
def forward(self, x):
return self.layers(x)
class MiniGridBackbone(nn.Module):
def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8):
super().__init__()
self.object_embedding = nn.Embedding(num_objects, embedding_dim)
self.color_embedding = nn.Embedding(num_colors, embedding_dim)
self.state_embedding = nn.Embedding(num_states, embedding_dim)
self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim)
self.project_to_d_projection = nn.Sequential(
nn.Linear(embedding_dim * 4, d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input),
nn.Linear(d_input, d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input)
)
def forward(self, x):
x = x.long()
B, H, W, C = x.size()
object_idx = x[:,:,:, 0]
color_idx = x[:,:,:, 1]
state_idx = x[:,:,:, 2]
obj_embed = self.object_embedding(object_idx)
color_embed = self.color_embedding(color_idx)
state_embed = self.state_embedding(state_idx)
pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1)
pos_embed = self.position_embedding(pos_idx)
out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1))
return out
class ClassicControlBackbone(nn.Module):
def __init__(self, d_input):
super().__init__()
self.input_projector = nn.Sequential(
nn.Flatten(),
nn.LazyLinear(d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input),
nn.LazyLinear(d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input)
)
def forward(self, x):
return self.input_projector(x)
class ShallowWide(nn.Module):
"""
Simple, wide, shallow convolutional backbone for image feature extraction.
Alternative to ResNet, uses grouped convolutions and GLU activations.
Fixed structure, useful for specific experiments.
"""
def __init__(self):
super(ShallowWide, self).__init__()
# LazyConv2d infers input channels
self.layers = nn.Sequential(
nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), # Output channels = 4096
nn.GLU(dim=1), # Halves channels to 2048
nn.BatchNorm2d(2048),
# Grouped convolution maintains width but processes groups independently
nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32),
nn.GLU(dim=1), # Halves channels to 2048
nn.BatchNorm2d(2048)
)
def forward(self, x):
return self.layers(x)
class PretrainedResNetWrapper(nn.Module):
"""
Wrapper to use standard pre-trained ResNet models from torchvision.
Loads a specified ResNet architecture pre-trained on ImageNet, removes the
final classification layer (fc), average pooling, and optionally later layers
(e.g., layer4), allowing it to be used as a feature extractor backbone.
Args:
resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50').
fine_tune (bool): If False, freezes the weights of the pre-trained backbone.
"""
def __init__(self, resnet_type, fine_tune=True):
super(PretrainedResNetWrapper, self).__init__()
self.resnet_type = resnet_type
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True)
if not fine_tune:
for param in self.backbone.parameters():
param.requires_grad = False
# Remove final layers to use as feature extractor
self.backbone.avgpool = Identity()
self.backbone.fc = Identity()
# Keep layer4 by default, user can modify instance if needed
# self.backbone.layer4 = Identity()
def forward(self, x):
# Get features from the modified ResNet
out = self.backbone(x)
# Reshape output to (B, C, H, W) - This is heuristic based on original comment.
# User might need to adjust this based on which layers are kept/removed.
# Infer C based on ResNet type (example values)
nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 # Approx for layer3/4 output channel numbers
# Infer H, W assuming output is flattened C * H * W
num_features = out.shape[-1]
# This calculation assumes nc is correct and feature map is square
wh_squared = num_features / nc
if wh_squared < 0 or not float(wh_squared).is_integer():
print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}")
# Return potentially flattened features if reshape fails
return out
wh = int(np.sqrt(wh_squared))
return out.reshape(x.size(0), nc, wh, wh)
# --- Positional Encoding Modules ---
class LearnableFourierPositionalEncoding(nn.Module):
"""
Learnable Fourier Feature Positional Encoding.
Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
Provides positional information for 2D feature maps.
Args:
d_model (int): The output dimension of the positional encoding (D).
G (int): Positional groups (default 1).
M (int): Dimensionality of input coordinates (default 2 for H, W).
F_dim (int): Dimension of the Fourier features.
H_dim (int): Hidden dimension of the MLP.
gamma (float): Initialization scale for the Fourier projection weights (Wr).
"""
def __init__(self, d_model,
G=1, M=2,
F_dim=256,
H_dim=128,
gamma=1/2.5,
):
super().__init__()
self.G = G
self.M = M
self.F_dim = F_dim
self.H_dim = H_dim
self.D = d_model
self.gamma = gamma
self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
self.mlp = nn.Sequential(
nn.Linear(self.F_dim, self.H_dim, bias=True),
nn.GLU(), # Halves H_dim
nn.Linear(self.H_dim // 2, self.D // self.G),
nn.LayerNorm(self.D // self.G)
)
self.init_weights()
def init_weights(self):
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
def forward(self, x):
"""
Computes positional encodings for the input feature map x.
Args:
x (torch.Tensor): Input feature map, shape (B, C, H, W).
Returns:
torch.Tensor: Positional encoding tensor, shape (B, D, H, W).
"""
B, C, H, W = x.shape
# Creates coordinates based on (H, W) and repeats for batch B.
# Takes x[:,0] assuming channel dim isn't needed for coords.
x_coord = add_coord_dim(x[:,0]) # Expects (B, H, W) -> (B, H, W, 2)
# Compute Fourier features
projected = self.Wr(x_coord) # (B, H, W, F_dim // 2)
cosines = torch.cos(projected)
sines = torch.sin(projected)
F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) # (B, H, W, F_dim)
# Project features through MLP
Y = self.mlp(F) # (B, H, W, D // G)
# Reshape to (B, D, H, W)
PEx = Y.permute(0, 3, 1, 2) # Assuming G=1
return PEx
class MultiLearnableFourierPositionalEncoding(nn.Module):
"""
Combines multiple LearnableFourierPositionalEncoding modules with different
initialization scales (gamma) via a learnable weighted sum.
Allows the model to learn an optimal combination of positional frequencies.
Args:
d_model (int): Output dimension of the encoding.
G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding.
gamma_range (list[float]): Min and max gamma values for the linspace.
N (int): Number of parallel embedding modules to create.
"""
def __init__(self, d_model,
G=1, M=2,
F_dim=256,
H_dim=128,
gamma_range=[1.0, 0.1], # Default range
N=10,
):
super().__init__()
self.embedders = nn.ModuleList()
for gamma in np.linspace(gamma_range[0], gamma_range[1], N):
self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma))
# Renamed parameter from 'combination' to 'combination_weights' for clarity only in comments
# Actual registered name remains 'combination' as in original code
self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True))
self.N = N
def forward(self, x):
"""
Computes combined positional encoding.
Args:
x (torch.Tensor): Input feature map, shape (B, C, H, W).
Returns:
torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W).
"""
# Compute embeddings from all modules and stack: (N, B, D, H, W)
pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0)
# Compute combination weights using softmax
# Use registered parameter name 'combination'
# Reshape weights for broadcasting: (N,) -> (N, 1, 1, 1, 1)
weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1)
# Compute weighted sum over the N dimension
combined_emb = (pos_embs * weights).sum(0) # (B, D, H, W)
return combined_emb
class CustomRotationalEmbedding(nn.Module):
"""
Custom Rotational Positional Embedding.
Generates 2D positional embeddings based on rotating a fixed start vector.
The rotation angle for each grid position is determined primarily by its
horizontal position (width dimension). The resulting rotated vectors are
concatenated and projected.
Note: The current implementation derives angles only from the width dimension (`x.size(-1)`).
Args:
d_model (int): Dimensionality of the output embeddings.
"""
def __init__(self, d_model):
super(CustomRotationalEmbedding, self).__init__()
# Learnable 2D start vector
self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True))
# Projects the 4D concatenated rotated vectors to d_model
# Input size 4 comes from concatenating two 2D rotated vectors
self.projection = nn.Sequential(nn.Linear(4, d_model))
def forward(self, x):
"""
Computes rotational positional embeddings based on input width.
Args:
x (torch.Tensor): Input tensor (used for shape and device),
shape (batch_size, channels, height, width).
Returns:
Output tensor containing positional embeddings,
shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all.
"""
B, C, H, W = x.shape
device = x.device
# --- Generate rotations based only on Width ---
# Angles derived from width dimension
theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) # Angle per column
cos_theta = torch.cos(theta_rad)
sin_theta = torch.sin(theta_rad)
# Create rotation matrices: Shape (W, 2, 2)
# Use unsqueeze(1) to allow stacking along dim 1
rotation_matrices = torch.stack([
torch.stack([cos_theta, -sin_theta], dim=-1), # Shape (W, 2)
torch.stack([sin_theta, cos_theta], dim=-1) # Shape (W, 2)
], dim=1) # Stacks along dim 1 -> Shape (W, 2, 2)
# Rotate the start vector by column angle: Shape (W, 2)
rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector)
# --- Create Grid Key ---
# Original code uses repeats based on rotated_vectors.shape[0] (which is W) for both dimensions.
# This creates a (W, W, 4) key tensor.
key = torch.cat((
torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), # (W, 1, 2) -> (W, W, 2)
torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0) # (1, W, 2) -> (W, W, 2)
), dim=-1) # Shape (W, W, 4)
# Project the 4D key vector to d_model: Shape (W, W, d_model)
pe_grid = self.projection(key)
# Reshape to (1, d_model, W, W) and then select/resize to target H, W?
# Original code permutes to (d_model, W, W) and unsqueezes to (1, d_model, W, W)
pe = pe_grid.permute(2, 0, 1).unsqueeze(0)
# If H != W, this needs adjustment. Assuming H=W or cropping/padding happens later.
# Let's return the (1, d_model, W, W) tensor as generated by the original logic.
# If H != W, downstream code must handle the mismatch or this PE needs modification.
if H != W:
# Simple interpolation/cropping could be added, but sticking to original logic:
# Option 1: Interpolate
# pe = F.interpolate(pe, size=(H, W), mode='bilinear', align_corners=False)
# Option 2: Crop/Pad (e.g., crop if W > W_target, pad if W < W_target)
# Sticking to original: return shape (1, d_model, W, W)
pass
return pe
class CustomRotationalEmbedding1D(nn.Module):
def __init__(self, d_model):
super(CustomRotationalEmbedding1D, self).__init__()
self.projection = nn.Linear(2, d_model)
def forward(self, x):
start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float)
theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device))
cos_theta = torch.cos(theta_rad)
sin_theta = torch.sin(theta_rad)
cos_theta = cos_theta.unsqueeze(1) # Shape: (height, 1)
sin_theta = sin_theta.unsqueeze(1) # Shape: (height, 1)
# Create rotation matrices
rotation_matrices = torch.stack([
torch.cat([cos_theta, -sin_theta], dim=1),
torch.cat([sin_theta, cos_theta], dim=1)
], dim=1) # Shape: (height, 2, 2)
# Rotate the start vector
rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector)
pe = self.projection(rotated_vectors)
pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0)
return pe.transpose(1, 2) # Transpose for compatibility with other backbones