LukeDarlow's picture
Welcome to the CTM. This is the first commit of the public repo. Enjoy!
68b32f4
import torch.nn as nn
# Local imports (Assuming these contain necessary custom modules)
from models.modules import *
from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
class FFBaseline(nn.Module):
"""
LSTM Baseline.
Wrapper that lets us use the same backbone as the CTM and LSTM baselines, with a
Args:
d_model (int): workaround that projects final layer to this space so that parameter-matching is plausible.
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
out_dims (int): Dimensionality of the final output projection.
dropout (float): dropout in last layer
"""
def __init__(self,
d_model,
backbone_type,
out_dims,
dropout=0,
):
super(FFBaseline, self).__init__()
# --- Core Parameters ---
self.d_model = d_model
self.backbone_type = backbone_type
self.out_dims = out_dims
# --- Input Assertions ---
assert backbone_type in ['resnet18-1', 'resnet18-2', 'resnet18-3', 'resnet18-4',
'resnet34-1', 'resnet34-2', 'resnet34-3', 'resnet34-4',
'resnet50-1', 'resnet50-2', 'resnet50-3', 'resnet50-4',
'resnet101-1', 'resnet101-2', 'resnet101-3', 'resnet101-4',
'resnet152-1', 'resnet152-2', 'resnet152-3', 'resnet152-4',
'none', 'shallow-wide', 'parity_backbone'], f"Invalid backbone_type: {backbone_type}"
# --- Backbone / Feature Extraction ---
self.initial_rgb = Identity() # Placeholder, potentially replaced if using ResNet
self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily
resnet_family = resnet18 # Default
if '34' in self.backbone_type: resnet_family = resnet34
if '50' in self.backbone_type: resnet_family = resnet50
if '101' in self.backbone_type: resnet_family = resnet101
if '152' in self.backbone_type: resnet_family = resnet152
# Determine which ResNet blocks to keep
block_num_str = self.backbone_type.split('-')[-1]
hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4]
self.backbone = resnet_family(
3, # initial_rgb handles input channels now
hyper_blocks_to_keep,
stride=2,
pretrained=False,
progress=True,
device="cpu", # Initialise on CPU, move later via .to(device)
do_initial_max_pool=True,
)
# At this point we will have a 4D tensor of features: [B, C, H, W]
# The following lets us scale up the resnet with d_model until it matches the CTM
self.output_projector = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), Squeeze(-1), Squeeze(-1), nn.LazyLinear(d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, out_dims))
def forward(self, x):
return self.output_projector((self.backbone(self.initial_rgb(x))))