LukeDarlow's picture
Welcome to the CTM. This is the first commit of the public repo. Enjoy!
68b32f4
import torch.nn as nn
import torch
import torch.nn.functional as F # Used for GLU if not in modules
import numpy as np
import math
# Local imports (Assuming these contain necessary custom modules)
from models.modules import *
from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here
class LSTMBaseline(nn.Module):
"""
LSTM Baseline
Args:
iterations (int): Number of internal 'thought' steps (T, in paper).
d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
d_input (int): Dimensionality of projected attention outputs or direct input features.
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
"""
def __init__(self,
iterations,
d_model,
d_input,
backbone_type,
):
super(LSTMBaseline, self).__init__()
# --- Core Parameters ---
self.iterations = iterations
self.d_model = d_model
self.backbone_type = backbone_type
# --- Input Assertions ---
assert backbone_type in ('navigation-backbone', 'classic-control-backbone'), f"Invalid backbone_type: {backbone_type}"
# --- Backbone / Feature Extraction ---
if self.backbone_type == 'navigation-backbone':
grid_size = 7
self.backbone = MiniGridBackbone(d_input=d_input, grid_size=grid_size)
lstm_cell_input_dim = grid_size * grid_size * d_input
elif self.backbone_type == 'classic-control-backbone':
self.backbone = ClassicControlBackbone(d_input=d_input)
lstm_cell_input_dim = d_input
else:
raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')
# --- Core LSTM Modules ---
self.lstm_cell = nn.LSTMCell(lstm_cell_input_dim, d_model)
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
def compute_features(self, x):
"""Applies backbone and positional embedding to input."""
return self.backbone(x)
def forward(self, x, hidden_states, track=False):
"""
Forward pass - Reverted to structure closer to user's working version.
Executes T=iterations steps.
"""
# --- Tracking Initialization ---
activations_tracking = []
# --- Featurise Input Data ---
features = self.compute_features(x)
hidden_state = hidden_states[0]
cell_state = hidden_states[1]
# --- Recurrent Loop ---
for stepi in range(self.iterations):
lstm_input = features.reshape(x.size(0), -1)
hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))
# --- Tracking ---
if track:
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())
hidden_states = (
hidden_state,
cell_state
)
# --- Return Values ---
if track:
return hidden_state, hidden_states, np.array(activations_tracking), np.array(activations_tracking)
return hidden_state, hidden_states