File size: 28,890 Bytes
68b32f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
import torch
import torch.nn as nn
import torch.nn.functional as F # Used for GLU
import math
import numpy as np

# Assuming 'add_coord_dim' is defined in models.utils
from models.utils import add_coord_dim

# --- Basic Utility Modules ---

class Identity(nn.Module):
    """
    Identity Module.

    Returns the input tensor unchanged. Useful as a placeholder or a no-op layer
    in nn.Sequential containers or conditional network parts.
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class Squeeze(nn.Module):
    """
    Squeeze Module.

    Removes a specified dimension of size 1 from the input tensor.
    Useful for incorporating tensor dimension squeezing within nn.Sequential.

    Args:
      dim (int): The dimension to squeeze.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        return x.squeeze(self.dim)

# --- Core CTM Component Modules ---

class SynapseUNET(nn.Module):
    """
    UNET-style architecture for the Synapse Model (f_theta1 in the paper).

    This module implements the connections between neurons in the CTM's latent
    space. It processes the combined input (previous post-activation state z^t
    and attention output o^t) to produce the pre-activations (a^t) for the
    next internal tick (Eq. 1 in the paper).

    While a simpler Linear or MLP layer can be used, the paper notes
    that this U-Net structure empirically performed better, suggesting benefit
    from more flexible synaptic connections[cite: 79, 80]. This implementation
    uses `depth` points in linspace and creates `depth-1` down/up blocks.

    Args:
      in_dims (int): Number of input dimensions (d_model + d_input).
      out_dims (int): Number of output dimensions (d_model).
      depth (int): Determines structure size; creates `depth-1` down/up blocks.
      minimum_width (int): Smallest channel width at the U-Net bottleneck.
      dropout (float): Dropout rate applied within down/up projections.
    """
    def __init__(self,
                 out_dims,
                 depth,
                 minimum_width=16,
                 dropout=0.0):
        super().__init__()
        self.width_out = out_dims
        self.n_deep = depth # Store depth just for reference if needed

        # Define UNET structure based on depth
        # Creates `depth` width values, leading to `depth-1` blocks
        widths = np.linspace(out_dims, minimum_width, depth)

        # Initial projection layer
        self.first_projection = nn.Sequential(
            nn.LazyLinear(int(widths[0])), # Project to the first width
            nn.LayerNorm(int(widths[0])),
            nn.SiLU()
        )

        # Downward path (encoding layers)
        self.down_projections = nn.ModuleList()
        self.up_projections = nn.ModuleList()
        self.skip_lns = nn.ModuleList()
        num_blocks = len(widths) - 1 # Number of down/up blocks created

        for i in range(num_blocks):
            # Down block: widths[i] -> widths[i+1]
            self.down_projections.append(nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(int(widths[i]), int(widths[i+1])),
                nn.LayerNorm(int(widths[i+1])),
                nn.SiLU()
            ))
            # Up block: widths[i+1] -> widths[i]
            # Note: Up blocks are added in order matching down blocks conceptually,
            # but applied in reverse order in the forward pass.
            self.up_projections.append(nn.Sequential(
                nn.Dropout(dropout),
                nn.Linear(int(widths[i+1]), int(widths[i])),
                nn.LayerNorm(int(widths[i])),
                nn.SiLU()
            ))
            # Skip connection LayerNorm operates on width[i]
            self.skip_lns.append(nn.LayerNorm(int(widths[i])))

    def forward(self, x):
        # Initial projection
        out_first = self.first_projection(x)

        # Downward path, storing outputs for skip connections
        outs_down = [out_first]
        for layer in self.down_projections:
            outs_down.append(layer(outs_down[-1]))
        # outs_down contains [level_0, level_1, ..., level_depth-1=bottleneck] outputs

        # Upward path, starting from the bottleneck output
        outs_up = outs_down[-1] # Bottleneck activation
        num_blocks = len(self.up_projections) # Should be depth - 1

        for i in range(num_blocks):
            # Apply up projection in reverse order relative to down blocks
            # up_projection[num_blocks - 1 - i] processes deeper features first
            up_layer_idx = num_blocks - 1 - i
            out_up = self.up_projections[up_layer_idx](outs_up)

            # Get corresponding skip connection from downward path
            # skip_connection index = num_blocks - 1 - i (same as up_layer_idx)
            # This matches the output width of the up_projection[up_layer_idx]
            skip_idx = up_layer_idx
            skip_connection = outs_down[skip_idx]

            # Add skip connection and apply LayerNorm corresponding to this level
            # skip_lns index also corresponds to the level = skip_idx
            outs_up = self.skip_lns[skip_idx](out_up + skip_connection)

        # The final output after all up-projections
        return outs_up


class SuperLinear(nn.Module):
    """
    SuperLinear Layer: Implements Neuron-Level Models (NLMs) for the CTM.

    This layer is the core component enabling Neuron-Level Models (NLMs),
    referred to as g_theta_d in the paper (Eq. 3). It applies N independent
    linear transformations (or small MLPs when used sequentially) to corresponding
    slices of the input tensor along a specified dimension (typically the neuron
    or feature dimension).

    How it works for NLMs:
    - The input `x` is expected to be the pre-activation history for each neuron,
      shaped (batch_size, n_neurons=N, history_length=in_dims).
    - This layer holds unique weights (`w1`) and biases (`b1`) for *each* of the `N` neurons.
      `w1` has shape (in_dims, out_dims, N), `b1` has shape (1, N, out_dims).
    - `torch.einsum('bni,iog->bno', x, self.w1)` performs N independent matrix
      multiplications in parallel (mapping from dim `i` to `o` for each neuron `n`):
        - For each neuron `n` (from 0 to N-1):
        - It takes the neuron's history `x[:, n, :]` (shape B, in_dims).
        - Multiplies it by the neuron's unique weight matrix `self.w1[:, :, n]` (shape in_dims, out_dims).
        - Resulting in `out[:, n, :]` (shape B, out_dims).
    - The unique bias `self.b1[:, n, :]` is added.
    - The result is squeezed on the last dim (if out_dims=1) and scaled by `T`.

    This allows each neuron `d` to process its temporal history `A_d^t` using
    its private parameters `theta_d` to produce the post-activation `z_d^{t+1}`,
    enabling the fine-grained temporal dynamics central to the CTM[cite: 7, 30, 85].
    It's typically used within the `trace_processor` module of the main CTM class.

    Args:
      in_dims (int): Input dimension (typically `memory_length`).
      out_dims (int): Output dimension per neuron.
      N (int): Number of independent linear models (typically `d_model`).
      T (float): Initial value for learnable temperature/scaling factor applied to output.
      do_norm (bool): Apply Layer Normalization to the input history before linear transform.
      dropout (float): Dropout rate applied to the input.
    """
    def __init__(self,
                 in_dims,
                 out_dims,
                 N,
                 T=1.0,
                 do_norm=False,
                 dropout=0):
        super().__init__()
        # N is the number of neurons (d_model), in_dims is the history length (memory_length)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else Identity()
        self.in_dims = in_dims # Corresponds to memory_length
        # LayerNorm applied across the history dimension for each neuron independently
        self.layernorm = nn.LayerNorm(in_dims, elementwise_affine=True) if do_norm else Identity()
        self.do_norm = do_norm

        # Initialize weights and biases
        # w1 shape: (memory_length, out_dims, d_model)
        self.register_parameter('w1', nn.Parameter(
            torch.empty((in_dims, out_dims, N)).uniform_(
                -1/math.sqrt(in_dims + out_dims),
                 1/math.sqrt(in_dims + out_dims)
            ), requires_grad=True)
        )
        # b1 shape: (1, d_model, out_dims)
        self.register_parameter('b1', nn.Parameter(torch.zeros((1, N, out_dims)), requires_grad=True))
        # Learnable temperature/scaler T
        self.register_parameter('T', nn.Parameter(torch.Tensor([T]))) 

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor, expected shape (B, N, in_dims)
                              where B=batch, N=d_model, in_dims=memory_length.
        Returns:
            torch.Tensor: Output tensor, shape (B, N) after squeeze(-1).
        """
        # Input shape: (B, D, M) where D=d_model=N neurons in CTM, M=history/memory length
        out = self.dropout(x)
        # LayerNorm across the memory_length dimension (dim=-1)
        out = self.layernorm(out) # Shape remains (B, N, M)

        # Apply N independent linear models using einsum
        # einsum('BDM,MHD->BDH', ...)
        # x: (B=batch size, D=N neurons, one NLM per each of these, M=history/memory length)
        # w1: (M, H=hidden dims if using MLP, otherwise output, D=N neurons, parallel)
        # b1: (1, D=N neurons, H)
        # einsum result: (B, D, H)
        # Applying bias requires matching shapes, b1 is broadcasted.
        out = torch.einsum('BDM,MHD->BDH', out, self.w1) + self.b1

        # Squeeze the output dimension (assumed to be 1 usually) and scale by T
        # This matches the original code's structure exactly.
        out = out.squeeze(-1) / self.T
        return out


# --- Backbone Modules ---

class ParityBackbone(nn.Module):
    def __init__(self, n_embeddings, d_embedding):
        super(ParityBackbone, self).__init__()
        self.embedding = nn.Embedding(n_embeddings, d_embedding)

    def forward(self, x):
        """
        Maps -1 (negative parity) to 0 and 1 (positive) to 1
        """
        x = (x == 1).long() 
        return self.embedding(x.long()).transpose(1, 2) # Transpose for compatibility with other backbones

class QAMNISTOperatorEmbeddings(nn.Module):
    def __init__(self, num_operator_types, d_projection):
        super(QAMNISTOperatorEmbeddings, self).__init__()
        self.embedding = nn.Embedding(num_operator_types, d_projection)

    def forward(self, x):
        # -1 for plus and -2 for minus
        return self.embedding(-x - 1)

class QAMNISTIndexEmbeddings(torch.nn.Module):
    def __init__(self, max_seq_length, embedding_dim):
        super().__init__()
        self.max_seq_length = max_seq_length
        self.embedding_dim = embedding_dim

        embedding = torch.zeros(max_seq_length, embedding_dim)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        
        embedding[:, 0::2] = torch.sin(position * div_term)
        embedding[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('embedding', embedding)

    def forward(self, x):
        return self.embedding[x]
    
class ThoughtSteps:
    """
    Helper class for managing "thought steps" in the ctm_qamnist pipeline.

    Args:
        iterations_per_digit (int): Number of iterations for each digit.
        iterations_per_question_part (int): Number of iterations for each question part.
        total_iterations_for_answering (int): Total number of iterations for answering.
        total_iterations_for_digits (int): Total number of iterations for digits.
        total_iterations_for_question (int): Total number of iterations for question.
    """
    def __init__(self, iterations_per_digit, iterations_per_question_part, total_iterations_for_answering, total_iterations_for_digits, total_iterations_for_question):
        self.iterations_per_digit = iterations_per_digit
        self.iterations_per_question_part = iterations_per_question_part
        self.total_iterations_for_digits = total_iterations_for_digits
        self.total_iterations_for_question = total_iterations_for_question
        self.total_iterations_for_answering = total_iterations_for_answering
        self.total_iterations = self.total_iterations_for_digits + self.total_iterations_for_question + self.total_iterations_for_answering

    def determine_step_type(self, stepi: int):
        is_digit_step = stepi < self.total_iterations_for_digits
        is_question_step = self.total_iterations_for_digits <= stepi < self.total_iterations_for_digits + self.total_iterations_for_question
        is_answer_step = stepi >= self.total_iterations_for_digits + self.total_iterations_for_question
        return is_digit_step, is_question_step, is_answer_step

    def determine_answer_step_type(self, stepi: int):
        step_within_questions = stepi - self.total_iterations_for_digits
        if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part:
            is_index_step = True
            is_operator_step = False
        else:
            is_index_step = False
            is_operator_step = True
        return is_index_step, is_operator_step

class MNISTBackbone(nn.Module):
    """
    Simple backbone for MNIST feature extraction.
    """
    def __init__(self, d_input):
        super(MNISTBackbone, self).__init__()
        self.layers = nn.Sequential(
            nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(d_input),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.LazyConv2d(d_input, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(d_input),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )

    def forward(self, x):
        return self.layers(x)


class MiniGridBackbone(nn.Module):
    def __init__(self, d_input, grid_size=7, num_objects=11, num_colors=6, num_states=3, embedding_dim=8):
        super().__init__()
        self.object_embedding = nn.Embedding(num_objects, embedding_dim)
        self.color_embedding = nn.Embedding(num_colors, embedding_dim)
        self.state_embedding = nn.Embedding(num_states, embedding_dim)
        
        self.position_embedding = nn.Embedding(grid_size * grid_size, embedding_dim)

        self.project_to_d_projection = nn.Sequential(
            nn.Linear(embedding_dim * 4, d_input * 2),
            nn.GLU(),
            nn.LayerNorm(d_input),
            nn.Linear(d_input, d_input * 2),
            nn.GLU(),
            nn.LayerNorm(d_input)
        )

    def forward(self, x):
        x = x.long()
        B, H, W, C = x.size()

        object_idx = x[:,:,:, 0]
        color_idx =  x[:,:,:, 1]
        state_idx =  x[:,:,:, 2]

        obj_embed = self.object_embedding(object_idx)
        color_embed = self.color_embedding(color_idx)
        state_embed = self.state_embedding(state_idx)
        
        pos_idx = torch.arange(H * W, device=x.device).view(1, H, W).expand(B, -1, -1)
        pos_embed = self.position_embedding(pos_idx)

        out = self.project_to_d_projection(torch.cat([obj_embed, color_embed, state_embed, pos_embed], dim=-1))
        return out

class ClassicControlBackbone(nn.Module):
    def __init__(self, d_input):
        super().__init__()
        self.input_projector = nn.Sequential(
            nn.Flatten(),
            nn.LazyLinear(d_input * 2),
            nn.GLU(),
            nn.LayerNorm(d_input),
            nn.LazyLinear(d_input * 2),
            nn.GLU(),
            nn.LayerNorm(d_input)
        )

    def forward(self, x):
        return self.input_projector(x)


class ShallowWide(nn.Module):
    """
    Simple, wide, shallow convolutional backbone for image feature extraction.

    Alternative to ResNet, uses grouped convolutions and GLU activations.
    Fixed structure, useful for specific experiments.
    """
    def __init__(self):
        super(ShallowWide, self).__init__()
        # LazyConv2d infers input channels
        self.layers = nn.Sequential(
            nn.LazyConv2d(4096, kernel_size=3, stride=2, padding=1), # Output channels = 4096
            nn.GLU(dim=1), # Halves channels to 2048
            nn.BatchNorm2d(2048),
            # Grouped convolution maintains width but processes groups independently
            nn.Conv2d(2048, 4096, kernel_size=3, stride=1, padding=1, groups=32),
            nn.GLU(dim=1), # Halves channels to 2048
            nn.BatchNorm2d(2048)
        )
    def forward(self, x):
        return self.layers(x)


class PretrainedResNetWrapper(nn.Module):
    """
    Wrapper to use standard pre-trained ResNet models from torchvision.

    Loads a specified ResNet architecture pre-trained on ImageNet, removes the
    final classification layer (fc), average pooling, and optionally later layers
    (e.g., layer4), allowing it to be used as a feature extractor backbone.

    Args:
        resnet_type (str): Name of the ResNet model (e.g., 'resnet18', 'resnet50').
        fine_tune (bool): If False, freezes the weights of the pre-trained backbone.
    """
    def __init__(self, resnet_type, fine_tune=True):
        super(PretrainedResNetWrapper, self).__init__()
        self.resnet_type = resnet_type
        self.backbone = torch.hub.load('pytorch/vision:v0.10.0', resnet_type, pretrained=True)

        if not fine_tune:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # Remove final layers to use as feature extractor
        self.backbone.avgpool = Identity()
        self.backbone.fc = Identity()
        # Keep layer4 by default, user can modify instance if needed
        # self.backbone.layer4 = Identity()

    def forward(self, x):
        # Get features from the modified ResNet
        out = self.backbone(x)

        # Reshape output to (B, C, H, W) - This is heuristic based on original comment.
        # User might need to adjust this based on which layers are kept/removed.
        # Infer C based on ResNet type (example values)
        nc = 256 if ('18' in self.resnet_type or '34' in self.resnet_type) else 512 if '50' in self.resnet_type else 1024 if '101' in self.resnet_type else 2048 # Approx for layer3/4 output channel numbers
        # Infer H, W assuming output is flattened C * H * W
        num_features = out.shape[-1]
        # This calculation assumes nc is correct and feature map is square
        wh_squared = num_features / nc
        if wh_squared < 0 or not float(wh_squared).is_integer():
             print(f"Warning: Cannot reliably reshape PretrainedResNetWrapper output. nc={nc}, num_features={num_features}")
             # Return potentially flattened features if reshape fails
             return out
        wh = int(np.sqrt(wh_squared))

        return out.reshape(x.size(0), nc, wh, wh)

# --- Positional Encoding Modules ---

class LearnableFourierPositionalEncoding(nn.Module):
    """
    Learnable Fourier Feature Positional Encoding.

    Implements Algorithm 1 from "Learnable Fourier Features for Multi-Dimensional
    Spatial Positional Encoding" (https://arxiv.org/pdf/2106.02795.pdf).
    Provides positional information for 2D feature maps.

    Args:
        d_model (int): The output dimension of the positional encoding (D).
        G (int): Positional groups (default 1).
        M (int): Dimensionality of input coordinates (default 2 for H, W).
        F_dim (int): Dimension of the Fourier features.
        H_dim (int): Hidden dimension of the MLP.
        gamma (float): Initialization scale for the Fourier projection weights (Wr).
    """
    def __init__(self, d_model,
                 G=1, M=2,
                 F_dim=256,
                 H_dim=128,
                 gamma=1/2.5,
                 ):
        super().__init__()
        self.G = G
        self.M = M
        self.F_dim = F_dim
        self.H_dim = H_dim
        self.D = d_model
        self.gamma = gamma

        self.Wr = nn.Linear(self.M, self.F_dim // 2, bias=False)
        self.mlp = nn.Sequential(
            nn.Linear(self.F_dim, self.H_dim, bias=True),
            nn.GLU(), # Halves H_dim
            nn.Linear(self.H_dim // 2, self.D // self.G),
            nn.LayerNorm(self.D // self.G)
        )

        self.init_weights()

    def init_weights(self):
        nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma ** -2)

    def forward(self, x):
        """
        Computes positional encodings for the input feature map x.

        Args:
            x (torch.Tensor): Input feature map, shape (B, C, H, W).

        Returns:
            torch.Tensor: Positional encoding tensor, shape (B, D, H, W).
        """
        B, C, H, W = x.shape
        # Creates coordinates based on (H, W) and repeats for batch B.
        # Takes x[:,0] assuming channel dim isn't needed for coords.
        x_coord = add_coord_dim(x[:,0]) # Expects (B, H, W) -> (B, H, W, 2)

        # Compute Fourier features
        projected = self.Wr(x_coord) # (B, H, W, F_dim // 2)
        cosines = torch.cos(projected)
        sines = torch.sin(projected)
        F = (1.0 / math.sqrt(self.F_dim)) * torch.cat([cosines, sines], dim=-1) # (B, H, W, F_dim)

        # Project features through MLP
        Y = self.mlp(F) # (B, H, W, D // G)

        # Reshape to (B, D, H, W)
        PEx = Y.permute(0, 3, 1, 2) # Assuming G=1
        return PEx


class MultiLearnableFourierPositionalEncoding(nn.Module):
    """
    Combines multiple LearnableFourierPositionalEncoding modules with different
    initialization scales (gamma) via a learnable weighted sum.

    Allows the model to learn an optimal combination of positional frequencies.

    Args:
        d_model (int): Output dimension of the encoding.
        G, M, F_dim, H_dim: Parameters passed to underlying LearnableFourierPositionalEncoding.
        gamma_range (list[float]): Min and max gamma values for the linspace.
        N (int): Number of parallel embedding modules to create.
    """
    def __init__(self, d_model,
                 G=1, M=2,
                 F_dim=256,
                 H_dim=128,
                 gamma_range=[1.0, 0.1], # Default range
                 N=10,
                 ):
        super().__init__()
        self.embedders = nn.ModuleList()
        for gamma in np.linspace(gamma_range[0], gamma_range[1], N):
            self.embedders.append(LearnableFourierPositionalEncoding(d_model, G, M, F_dim, H_dim, gamma))

        # Renamed parameter from 'combination' to 'combination_weights' for clarity only in comments
        # Actual registered name remains 'combination' as in original code
        self.register_parameter('combination', torch.nn.Parameter(torch.ones(N), requires_grad=True))
        self.N = N


    def forward(self, x):
        """
        Computes combined positional encoding.

        Args:
            x (torch.Tensor): Input feature map, shape (B, C, H, W).

        Returns:
            torch.Tensor: Combined positional encoding tensor, shape (B, D, H, W).
        """
        # Compute embeddings from all modules and stack: (N, B, D, H, W)
        pos_embs = torch.stack([emb(x) for emb in self.embedders], dim=0)

        # Compute combination weights using softmax
        # Use registered parameter name 'combination'
        # Reshape weights for broadcasting: (N,) -> (N, 1, 1, 1, 1)
        weights = F.softmax(self.combination, dim=-1).view(self.N, 1, 1, 1, 1)

        # Compute weighted sum over the N dimension
        combined_emb = (pos_embs * weights).sum(0) # (B, D, H, W)
        return combined_emb


class CustomRotationalEmbedding(nn.Module):
    """
    Custom Rotational Positional Embedding.

    Generates 2D positional embeddings based on rotating a fixed start vector.
    The rotation angle for each grid position is determined primarily by its
    horizontal position (width dimension). The resulting rotated vectors are
    concatenated and projected.

    Note: The current implementation derives angles only from the width dimension (`x.size(-1)`).

    Args:
        d_model (int): Dimensionality of the output embeddings.
    """
    def __init__(self, d_model):
        super(CustomRotationalEmbedding, self).__init__()
        # Learnable 2D start vector
        self.register_parameter('start_vector', nn.Parameter(torch.Tensor([0, 1]), requires_grad=True))
        # Projects the 4D concatenated rotated vectors to d_model
        # Input size 4 comes from concatenating two 2D rotated vectors
        self.projection = nn.Sequential(nn.Linear(4, d_model))

    def forward(self, x):
        """
        Computes rotational positional embeddings based on input width.

        Args:
            x (torch.Tensor): Input tensor (used for shape and device),
                              shape (batch_size, channels, height, width).
        Returns:
            Output tensor containing positional embeddings,
            shape (1, d_model, height, width) - Batch dim is 1 as PE is same for all.
        """
        B, C, H, W = x.shape
        device = x.device

        # --- Generate rotations based only on Width ---
        # Angles derived from width dimension
        theta_rad = torch.deg2rad(torch.linspace(0, 180, W, device=device)) # Angle per column
        cos_theta = torch.cos(theta_rad)
        sin_theta = torch.sin(theta_rad)

        # Create rotation matrices: Shape (W, 2, 2)
        # Use unsqueeze(1) to allow stacking along dim 1
        rotation_matrices = torch.stack([
            torch.stack([cos_theta, -sin_theta], dim=-1), # Shape (W, 2)
            torch.stack([sin_theta, cos_theta], dim=-1)  # Shape (W, 2)
        ], dim=1) # Stacks along dim 1 -> Shape (W, 2, 2)

        # Rotate the start vector by column angle: Shape (W, 2)
        rotated_vectors = torch.einsum('wij,j->wi', rotation_matrices, self.start_vector)

        # --- Create Grid Key ---
        # Original code uses repeats based on rotated_vectors.shape[0] (which is W) for both dimensions.
        # This creates a (W, W, 4) key tensor.
        key = torch.cat((
            torch.repeat_interleave(rotated_vectors.unsqueeze(1), W, dim=1), # (W, 1, 2) -> (W, W, 2)
            torch.repeat_interleave(rotated_vectors.unsqueeze(0), W, dim=0)  # (1, W, 2) -> (W, W, 2)
        ), dim=-1) # Shape (W, W, 4)

        # Project the 4D key vector to d_model: Shape (W, W, d_model)
        pe_grid = self.projection(key)

        # Reshape to (1, d_model, W, W) and then select/resize to target H, W?
        # Original code permutes to (d_model, W, W) and unsqueezes to (1, d_model, W, W)
        pe = pe_grid.permute(2, 0, 1).unsqueeze(0)

        # If H != W, this needs adjustment. Assuming H=W or cropping/padding happens later.
        # Let's return the (1, d_model, W, W) tensor as generated by the original logic.
        # If H != W, downstream code must handle the mismatch or this PE needs modification.
        if H != W:
            # Simple interpolation/cropping could be added, but sticking to original logic:
            # Option 1: Interpolate
            # pe = F.interpolate(pe, size=(H, W), mode='bilinear', align_corners=False)
            # Option 2: Crop/Pad (e.g., crop if W > W_target, pad if W < W_target)
            # Sticking to original: return shape (1, d_model, W, W)
            pass

        return pe

class CustomRotationalEmbedding1D(nn.Module):
    def __init__(self, d_model):
        super(CustomRotationalEmbedding1D, self).__init__()
        self.projection = nn.Linear(2, d_model)

    def forward(self, x):
        start_vector = torch.tensor([0., 1.], device=x.device, dtype=torch.float)
        theta_rad = torch.deg2rad(torch.linspace(0, 180, x.size(2), device=x.device))
        cos_theta = torch.cos(theta_rad)
        sin_theta = torch.sin(theta_rad)
        cos_theta = cos_theta.unsqueeze(1)  # Shape: (height, 1)
        sin_theta = sin_theta.unsqueeze(1)  # Shape: (height, 1)

        # Create rotation matrices
        rotation_matrices = torch.stack([
        torch.cat([cos_theta, -sin_theta], dim=1),
        torch.cat([sin_theta, cos_theta], dim=1)
        ], dim=1)  # Shape: (height, 2, 2)

        # Rotate the start vector
        rotated_vectors = torch.einsum('bij,j->bi', rotation_matrices, start_vector)

        pe = self.projection(rotated_vectors)
        pe = torch.repeat_interleave(pe.unsqueeze(0), x.size(0), 0)
        return pe.transpose(1, 2) # Transpose for compatibility with other backbones