File size: 3,182 Bytes
68b32f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch.nn as nn

# Local imports (Assuming these contain necessary custom modules)
from models.modules import *
from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152


class FFBaseline(nn.Module):
    """
    LSTM Baseline.

    Wrapper that lets us use the same backbone as the CTM and LSTM baselines, with a 


    Args:
        d_model (int): workaround that projects final layer to this space so that parameter-matching is plausible.
        backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
        out_dims (int): Dimensionality of the final output projection.
        dropout (float): dropout in last layer
    """

    def __init__(self,
                 d_model,
                 backbone_type,
                 out_dims,
                 dropout=0,
                 ):
        super(FFBaseline, self).__init__()

        # --- Core Parameters ---
        self.d_model = d_model
        self.backbone_type = backbone_type
        self.out_dims = out_dims

        # --- Input Assertions ---
        assert backbone_type in ['resnet18-1', 'resnet18-2', 'resnet18-3', 'resnet18-4',
                                 'resnet34-1', 'resnet34-2', 'resnet34-3', 'resnet34-4',
                                 'resnet50-1', 'resnet50-2', 'resnet50-3', 'resnet50-4',
                                 'resnet101-1', 'resnet101-2', 'resnet101-3', 'resnet101-4',
                                 'resnet152-1', 'resnet152-2', 'resnet152-3', 'resnet152-4',
                                 'none', 'shallow-wide', 'parity_backbone'], f"Invalid backbone_type: {backbone_type}"

        # --- Backbone / Feature Extraction ---
        self.initial_rgb = Identity() # Placeholder, potentially replaced if using ResNet

        
        self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
        resnet_family = resnet18 # Default
        if '34' in self.backbone_type: resnet_family = resnet34
        if '50' in self.backbone_type: resnet_family = resnet50
        if '101' in self.backbone_type: resnet_family = resnet101
        if '152' in self.backbone_type: resnet_family = resnet152

        # Determine which ResNet blocks to keep
        block_num_str = self.backbone_type.split('-')[-1]
        hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]

        self.backbone = resnet_family(
            3, # initial_rgb handles input channels now
            hyper_blocks_to_keep,
            stride=2,
            pretrained=False,
            progress=True,
            device="cpu", # Initialise on CPU, move later via .to(device)
            do_initial_max_pool=True,
        )


        # At this point we will have a 4D tensor of features: [B, C, H, W]
        # The following lets us scale up the resnet with d_model until it matches the CTM
        self.output_projector = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), Squeeze(-1), Squeeze(-1), nn.LazyLinear(d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, out_dims))


    def forward(self, x):
        return self.output_projector((self.backbone(self.initial_rgb(x))))