File size: 3,580 Bytes
68b32f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch.nn as nn
import torch
import torch.nn.functional as F # Used for GLU if not in modules
import numpy as np
import math

# Local imports (Assuming these contain necessary custom modules)
from models.modules import *
from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here


class LSTMBaseline(nn.Module):
    """

    LSTM Baseline

    Args:
        iterations (int): Number of internal 'thought' steps (T, in paper).
        d_model (int): Core dimensionality of the CTM's latent space (D, in paper).
        d_input (int): Dimensionality of projected attention outputs or direct input features.
        backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none').
    """

    def __init__(self,
                 iterations,
                 d_model,
                 d_input,
                 backbone_type,
                 ):
        super(LSTMBaseline, self).__init__()

        # --- Core Parameters ---
        self.iterations = iterations
        self.d_model = d_model
        self.backbone_type = backbone_type

        # --- Input Assertions ---
        assert backbone_type in ('navigation-backbone', 'classic-control-backbone'), f"Invalid backbone_type: {backbone_type}"

        # --- Backbone / Feature Extraction ---
        if self.backbone_type == 'navigation-backbone':
            grid_size = 7
            self.backbone = MiniGridBackbone(d_input=d_input, grid_size=grid_size)
            lstm_cell_input_dim = grid_size * grid_size * d_input

        elif self.backbone_type == 'classic-control-backbone':
            self.backbone = ClassicControlBackbone(d_input=d_input)
            lstm_cell_input_dim = d_input

        else:
            raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).')

        # --- Core LSTM Modules ---
        self.lstm_cell = nn.LSTMCell(lstm_cell_input_dim, d_model)
        self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
        self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True))
    
    def compute_features(self, x):
        """Applies backbone and positional embedding to input."""
        return self.backbone(x)


    def forward(self, x, hidden_states, track=False):
        """
        Forward pass - Reverted to structure closer to user's working version.
        Executes T=iterations steps.
        """

        # --- Tracking Initialization ---
        activations_tracking = []

        # --- Featurise Input Data ---
        features = self.compute_features(x)

        hidden_state = hidden_states[0]
        cell_state = hidden_states[1]

        # --- Recurrent Loop ---
        for stepi in range(self.iterations):

            lstm_input = features.reshape(x.size(0), -1)
            hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state))

            # --- Tracking ---
            if track:
                activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy())

        hidden_states = (
            hidden_state,
            cell_state
        )

        # --- Return Values ---
        if track:
            return hidden_state, hidden_states, np.array(activations_tracking), np.array(activations_tracking)
        return hidden_state, hidden_states