|
|
import torch |
|
|
import pytest |
|
|
import itertools |
|
|
from models.constants import VALID_NEURON_SELECT_TYPES, VALID_BACKBONE_TYPES, VALID_POSITIONAL_EMBEDDING_TYPES |
|
|
import numpy as np |
|
|
|
|
|
def rep_size(neuron_select_type: str, n_synch: int) -> int: |
|
|
return n_synch if neuron_select_type == "random-pairing" else n_synch * (n_synch + 1) // 2 |
|
|
|
|
|
def rep_size(neuron_select_type: str, n_synch: int) -> int: |
|
|
return n_synch if neuron_select_type == "random-pairing" else n_synch * (n_synch + 1) // 2 |
|
|
|
|
|
def grab_synch_tensors(model, s_type: str): |
|
|
if s_type == "out": |
|
|
return ( |
|
|
model.out_neuron_indices_left, |
|
|
model.out_neuron_indices_right, |
|
|
model.decay_params_out, |
|
|
) |
|
|
if s_type == "action": |
|
|
return ( |
|
|
model.action_neuron_indices_left, |
|
|
model.action_neuron_indices_right, |
|
|
model.decay_params_action, |
|
|
) |
|
|
raise ValueError(s_type) |
|
|
|
|
|
|
|
|
|
|
|
def test_golden_parity(golden_test_model_parity, golden_test_input_parity, golden_test_expected_predictions_parity, golden_test_expected_certainties_parity, golden_test_expected_synchronization_out_tracking_parity, golden_test_expected_synchronization_action_tracking_parity, golden_test_expected_pre_activations_tracking_parity, golden_test_expected_post_activations_tracking_parity, golden_test_expected_attentions_tracking_parity): |
|
|
"""Golden test the parity CTM model.""" |
|
|
|
|
|
atol = 1e-5 |
|
|
atol_attn = 1e-3 |
|
|
golden_test_model_parity.eval() |
|
|
predictions, certainties, (synch_out_tracking, synch_action_tracking), pre_activations_tracking, post_activations_tracking, attention_tracking = golden_test_model_parity(golden_test_input_parity, track=True) |
|
|
|
|
|
assert torch.isclose(predictions, golden_test_expected_predictions_parity, atol=atol).all(), f"Predictions do not match expected values." |
|
|
assert torch.isclose(certainties, golden_test_expected_certainties_parity, atol=atol).all(), f"Certainties do not match expected values." |
|
|
assert np.isclose(synch_out_tracking, golden_test_expected_synchronization_out_tracking_parity, atol=atol).all(), f"Synch Out do not match expected values." |
|
|
assert np.isclose(synch_action_tracking, golden_test_expected_synchronization_action_tracking_parity, atol=atol).all(), f"Synch Action do not match expected values." |
|
|
assert np.isclose(pre_activations_tracking, golden_test_expected_pre_activations_tracking_parity, atol=atol).all(), f"Pre-activations do not match expected values." |
|
|
assert np.isclose(post_activations_tracking, golden_test_expected_post_activations_tracking_parity, atol=atol).all(), f"Post-activations do not match expected values." |
|
|
assert np.isclose(attention_tracking, golden_test_expected_attentions_tracking_parity, atol=atol_attn).all(), f"Attention do not match expected values." |
|
|
|
|
|
pass |
|
|
|
|
|
def test_golden_qamnist(golden_test_model_qamnist, golden_test_input_qamnist, golden_test_expected_predictions_qamnist, golden_test_expected_certainties_qamnist, golden_test_expected_synchronization_out_tracking_qamnist, golden_test_expected_pre_activations_tracking_qamnist, golden_test_expected_post_activations_tracking_qamnist, golden_test_expected_attentions_tracking_qamnist, golden_test_expected_embeddings_tracking_qamnist): |
|
|
"""Golden test the QAMNIST CTM model.""" |
|
|
|
|
|
atol = 1e-4 |
|
|
atol_attn = 5e-3 |
|
|
golden_test_model_qamnist.eval() |
|
|
x, z = golden_test_input_qamnist |
|
|
|
|
|
predictions, certainties, synch_out_tracking, pre_activations_tracking, post_activations_tracking, attention_tracking, embedding_tracking = golden_test_model_qamnist(x, z=z, track=True) |
|
|
|
|
|
assert torch.isclose(predictions, golden_test_expected_predictions_qamnist, atol=atol).all(), f"Predictions do not match expected values." |
|
|
assert torch.isclose(certainties, golden_test_expected_certainties_qamnist, atol=atol).all(), f"Certainties do not match expected values." |
|
|
assert torch.isclose(synch_out_tracking, golden_test_expected_synchronization_out_tracking_qamnist[-1], atol=atol).all(), f"Synch Out do not match expected values." |
|
|
assert np.isclose(pre_activations_tracking, golden_test_expected_pre_activations_tracking_qamnist, atol=atol).all(), f"Pre-activations do not match expected values." |
|
|
assert np.isclose(post_activations_tracking, golden_test_expected_post_activations_tracking_qamnist, atol=atol).all(), f"Post-activations do not match expected values." |
|
|
assert np.isclose(attention_tracking, golden_test_expected_attentions_tracking_qamnist, atol=atol_attn).all(), f"Attention do not match expected values." |
|
|
assert np.isclose(embedding_tracking, golden_test_expected_embeddings_tracking_qamnist, atol=atol).all(), f"Embeddings do not match expected values." |
|
|
|
|
|
pass |
|
|
|
|
|
def test_golden_rl(golden_test_model_rl, golden_test_inputs_rl, golden_test_expected_initial_state_trace_rl, golden_test_expected_initial_activated_state_trace_rl, golden_test_expected_action_rl, golden_test_expected_action_log_prob_rl, golden_test_expected_action_entropy_rl, golden_test_expected_value_rl, golden_test_expected_state_trace_rl, golden_test_expected_activated_state_trace_rl, golden_test_expected_action_logits_rl, golden_test_expected_action_probs_rl, golden_test_expected_pre_activations_tracking_rl, golden_test_expected_post_activations_tracking_rl, golden_test_expected_synch_out_tracking_rl): |
|
|
|
|
|
atol = 1e-5 |
|
|
golden_test_model_rl.eval() |
|
|
|
|
|
initial_state_trace, initial_activated_state_trace = golden_test_model_rl.get_initial_state(num_envs=1) |
|
|
|
|
|
dones = torch.zeros(1).to(initial_state_trace.device) |
|
|
|
|
|
assert torch.isclose(initial_state_trace, golden_test_expected_initial_state_trace_rl, atol=atol).all(), f"Initial hidden states of the CTM does not match expected values." |
|
|
assert torch.isclose(initial_activated_state_trace, golden_test_expected_initial_activated_state_trace_rl, atol=atol).all(), f"Initial hidden states of the CTM does not match expected values." |
|
|
|
|
|
_, action_log_probs, entropy, value, (state_trace, activated_state_trace), tracking_data, action_logits, action_probs = golden_test_model_rl.get_action_and_value(golden_test_inputs_rl, (initial_state_trace, initial_activated_state_trace), dones, track=True) |
|
|
|
|
|
pre_activations = tracking_data["pre_activations"] |
|
|
post_activations = tracking_data["post_activations"] |
|
|
synchronization = tracking_data["synchronisation"] |
|
|
|
|
|
assert torch.isclose(action_log_probs, golden_test_expected_action_log_prob_rl, atol=atol).all(), f"Action log probs do not match expected values." |
|
|
assert torch.isclose(entropy, golden_test_expected_action_entropy_rl, atol=atol).all(), f"Entropy does not match expected values." |
|
|
assert torch.isclose(value, golden_test_expected_value_rl, atol=atol).all(), f"Value does not match expected values." |
|
|
assert torch.isclose(state_trace, golden_test_expected_state_trace_rl, atol=atol).all(), f"State trace does not match expected values." |
|
|
assert torch.isclose(activated_state_trace, golden_test_expected_activated_state_trace_rl, atol=atol).all(), f"Activated state trace does not match expected values." |
|
|
assert np.isclose(pre_activations, golden_test_expected_pre_activations_tracking_rl, atol=atol).all(), f"Pre-activations do not match expected values." |
|
|
assert np.isclose(post_activations, golden_test_expected_post_activations_tracking_rl, atol=atol).all(), f"Post-activations do not match expected values." |
|
|
assert np.isclose(synchronization, golden_test_expected_synch_out_tracking_rl, atol=atol).all(), f"Synchronisation do not match expected values." |
|
|
assert torch.isclose(action_logits, golden_test_expected_action_logits_rl, atol=atol).all(), f"Action logits do not match expected values." |
|
|
assert torch.isclose(action_probs, golden_test_expected_action_probs_rl, atol=atol).all(), f"Action probs do not match expected values." |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("synch_type", ["out", "action"]) |
|
|
@pytest.mark.parametrize("neuron_select_type", ["first-last", "random", "random-pairing"]) |
|
|
def test_set_synchronisation_parameters(ctm_factory, base_params, device, synch_type, neuron_select_type): |
|
|
np.random.seed(0) |
|
|
n_synch = 8 |
|
|
num_random_pairing_self = 2 |
|
|
|
|
|
model = ctm_factory( |
|
|
base_params, |
|
|
d_model=64, |
|
|
n_synch_out=n_synch, |
|
|
n_synch_action=n_synch, |
|
|
neuron_select_type=neuron_select_type, |
|
|
n_random_pairing_self=num_random_pairing_self, |
|
|
).to(device) |
|
|
|
|
|
left, right, decay = grab_synch_tensors(model, synch_type) |
|
|
|
|
|
|
|
|
assert left.dtype == right.dtype == torch.long |
|
|
assert left.shape == right.shape == (n_synch,) |
|
|
assert decay.shape == (rep_size(neuron_select_type, n_synch),) |
|
|
|
|
|
|
|
|
assert left.size(0) == right.size(0) == n_synch |
|
|
|
|
|
assert torch.all(left < model.d_model) and torch.all(right < model.d_model) |
|
|
|
|
|
|
|
|
if neuron_select_type == "first-last": |
|
|
if synch_type == "out": |
|
|
exp = torch.arange(0, n_synch, device=device) |
|
|
else: |
|
|
exp = torch.arange(model.d_model - n_synch, model.d_model, device=device) |
|
|
assert torch.equal(left, exp) and torch.equal(right, exp) |
|
|
elif neuron_select_type == "random": |
|
|
pass |
|
|
elif neuron_select_type == "random-pairing": |
|
|
assert torch.equal(right[:num_random_pairing_self], left[:num_random_pairing_self]) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("neuron_select_type", VALID_NEURON_SELECT_TYPES) |
|
|
def test_valid_neuron_select_type(ctm_factory, base_params, neuron_select_type): |
|
|
model = ctm_factory(base_params, neuron_select_type=neuron_select_type) |
|
|
assert model is not None |
|
|
|
|
|
def test_none_neuron_select_type(ctm_factory, base_params): |
|
|
with pytest.raises(Exception): |
|
|
ctm_factory(base_params, neuron_select_type="none") |
|
|
|
|
|
def test_invalid_neuron_select_type(ctm_factory, base_params): |
|
|
with pytest.raises(Exception): |
|
|
ctm_factory(base_params, neuron_select_type="invalid-option") |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("backbone_type, positional_embedding_type", list(itertools.product(VALID_BACKBONE_TYPES, VALID_POSITIONAL_EMBEDDING_TYPES))) |
|
|
def test_valid_backbone_and_valid_positional_embedding(ctm_factory, base_params, backbone_type, positional_embedding_type): |
|
|
model = ctm_factory( |
|
|
base_params, |
|
|
backbone_type=backbone_type, |
|
|
positional_embedding_type=positional_embedding_type, |
|
|
) |
|
|
assert model is not None |
|
|
|
|
|
def test_none_backbone_with_none_positional_embeddings(ctm_factory, base_params): |
|
|
model = ctm_factory( |
|
|
base_params, |
|
|
backbone_type="none", |
|
|
positional_embedding_type="none", |
|
|
) |
|
|
assert model is not None |
|
|
|
|
|
@pytest.mark.parametrize("positional_embedding_type", VALID_POSITIONAL_EMBEDDING_TYPES) |
|
|
def test_none_backbone_with_valid_positional_embeddings(ctm_factory, base_params, positional_embedding_type): |
|
|
with pytest.raises(Exception): |
|
|
ctm_factory( |
|
|
base_params, |
|
|
backbone_type="none", |
|
|
positional_embedding_type=positional_embedding_type, |
|
|
) |
|
|
|
|
|
@pytest.mark.parametrize("backbone_type", VALID_BACKBONE_TYPES) |
|
|
def test_valid_backbone_with_none_positional_embeddings(ctm_factory, base_params, backbone_type): |
|
|
model = ctm_factory( |
|
|
base_params, |
|
|
backbone_type=backbone_type, |
|
|
positional_embedding_type="none", |
|
|
) |
|
|
assert model is not None |
|
|
|
|
|
|
|
|
|
|
|
def test_parity_prediction_shape(parity_ctm_model, parity_params, parity_input): |
|
|
predictions, _, _ = parity_ctm_model(parity_input) |
|
|
|
|
|
batch_size, parity_length = parity_input.shape |
|
|
expected_shape = (batch_size, parity_length * 2, parity_params["iterations"]) |
|
|
assert predictions.shape == expected_shape |
|
|
|
|
|
def test_parity_certainty_shape(parity_ctm_model, parity_params, parity_input): |
|
|
_, certainties, _ = parity_ctm_model(parity_input) |
|
|
|
|
|
batch_size = parity_input.shape[0] |
|
|
expected_shape = (batch_size, 2, parity_params["iterations"]) |
|
|
assert certainties.shape == expected_shape |
|
|
|
|
|
def test_parity_nans_in_predictions(parity_ctm_model, parity_input): |
|
|
predictions, _, _ = parity_ctm_model(parity_input) |
|
|
assert not torch.isnan(predictions).any() |
|
|
|
|
|
|
|
|
|
|
|
def test_qamnist_prediction_shape(qamnist_model_factory, qamnist_params, qamnist_input, device): |
|
|
model = qamnist_model_factory("first-last").to(device) |
|
|
inputs, z = qamnist_input |
|
|
|
|
|
predictions, _, _ = model(inputs, z) |
|
|
B = inputs.shape[0] |
|
|
out_dims = qamnist_params["out_dims"] |
|
|
T = inputs.shape[1] + z.shape[1] + qamnist_params["iterations_for_answering"] |
|
|
expected_shape = (B, out_dims, T) |
|
|
assert predictions.shape == expected_shape, f"Expected {expected_shape}, got {predictions.shape}" |
|
|
|
|
|
def test_qamnist_certainty_shape(qamnist_model_factory, qamnist_params, qamnist_input, device): |
|
|
model = qamnist_model_factory("first-last").to(device) |
|
|
inputs, z = qamnist_input |
|
|
|
|
|
_, certainties, _ = model(inputs, z) |
|
|
B = inputs.shape[0] |
|
|
T = inputs.shape[1] + z.shape[1] + qamnist_params["iterations_for_answering"] |
|
|
expected_shape = (B, 2, T) |
|
|
assert certainties.shape == expected_shape, f"Expected {expected_shape}, got {certainties.shape}" |
|
|
|
|
|
def test_qamnist_nans_in_predictions(qamnist_model_factory, qamnist_input, device): |
|
|
model = qamnist_model_factory("first-last").to(device) |
|
|
inputs, z = qamnist_input |
|
|
|
|
|
predictions, _, _ = model(inputs, z) |
|
|
assert not torch.isnan(predictions).any(), "Predictions contain NaNs" |
|
|
|
|
|
@pytest.mark.parametrize("neuron_select_type", ["first-last", "random", "random-pairing"]) |
|
|
def test_qamnist_synchronisation_shape(qamnist_model_factory, qamnist_params, qamnist_input, neuron_select_type, device): |
|
|
model = qamnist_model_factory(neuron_select_type).to(device) |
|
|
inputs, z = qamnist_input |
|
|
|
|
|
_, _, synchronisation = model(inputs, z) |
|
|
|
|
|
batch_size = inputs.shape[0] |
|
|
n_synch_out = qamnist_params["n_synch_out"] |
|
|
|
|
|
if neuron_select_type in ("first-last", "random"): |
|
|
expected_size = (n_synch_out * (n_synch_out + 1)) // 2 |
|
|
elif neuron_select_type == "random-pairing": |
|
|
expected_size = n_synch_out |
|
|
|
|
|
assert synchronisation.shape == (batch_size, expected_size), \ |
|
|
f"Expected {(batch_size, expected_size)}, got {synchronisation.shape}" |
|
|
|