File size: 14,483 Bytes
68b32f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import torch
import pytest
import itertools
from models.constants import VALID_NEURON_SELECT_TYPES, VALID_BACKBONE_TYPES, VALID_POSITIONAL_EMBEDDING_TYPES
import numpy as np

def rep_size(neuron_select_type: str, n_synch: int) -> int:
    return n_synch if neuron_select_type == "random-pairing" else n_synch * (n_synch + 1) // 2

def rep_size(neuron_select_type: str, n_synch: int) -> int:
    return n_synch if neuron_select_type == "random-pairing" else n_synch * (n_synch + 1) // 2

def grab_synch_tensors(model, s_type: str):
    if s_type == "out":
        return (
            model.out_neuron_indices_left,
            model.out_neuron_indices_right,
            model.decay_params_out,
        )
    if s_type == "action":
        return (
            model.action_neuron_indices_left,
            model.action_neuron_indices_right,
            model.decay_params_action,
        )
    raise ValueError(s_type)

# --- Golden Tests --- 

def test_golden_parity(golden_test_model_parity, golden_test_input_parity, golden_test_expected_predictions_parity, golden_test_expected_certainties_parity, golden_test_expected_synchronization_out_tracking_parity, golden_test_expected_synchronization_action_tracking_parity, golden_test_expected_pre_activations_tracking_parity, golden_test_expected_post_activations_tracking_parity, golden_test_expected_attentions_tracking_parity):
    """Golden test the parity CTM model."""

    atol = 1e-5
    atol_attn = 1e-3
    golden_test_model_parity.eval()
    predictions, certainties, (synch_out_tracking, synch_action_tracking), pre_activations_tracking, post_activations_tracking, attention_tracking = golden_test_model_parity(golden_test_input_parity, track=True)

    assert torch.isclose(predictions, golden_test_expected_predictions_parity, atol=atol).all(), f"Predictions do not match expected values."
    assert torch.isclose(certainties, golden_test_expected_certainties_parity, atol=atol).all(), f"Certainties do not match expected values."
    assert np.isclose(synch_out_tracking, golden_test_expected_synchronization_out_tracking_parity, atol=atol).all(), f"Synch Out do not match expected values."
    assert np.isclose(synch_action_tracking, golden_test_expected_synchronization_action_tracking_parity, atol=atol).all(), f"Synch Action do not match expected values."
    assert np.isclose(pre_activations_tracking, golden_test_expected_pre_activations_tracking_parity, atol=atol).all(), f"Pre-activations do not match expected values."
    assert np.isclose(post_activations_tracking, golden_test_expected_post_activations_tracking_parity, atol=atol).all(), f"Post-activations do not match expected values."
    assert np.isclose(attention_tracking, golden_test_expected_attentions_tracking_parity, atol=atol_attn).all(), f"Attention do not match expected values."

    pass

def test_golden_qamnist(golden_test_model_qamnist, golden_test_input_qamnist, golden_test_expected_predictions_qamnist, golden_test_expected_certainties_qamnist, golden_test_expected_synchronization_out_tracking_qamnist, golden_test_expected_pre_activations_tracking_qamnist, golden_test_expected_post_activations_tracking_qamnist, golden_test_expected_attentions_tracking_qamnist, golden_test_expected_embeddings_tracking_qamnist):
    """Golden test the QAMNIST CTM model."""
    
    atol = 1e-4
    atol_attn = 5e-3
    golden_test_model_qamnist.eval()
    x, z = golden_test_input_qamnist

    predictions, certainties, synch_out_tracking, pre_activations_tracking, post_activations_tracking, attention_tracking, embedding_tracking = golden_test_model_qamnist(x, z=z, track=True)

    assert torch.isclose(predictions, golden_test_expected_predictions_qamnist, atol=atol).all(), f"Predictions do not match expected values."
    assert torch.isclose(certainties, golden_test_expected_certainties_qamnist, atol=atol).all(), f"Certainties do not match expected values."
    assert torch.isclose(synch_out_tracking, golden_test_expected_synchronization_out_tracking_qamnist[-1], atol=atol).all(), f"Synch Out do not match expected values."
    assert np.isclose(pre_activations_tracking, golden_test_expected_pre_activations_tracking_qamnist, atol=atol).all(), f"Pre-activations do not match expected values."
    assert np.isclose(post_activations_tracking, golden_test_expected_post_activations_tracking_qamnist, atol=atol).all(), f"Post-activations do not match expected values."
    assert np.isclose(attention_tracking, golden_test_expected_attentions_tracking_qamnist, atol=atol_attn).all(), f"Attention do not match expected values."
    assert np.isclose(embedding_tracking, golden_test_expected_embeddings_tracking_qamnist, atol=atol).all(), f"Embeddings do not match expected values."

    pass

def test_golden_rl(golden_test_model_rl, golden_test_inputs_rl, golden_test_expected_initial_state_trace_rl, golden_test_expected_initial_activated_state_trace_rl, golden_test_expected_action_rl, golden_test_expected_action_log_prob_rl, golden_test_expected_action_entropy_rl, golden_test_expected_value_rl, golden_test_expected_state_trace_rl, golden_test_expected_activated_state_trace_rl, golden_test_expected_action_logits_rl, golden_test_expected_action_probs_rl, golden_test_expected_pre_activations_tracking_rl, golden_test_expected_post_activations_tracking_rl, golden_test_expected_synch_out_tracking_rl):

    atol = 1e-5
    golden_test_model_rl.eval()

    initial_state_trace, initial_activated_state_trace = golden_test_model_rl.get_initial_state(num_envs=1)

    dones = torch.zeros(1).to(initial_state_trace.device)

    assert torch.isclose(initial_state_trace, golden_test_expected_initial_state_trace_rl, atol=atol).all(), f"Initial hidden states of the CTM does not match expected values."
    assert torch.isclose(initial_activated_state_trace, golden_test_expected_initial_activated_state_trace_rl, atol=atol).all(), f"Initial hidden states of the CTM does not match expected values."

    _, action_log_probs, entropy, value, (state_trace, activated_state_trace), tracking_data, action_logits, action_probs = golden_test_model_rl.get_action_and_value(golden_test_inputs_rl, (initial_state_trace, initial_activated_state_trace), dones, track=True)

    pre_activations = tracking_data["pre_activations"]
    post_activations = tracking_data["post_activations"]
    synchronization = tracking_data["synchronisation"]

    assert torch.isclose(action_log_probs, golden_test_expected_action_log_prob_rl, atol=atol).all(), f"Action log probs do not match expected values."
    assert torch.isclose(entropy, golden_test_expected_action_entropy_rl, atol=atol).all(), f"Entropy does not match expected values."
    assert torch.isclose(value, golden_test_expected_value_rl, atol=atol).all(), f"Value does not match expected values."
    assert torch.isclose(state_trace, golden_test_expected_state_trace_rl, atol=atol).all(), f"State trace does not match expected values."
    assert torch.isclose(activated_state_trace, golden_test_expected_activated_state_trace_rl, atol=atol).all(), f"Activated state trace does not match expected values."
    assert np.isclose(pre_activations, golden_test_expected_pre_activations_tracking_rl, atol=atol).all(), f"Pre-activations do not match expected values."
    assert np.isclose(post_activations, golden_test_expected_post_activations_tracking_rl, atol=atol).all(), f"Post-activations do not match expected values."
    assert np.isclose(synchronization, golden_test_expected_synch_out_tracking_rl, atol=atol).all(), f"Synchronisation do not match expected values."
    assert torch.isclose(action_logits, golden_test_expected_action_logits_rl, atol=atol).all(), f"Action logits do not match expected values."
    assert torch.isclose(action_probs, golden_test_expected_action_probs_rl, atol=atol).all(), f"Action probs do not match expected values."

    pass

# --- General CTM Tests ---

@pytest.mark.parametrize("synch_type", ["out", "action"])
@pytest.mark.parametrize("neuron_select_type", ["first-last", "random", "random-pairing"])
def test_set_synchronisation_parameters(ctm_factory, base_params, device, synch_type, neuron_select_type):
    np.random.seed(0)
    n_synch = 8
    num_random_pairing_self = 2

    model = ctm_factory(
        base_params,
        d_model=64,
        n_synch_out=n_synch,
        n_synch_action=n_synch,
        neuron_select_type=neuron_select_type,
        n_random_pairing_self=num_random_pairing_self,
    ).to(device)

    left, right, decay = grab_synch_tensors(model, synch_type)

    # Check shapes
    assert left.dtype == right.dtype == torch.long
    assert left.shape == right.shape == (n_synch,)
    assert decay.shape == (rep_size(neuron_select_type, n_synch),)

    # Check equal number of neurons on left and right
    assert left.size(0) == right.size(0) == n_synch
    # Check that the left and right indices are within the model's d_model
    assert torch.all(left < model.d_model) and torch.all(right < model.d_model)

    # Test neuron pairing selection
    if neuron_select_type == "first-last":
        if synch_type == "out":
            exp = torch.arange(0, n_synch, device=device)
        else:
            exp = torch.arange(model.d_model - n_synch, model.d_model, device=device)
        assert torch.equal(left, exp) and torch.equal(right, exp)
    elif neuron_select_type == "random":
        pass
    elif neuron_select_type == "random-pairing":
        assert torch.equal(right[:num_random_pairing_self], left[:num_random_pairing_self])

# ------ Neuron Select Type Test ---

@pytest.mark.parametrize("neuron_select_type", VALID_NEURON_SELECT_TYPES)
def test_valid_neuron_select_type(ctm_factory, base_params, neuron_select_type):
    model = ctm_factory(base_params, neuron_select_type=neuron_select_type)
    assert model is not None

def test_none_neuron_select_type(ctm_factory, base_params):
    with pytest.raises(Exception):
        ctm_factory(base_params, neuron_select_type="none")

def test_invalid_neuron_select_type(ctm_factory, base_params):
    with pytest.raises(Exception):
        ctm_factory(base_params, neuron_select_type="invalid-option")

# ------ Backbone and Positional Embedding Type Test ---

@pytest.mark.parametrize("backbone_type, positional_embedding_type", list(itertools.product(VALID_BACKBONE_TYPES, VALID_POSITIONAL_EMBEDDING_TYPES)))
def test_valid_backbone_and_valid_positional_embedding(ctm_factory, base_params, backbone_type, positional_embedding_type):
    model = ctm_factory(
        base_params,
        backbone_type=backbone_type,
        positional_embedding_type=positional_embedding_type,
    )
    assert model is not None

def test_none_backbone_with_none_positional_embeddings(ctm_factory, base_params):
    model = ctm_factory(
        base_params,
        backbone_type="none",
        positional_embedding_type="none",
    )
    assert model is not None

@pytest.mark.parametrize("positional_embedding_type", VALID_POSITIONAL_EMBEDDING_TYPES)
def test_none_backbone_with_valid_positional_embeddings(ctm_factory, base_params, positional_embedding_type):
    with pytest.raises(Exception):
        ctm_factory(
            base_params,
            backbone_type="none",
            positional_embedding_type=positional_embedding_type,
        )

@pytest.mark.parametrize("backbone_type", VALID_BACKBONE_TYPES)
def test_valid_backbone_with_none_positional_embeddings(ctm_factory, base_params, backbone_type):
    model = ctm_factory(
        base_params,
        backbone_type=backbone_type,
        positional_embedding_type="none",
    )
    assert model is not None

# --- Parity Tests ---

def test_parity_prediction_shape(parity_ctm_model, parity_params, parity_input):
    predictions, _, _ = parity_ctm_model(parity_input)

    batch_size, parity_length = parity_input.shape
    expected_shape = (batch_size, parity_length * 2, parity_params["iterations"])
    assert predictions.shape == expected_shape

def test_parity_certainty_shape(parity_ctm_model, parity_params, parity_input):
    _, certainties, _ = parity_ctm_model(parity_input)

    batch_size = parity_input.shape[0]
    expected_shape = (batch_size, 2, parity_params["iterations"])
    assert certainties.shape == expected_shape

def test_parity_nans_in_predictions(parity_ctm_model, parity_input):
    predictions, _, _ = parity_ctm_model(parity_input)
    assert not torch.isnan(predictions).any()

# --- QAMNIST Tests ---

def test_qamnist_prediction_shape(qamnist_model_factory, qamnist_params, qamnist_input, device):
    model = qamnist_model_factory("first-last").to(device)
    inputs, z = qamnist_input

    predictions, _, _ = model(inputs, z)
    B = inputs.shape[0]
    out_dims = qamnist_params["out_dims"]
    T = inputs.shape[1] + z.shape[1] + qamnist_params["iterations_for_answering"]
    expected_shape = (B, out_dims, T)
    assert predictions.shape == expected_shape, f"Expected {expected_shape}, got {predictions.shape}"

def test_qamnist_certainty_shape(qamnist_model_factory, qamnist_params, qamnist_input, device):
    model = qamnist_model_factory("first-last").to(device)
    inputs, z = qamnist_input

    _, certainties, _ = model(inputs, z)
    B = inputs.shape[0]
    T = inputs.shape[1] + z.shape[1] + qamnist_params["iterations_for_answering"]
    expected_shape = (B, 2, T)
    assert certainties.shape == expected_shape, f"Expected {expected_shape}, got {certainties.shape}"

def test_qamnist_nans_in_predictions(qamnist_model_factory, qamnist_input, device):
    model = qamnist_model_factory("first-last").to(device)
    inputs, z = qamnist_input

    predictions, _, _ = model(inputs, z)
    assert not torch.isnan(predictions).any(), "Predictions contain NaNs"

@pytest.mark.parametrize("neuron_select_type", ["first-last", "random", "random-pairing"])
def test_qamnist_synchronisation_shape(qamnist_model_factory, qamnist_params, qamnist_input, neuron_select_type, device):
    model = qamnist_model_factory(neuron_select_type).to(device)
    inputs, z = qamnist_input

    _, _, synchronisation = model(inputs, z)

    batch_size = inputs.shape[0]
    n_synch_out = qamnist_params["n_synch_out"]

    if neuron_select_type in ("first-last", "random"):
        expected_size = (n_synch_out * (n_synch_out + 1)) // 2
    elif neuron_select_type == "random-pairing":
        expected_size = n_synch_out

    assert synchronisation.shape == (batch_size, expected_size), \
        f"Expected {(batch_size, expected_size)}, got {synchronisation.shape}"