|
|
import torch.nn as nn |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import math |
|
|
|
|
|
|
|
|
from models.modules import * |
|
|
from models.utils import * |
|
|
|
|
|
|
|
|
class LSTMBaseline(nn.Module): |
|
|
""" |
|
|
|
|
|
LSTM Baseline |
|
|
|
|
|
Args: |
|
|
iterations (int): Number of internal 'thought' steps (T, in paper). |
|
|
d_model (int): Core dimensionality of the CTM's latent space (D, in paper). |
|
|
d_input (int): Dimensionality of projected attention outputs or direct input features. |
|
|
backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none'). |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
iterations, |
|
|
d_model, |
|
|
d_input, |
|
|
backbone_type, |
|
|
): |
|
|
super(LSTMBaseline, self).__init__() |
|
|
|
|
|
|
|
|
self.iterations = iterations |
|
|
self.d_model = d_model |
|
|
self.backbone_type = backbone_type |
|
|
|
|
|
|
|
|
assert backbone_type in ('navigation-backbone', 'classic-control-backbone'), f"Invalid backbone_type: {backbone_type}" |
|
|
|
|
|
|
|
|
if self.backbone_type == 'navigation-backbone': |
|
|
grid_size = 7 |
|
|
self.backbone = MiniGridBackbone(d_input=d_input, grid_size=grid_size) |
|
|
lstm_cell_input_dim = grid_size * grid_size * d_input |
|
|
|
|
|
elif self.backbone_type == 'classic-control-backbone': |
|
|
self.backbone = ClassicControlBackbone(d_input=d_input) |
|
|
lstm_cell_input_dim = d_input |
|
|
|
|
|
else: |
|
|
raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).') |
|
|
|
|
|
|
|
|
self.lstm_cell = nn.LSTMCell(lstm_cell_input_dim, d_model) |
|
|
self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) |
|
|
self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) |
|
|
|
|
|
def compute_features(self, x): |
|
|
"""Applies backbone and positional embedding to input.""" |
|
|
return self.backbone(x) |
|
|
|
|
|
|
|
|
def forward(self, x, hidden_states, track=False): |
|
|
""" |
|
|
Forward pass - Reverted to structure closer to user's working version. |
|
|
Executes T=iterations steps. |
|
|
""" |
|
|
|
|
|
|
|
|
activations_tracking = [] |
|
|
|
|
|
|
|
|
features = self.compute_features(x) |
|
|
|
|
|
hidden_state = hidden_states[0] |
|
|
cell_state = hidden_states[1] |
|
|
|
|
|
|
|
|
for stepi in range(self.iterations): |
|
|
|
|
|
lstm_input = features.reshape(x.size(0), -1) |
|
|
hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state)) |
|
|
|
|
|
|
|
|
if track: |
|
|
activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy()) |
|
|
|
|
|
hidden_states = ( |
|
|
hidden_state, |
|
|
cell_state |
|
|
) |
|
|
|
|
|
|
|
|
if track: |
|
|
return hidden_state, hidden_states, np.array(activations_tracking), np.array(activations_tracking) |
|
|
return hidden_state, hidden_states |