File size: 28,890 Bytes
68b32f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 |
import torch
import torch.nn as nn
import torch.nn.functional as F # Used for GLU
import math
import numpy as np
# Assuming 'add_coord_dim' is defined in models.utils
from models.utils import add_coord_dim
# --- Basic Utility Modules ---
class Identity(nn.Module):
"""
Identity Module.
Returns the input tensor unchanged. Useful as a placeholder or a no-op layer
in nn.Sequential containers or conditional network parts.
"""
def __init__(self):
super().__init__()
def forward(self, x):
return x
class Squeeze(nn.Module):
"""
Squeeze Module.
Removes a specified dimension of size 1 from the input tensor.
Useful for incorporating tensor dimension squeezing within nn.Sequential.
Args:
dim (int): The dimension to squeeze.
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
return x.squeeze(self.dim)
# --- Core CTM Component Modules ---
class SynapseUNET(nn.Module):
"""
UNET-style architecture for the Synapse Model (f_theta1 in the paper).
This module implements the connections between neurons in the CTM's latent
space. It processes the combined input (previous post-activation state z^t
and attention output o^t) to produce the pre-activations (a^t) for the
next internal tick (Eq. 1 in the paper).
While a simpler Linear or MLP layer can be used, the paper notes
that this U-Net structure empirically performed better, suggesting benefit
from more flexible synaptic connections[cite: 79, 80]. This implementation
uses `depth` points in linspace and creates `depth-1` down/up blocks.
Args:
in_dims (int): Number of input dimensions (d_model + d_input).
out_dims (int): Number of output dimensions (d_model).
depth (int): Determines structure size; creates `depth-1` down/up blocks.
minimum_width (int): Smallest channel width at the U-Net bottleneck.
dropout (float): Dropout rate applied within down/up projections.
"""
def __init__(self,
out_dims,
depth,
minimum_width=16,
dropout=0.0):
super().__init__()
self.width_out = out_dims
self.n_deep = depth # Store depth just for reference if needed
# Define UNET structure based on depth
# Creates `depth` width values, leading to `depth-1` blocks
widths = np.linspace(out_dims, minimum_width, depth)
# Initial projection layer
self.first_projection = nn.Sequential(
nn.LazyLinear(int(widths[0])), # Project to the first width
nn.LayerNorm(int(widths[0])),
nn.SiLU()
)
# Downward path (encoding layers)
self.down_projections = nn.ModuleList()
self.up_projections = nn.ModuleList()
self.skip_lns = nn.ModuleList()
num_blocks = len(widths) - 1 # Number of down/up blocks created
for i in range(num_blocks):
# Down block: widths[i] -> widths[i+1]
self.down_projections.append(nn.Sequential(
nn.Dropout(dropout),
nn.Linear(int(widths[i]), int(widths[i+1])),
nn.LayerNorm(int(widths[i+1])),
nn.SiLU()
))
# Up block: widths[i+1] -> widths[i]
# Note: Up blocks are added in order matching down blocks conceptually,
# but applied in reverse order in the forward pass.
self.up_projections.append(nn.Sequential(
nn.Dropout(dropout),
nn.Linear(int(widths[i+1]), int(widths[i])),
nn.LayerNorm(int(widths[i])),
nn.SiLU()
))
# Skip connection LayerNorm operates on width[i]
self.skip_lns.append(nn.LayerNorm(int(widths[i])))
def forward(self, x):
# Initial projection
out_first = self.first_projection(x)
# Downward path, storing outputs for skip connections
outs_down = [out_first]
for layer in self.down_projections:
outs_down.append(layer(outs_down[-1]))
# outs_down contains [level_0, level_1, ..., level_depth-1=bottleneck] outputs
# Upward path, starting from the bottleneck output
outs_up = outs_down[-1] # Bottleneck activation
num_blocks = len(self.up_projections) # Should be depth - 1
for i in range(num_blocks):
# Apply up projection in reverse order relative to down blocks
# up_projection[num_blocks - 1 - i] processes deeper features first
up_layer_idx = num_blocks - 1 - i
out_up = self.up_projections[up_layer_idx](outs_up)
# Get corresponding skip connection from downward path
# skip_connection index = num_blocks - 1 - i (same as up_layer_idx)
# This matches the output width of the up_projection[up_layer_idx]
skip_idx = up_layer_idx
skip_connection = outs_down[skip_idx]
# Add skip connection and apply LayerNorm corresponding to this level
# skip_lns index also corresponds to the level = skip_idx
outs_up = self.skip_lns[skip_idx](out_up + skip_connection)
# The final output after all up-projections
return outs_up
class SuperLinear(nn.Module):
"""
SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.
This layer is the core component enabling Neuron-Level Models (NLMs),
referred to as g_theta_d in the paper (Eq. 3). It applies N independent
linear transformations (or small MLPs when used sequentially) to corresponding
slices of the input tensor along a specified dimension (typically the neuron
or feature dimension).
How it works for NLMs:
- The input `x` is expected to be the pre-activation history for each neuron,
shaped (batch_size, n_neurons=N, history_length=in_dims).
- This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons.
`w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims).
- `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix
multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`):
- For each neuron `n` (from 0 to N-1):
- It takes the neuron's history `x[:, n, :]` (shape B, in_dims).
- Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims).
- Resulting in `out[:, n, :]` (shape B, out_dims).
- The unique bias `self.b1[:, n, :]` is added.
- The result is squeezed on the last dim (if out_dims=1) and scaled by `T`.
This allows each neuron `d` to process its temporal history `A_d^t` using
its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`,
enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85].
It's typically used within the `trace_processor` module of the main CTM class.
Args:
in_dims (int): Input dimension (typically `memory_length`).
out_dims (int): Output dimension per neuron.
N (int): Number of independent linear models (typically `d_model`).
T (float): Initial value for learnable temperature/scaling factor applied to output.
do_norm (bool): Apply Layer Normalization to the input history before linear transform.
dropout (float): Dropout rate applied to the input.
"""
def __init__(self,
in_dims,
out_dims,
N,
T=1.0,
do_norm=False,
dropout=0):
super().__init__()
# N is the number of neurons (d_model), in_dims is the history length (memory_length)
self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
self.in_dims = in_dims # Corresponds to memory_length
# LayerNorm applied across the history dimension for each neuron independently
self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity()
self.do_norm = do_norm
# Initialize weights and biases
# w1 shape: (memory_length, out_dims, d_model)
self.register_parameter('w1', nn.Parameter(
torch.empty((in_dims, out_dims, N)).uniform_(
-1/math.sqrt(in_dims + out_dims),
1/math.sqrt(in_dims + out_dims)
), requires_grad=True)
)
# b1 shape: (1, d_model, out_dims)
self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
# Learnable temperature/scaler T
self.register_parameter('T', nn.Parameter(torch.Tensor([T])))
def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor, expected shape (B, N, in_dims)
where B=batch, N=d_model, in_dims=memory_length.
Returns:
torch.Tensor: Output tensor, shape (B, N) after squeeze(-1).
"""
# Input shape: (B, D, M) where D=d_model=N neurons in CTM, M=history/memory length
out = self.dropout(x)
# LayerNorm across the memory_length dimension (dim=-1)
out = self.layernorm(out) # Shape remains (B, N, M)
# Apply N independent linear models using einsum
# einsum('BDM,MHD->BDH', ...)
# x: (B=batch size, D=N neurons, one NLM per each of these, M=history/memory length)
# w1: (M, H=hidden dims if using MLP, otherwise output, D=N neurons, parallel)
# b1: (1, D=N neurons, H)
# einsum result: (B, D, H)
# Applying bias requires matching shapes, b1 is broadcasted.
out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1
# Squeeze the output dimension (assumed to be 1 usually) and scale by T
# This matches the original code's structure exactly.
out = out.squeeze(-1) / self.T
return out
# --- Backbone Modules ---
class ParityBackbone(nn.Module):
def __init__(self, n_embeddings, d_embedding):
super(ParityBackbone, self).__init__()
self.embedding = nn.Embedding(n_embeddings, d_embedding)
def forward(self, x):
"""
Maps -1 (negative parity) to 0 and 1 (positive) to 1
"""
x = (x == 1).long()
return self.embedding(x.long()).transpose(1, 2) # Transpose for compatibility with other backbones
class QAMNISTOperatorEmbeddings(nn.Module):
def __init__(self, num_operator_types, d_projection):
super(QAMNISTOperatorEmbeddings, self).__init__()
self.embedding = nn.Embedding(num_operator_types, d_projection)
def forward(self, x):
# -1 for plus and -2 for minus
return self.embedding(-x - 1)
class QAMNISTIndexEmbeddings(torch.nn.Module):
def __init__(self, max_seq_length, embedding_dim):
super().__init__()
self.max_seq_length = max_seq_length
self.embedding_dim = embedding_dim
embedding = torch.zeros(max_seq_length, embedding_dim)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
embedding[:, 0::2] = torch.sin(position * div_term)
embedding[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('embedding', embedding)
def forward(self, x):
return self.embedding[x]
class ThoughtSteps:
"""
Helper class for managing "thought steps" in the ctm_qamnist pipeline.
Args:
iterations_per_digit (int): Number of iterations for each digit.
iterations_per_question_part (int): Number of iterations for each question part.
total_iterations_for_answering (int): Total number of iterations for answering.
total_iterations_for_digits (int): Total number of iterations for digits.
total_iterations_for_question (int): Total number of iterations for question.
"""
def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question):
self.iterations_per_digit = iterations_per_digit
self.iterations_per_question_part = iterations_per_question_part
self.total_iterations_for_digits = total_iterations_for_digits
self.total_iterations_for_question = total_iterations_for_question
self.total_iterations_for_answering = total_iterations_for_answering
self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering
def determine_step_type(self, stepi: int):
is_digit_step = stepi < self.total_iterations_for_digits
is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question
is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question
return is_digit_step, is_question_step, is_answer_step
def determine_answer_step_type(self, stepi: int):
step_within_questions = stepi - self.total_iterations_for_digits
if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
is_index_step = True
is_operator_step = False
else:
is_index_step = False
is_operator_step = True
return is_index_step, is_operator_step
class MNISTBackbone(nn.Module):
"""
Simple backbone for MNIST feature extraction.
"""
def __init__(self, d_input):
super(MNISTBackbone, self).__init__()
self.layers = nn.Sequential(
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(d_input),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(d_input),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
def forward(self, x):
return self.layers(x)
class MiniGridBackbone(nn.Module):
def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8):
super().__init__()
self.object_embedding = nn.Embedding(num_objects, embedding_dim)
self.color_embedding = nn.Embedding(num_colors, embedding_dim)
self.state_embedding = nn.Embedding(num_states, embedding_dim)
self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim)
self.project_to_d_projection = nn.Sequential(
nn.Linear(embedding_dim * 4, d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input),
nn.Linear(d_input, d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input)
)
def forward(self, x):
x = x.long()
B, H, W, C = x.size()
object_idx = x[:,:,:, 0]
color_idx = x[:,:,:, 1]
state_idx = x[:,:,:, 2]
obj_embed = self.object_embedding(object_idx)
color_embed = self.color_embedding(color_idx)
state_embed = self.state_embedding(state_idx)
pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1)
pos_embed = self.position_embedding(pos_idx)
out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1))
return out
class ClassicControlBackbone(nn.Module):
def __init__(self, d_input):
super().__init__()
self.input_projector = nn.Sequential(
nn.Flatten(),
nn.LazyLinear(d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input),
nn.LazyLinear(d_input * 2),
nn.GLU(),
nn.LayerNorm(d_input)
)
def forward(self, x):
return self.input_projector(x)
class ShallowWide(nn.Module):
"""
Simple, wide, shallow convolutional backbone for image feature extraction.
Alternative to ResNet, uses grouped convolutions and GLU activations.
Fixed structure, useful for specific experiments.
"""
def __init__(self):
super(ShallowWide, self).__init__()
# LazyConv2d infers input channels
self.layers = nn.Sequential(
nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), # Output channels = 4096
nn.GLU(dim=1), # Halves channels to 2048
nn.BatchNorm2d(2048),
# Grouped convolution maintains width but processes groups independently
nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32),
nn.GLU(dim=1), # Halves channels to 2048
nn.BatchNorm2d(2048)
)
def forward(self, x):
return self.layers(x)
class PretrainedResNetWrapper(nn.Module):
"""
Wrapper to use standard pre-trained ResNet models from torchvision.
Loads a specified ResNet architecture pre-trained on ImageNet, removes the
final classification layer (fc), average pooling, and optionally later layers
(e.g., layer4), allowing it to be used as a feature extractor backbone.
Args:
resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50').
fine_tune (bool): If False, freezes the weights of the pre-trained backbone.
"""
def __init__(self, resnet_type, fine_tune=True):
super(PretrainedResNetWrapper, self).__init__()
self.resnet_type = resnet_type
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True)
if not fine_tune:
for param in self.backbone.parameters():
param.requires_grad = False
# Remove final layers to use as feature extractor
self.backbone.avgpool = Identity()
self.backbone.fc = Identity()
# Keep layer4 by default, user can modify instance if needed
# self.backbone.layer4 = Identity()
def forward(self, x):
# Get features from the modified ResNet
out = self.backbone(x)
# Reshape output to (B, C, H, W) - This is heuristic based on original comment.
# User might need to adjust this based on which layers are kept/removed.
# Infer C based on ResNet type (example values)
nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 # Approx for layer3/4 output channel numbers
# Infer H, W assuming output is flattened C * H * W
num_features = out.shape[-1]
# This calculation assumes nc is correct and feature map is square
wh_squared = num_features / nc
if wh_squared < 0 or not float(wh_squared).is_integer():
print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}")
# Return potentially flattened features if reshape fails
return out
wh = int(np.sqrt(wh_squared))
return out.reshape(x.size(0), nc, wh, wh)
# --- Positional Encoding Modules ---
class LearnableFourierPositionalEncoding(nn.Module):
"""
Learnable Fourier Feature Positional Encoding.
Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
Provides positional information for 2D feature maps.
Args:
d_model (int): The output dimension of the positional encoding (D).
G (int): Positional groups (default 1).
M (int): Dimensionality of input coordinates (default 2 for H, W).
F_dim (int): Dimension of the Fourier features.
H_dim (int): Hidden dimension of the MLP.
gamma (float): Initialization scale for the Fourier projection weights (Wr).
"""
def __init__(self, d_model,
G=1, M=2,
F_dim=256,
H_dim=128,
gamma=1/2.5,
):
super().__init__()
self.G = G
self.M = M
self.F_dim = F_dim
self.H_dim = H_dim
self.D = d_model
self.gamma = gamma
self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
self.mlp = nn.Sequential(
nn.Linear(self.F_dim, self.H_dim, bias=True),
nn.GLU(), # Halves H_dim
nn.Linear(self.H_dim // 2, self.D // self.G),
nn.LayerNorm(self.D // self.G)
)
self.init_weights()
def init_weights(self):
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)
def forward(self, x):
"""
Computes positional encodings for the input feature map x.
Args:
x (torch.Tensor): Input feature map, shape (B, C, H, W).
Returns:
torch.Tensor: Positional encoding tensor, shape (B, D, H, W).
"""
B, C, H, W = x.shape
# Creates coordinates based on (H, W) and repeats for batch B.
# Takes x[:,0] assuming channel dim isn't needed for coords.
x_coord = add_coord_dim(x[:,0]) # Expects (B, H, W) -> (B, H, W, 2)
# Compute Fourier features
projected = self.Wr(x_coord) # (B, H, W, F_dim // 2)
cosines = torch.cos(projected)
sines = torch.sin(projected)
F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) # (B, H, W, F_dim)
# Project features through MLP
Y = self.mlp(F) # (B, H, W, D // G)
# Reshape to (B, D, H, W)
PEx = Y.permute(0, 3, 1, 2) # Assuming G=1
return PEx
class MultiLearnableFourierPositionalEncoding(nn.Module):
"""
Combines multiple LearnableFourierPositionalEncoding modules with different
initialization scales (gamma) via a learnable weighted sum.
Allows the model to learn an optimal combination of positional frequencies.
Args:
d_model (int): Output dimension of the encoding.
G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding.
gamma_range (list[float]): Min and max gamma values for the linspace.
N (int): Number of parallel embedding modules to create.
"""
def __init__(self, d_model,
G=1, M=2,
F_dim=256,
H_dim=128,
gamma_range=[1.0, 0.1], # Default range
N=10,
):
super().__init__()
self.embedders = nn.ModuleList()
for gamma in np.linspace(gamma_range[0], gamma_range[1], N):
self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma))
# Renamed parameter from 'combination' to 'combination_weights' for clarity only in comments
# Actual registered name remains 'combination' as in original code
self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True))
self.N = N
def forward(self, x):
"""
Computes combined positional encoding.
Args:
x (torch.Tensor): Input feature map, shape (B, C, H, W).
Returns:
torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W).
"""
# Compute embeddings from all modules and stack: (N, B, D, H, W)
pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0)
# Compute combination weights using softmax
# Use registered parameter name 'combination'
# Reshape weights for broadcasting: (N,) -> (N, 1, 1, 1, 1)
weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1)
# Compute weighted sum over the N dimension
combined_emb = (pos_embs * weights).sum(0) # (B, D, H, W)
return combined_emb
class CustomRotationalEmbedding(nn.Module):
"""
Custom Rotational Positional Embedding.
Generates 2D positional embeddings based on rotating a fixed start vector.
The rotation angle for each grid position is determined primarily by its
horizontal position (width dimension). The resulting rotated vectors are
concatenated and projected.
Note: The current implementation derives angles only from the width dimension (`x.size(-1)`).
Args:
d_model (int): Dimensionality of the output embeddings.
"""
def __init__(self, d_model):
super(CustomRotationalEmbedding, self).__init__()
# Learnable 2D start vector
self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True))
# Projects the 4D concatenated rotated vectors to d_model
# Input size 4 comes from concatenating two 2D rotated vectors
self.projection = nn.Sequential(nn.Linear(4, d_model))
def forward(self, x):
"""
Computes rotational positional embeddings based on input width.
Args:
x (torch.Tensor): Input tensor (used for shape and device),
shape (batch_size, channels, height, width).
Returns:
Output tensor containing positional embeddings,
shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all.
"""
B, C, H, W = x.shape
device = x.device
# --- Generate rotations based only on Width ---
# Angles derived from width dimension
theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) # Angle per column
cos_theta = torch.cos(theta_rad)
sin_theta = torch.sin(theta_rad)
# Create rotation matrices: Shape (W, 2, 2)
# Use unsqueeze(1) to allow stacking along dim 1
rotation_matrices = torch.stack([
torch.stack([cos_theta, -sin_theta], dim=-1), # Shape (W, 2)
torch.stack([sin_theta, cos_theta], dim=-1) # Shape (W, 2)
], dim=1) # Stacks along dim 1 -> Shape (W, 2, 2)
# Rotate the start vector by column angle: Shape (W, 2)
rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector)
# --- Create Grid Key ---
# Original code uses repeats based on rotated_vectors.shape[0] (which is W) for both dimensions.
# This creates a (W, W, 4) key tensor.
key = torch.cat((
torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), # (W, 1, 2) -> (W, W, 2)
torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0) # (1, W, 2) -> (W, W, 2)
), dim=-1) # Shape (W, W, 4)
# Project the 4D key vector to d_model: Shape (W, W, d_model)
pe_grid = self.projection(key)
# Reshape to (1, d_model, W, W) and then select/resize to target H, W?
# Original code permutes to (d_model, W, W) and unsqueezes to (1, d_model, W, W)
pe = pe_grid.permute(2, 0, 1).unsqueeze(0)
# If H != W, this needs adjustment. Assuming H=W or cropping/padding happens later.
# Let's return the (1, d_model, W, W) tensor as generated by the original logic.
# If H != W, downstream code must handle the mismatch or this PE needs modification.
if H != W:
# Simple interpolation/cropping could be added, but sticking to original logic:
# Option 1: Interpolate
# pe = F.interpolate(pe, size=(H, W), mode='bilinear', align_corners=False)
# Option 2: Crop/Pad (e.g., crop if W > W_target, pad if W < W_target)
# Sticking to original: return shape (1, d_model, W, W)
pass
return pe
class CustomRotationalEmbedding1D(nn.Module):
def __init__(self, d_model):
super(CustomRotationalEmbedding1D, self).__init__()
self.projection = nn.Linear(2, d_model)
def forward(self, x):
start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float)
theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device))
cos_theta = torch.cos(theta_rad)
sin_theta = torch.sin(theta_rad)
cos_theta = cos_theta.unsqueeze(1) # Shape: (height, 1)
sin_theta = sin_theta.unsqueeze(1) # Shape: (height, 1)
# Create rotation matrices
rotation_matrices = torch.stack([
torch.cat([cos_theta, -sin_theta], dim=1),
torch.cat([sin_theta, cos_theta], dim=1)
], dim=1) # Shape: (height, 2, 2)
# Rotate the start vector
rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector)
pe = self.projection(rotated_vectors)
pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0)
return pe.transpose(1, 2) # Transpose for compatibility with other backbones
|