| import pytest |
| import torch |
| from torch.utils.data import DataLoader |
| from unittest.mock import patch |
| from mentioned.data import ( |
| mentions_by_sentence, |
| flatten_to_sentences, |
| LitBankMentionDataset, |
| collate_fn, |
| make_litbank, |
| extract_spans_from_bio, |
| flatten_entities, |
| LitBankEntityDataset, |
| entity_collate_fn, |
| make_litbank_entity, |
| ) |
|
|
| |
|
|
|
|
| @pytest.fixture |
| def mock_raw_example(): |
| """Simulates a raw entry from LitBank before flattening.""" |
| return { |
| "sentences": [["The", "cat", "sat", "."], ["It", "was", "happy", "."]], |
| "coref_chains": [ |
| [[0, 0, 1], [1, 0, 0]] |
| ], |
| } |
|
|
|
|
| @pytest.fixture |
| def mock_flattened_data(): |
| """Simulates the output of the HF map functions.""" |
| return [ |
| {"sentence": ["The", "cat", "sat", "."], "mentions": [[0, 1]]}, |
| {"sentence": ["It", "was", "happy", "."], "mentions": [[0, 0]]}, |
| {"sentence": ["No", "mentions"], "mentions": []}, |
| ] |
|
|
|
|
| |
|
|
| def test_extract_spans_from_bio_simple(): |
| sentence = [ |
| {"token": "John", "bio_tags": ["B-PER"]}, |
| {"token": "Smith", "bio_tags": ["I-PER"]}, |
| {"token": "works", "bio_tags": ["O"]}, |
| {"token": "at", "bio_tags": ["O"]}, |
| {"token": "Google", "bio_tags": ["B-ORG"]}, |
| ] |
|
|
| spans, labels = extract_spans_from_bio(sentence) |
|
|
| |
| assert spans == [(0, 1), (4, 4)] |
| assert labels == ["PER", "ORG"] |
|
|
|
|
| def test_extract_spans_handles_single_token_entity(): |
| sentence = [ |
| {"token": "Paris", "bio_tags": ["B-LOC"]}, |
| {"token": "is", "bio_tags": ["O"]}, |
| ] |
|
|
| spans, labels = extract_spans_from_bio(sentence) |
|
|
| assert spans == [(0, 0)] |
| assert labels == ["LOC"] |
|
|
|
|
| def test_litbank_entity_dataset_getitem(): |
| fake_dataset = [ |
| { |
| "sentence": ["John", "works"], |
| "entity_spans": [(0, 1)], |
| "entity_labels": ["PER"], |
| } |
| ] |
|
|
| ds = LitBankEntityDataset(fake_dataset) |
| item = ds[0] |
|
|
| assert item["sentence"] == ["John", "works"] |
| assert torch.equal(item["starts"], torch.tensor([1, 0])) |
| assert item["entity_spans"] == [(0, 1)] |
| assert item["entity_labels"] == ["PER"] |
| assert item["task_id"] == 1 |
|
|
|
|
|
|
| def test_flatten_entities(): |
| batch = { |
| "entities": [ |
| [ |
| [ |
| {"token": "John", "bio_tags": ["B-PER"]}, |
| {"token": "Smith", "bio_tags": ["I-PER"]}, |
| ] |
| ] |
| ] |
| } |
|
|
| output = flatten_entities(batch) |
|
|
| assert output["sentence"] == [["John", "Smith"]] |
| assert output["entity_spans"] == [[(0, 1)]] |
| assert output["entity_labels"] == [["PER"]] |
|
|
|
|
| def test_entity_collate_fn_basic(): |
| batch = [ |
| { |
| "sentence": ["John", "works"], |
| "starts": torch.tensor([1, 0]), |
| "entity_spans": [(0, 1)], |
| "entity_labels": ["PER"], |
| "task_id": 1, |
| } |
| ] |
|
|
| output = entity_collate_fn(batch) |
|
|
| assert output["starts"].shape == (1, 2) |
| assert output["spans"].shape == (1, 2, 2) |
| assert output["spans"][0, 0, 1] == 1 |
| assert output["gold_labels"][0] == {(0, 1): "PER"} |
| assert output["task_id"].shape == (1,) |
|
|
|
|
| def test_mentions_by_sentence_grouping(mock_raw_example): |
| """Verify coref chains are correctly mapped to sentence indices as strings.""" |
| result = mentions_by_sentence(mock_raw_example) |
| assert "mentions" in result |
| |
| assert (0, 1) in result["mentions"]["0"] |
| |
| assert (0, 0) in result["mentions"]["1"] |
|
|
|
|
| def test_flatten_to_sentences_alignment(mock_raw_example): |
| """Verify flattening expands 1 doc -> N sentences with correct mention alignment.""" |
| |
| processed = mentions_by_sentence(mock_raw_example) |
| |
| batch = {k: [v] for k, v in processed.items()} |
|
|
| flattened = flatten_to_sentences(batch) |
|
|
| assert len(flattened["sentence"]) == 2 |
| assert flattened["mentions"][0] == [(0, 1)] |
| assert flattened["mentions"][1] == [(0, 0)] |
|
|
|
|
| def test_dataset_tensor_logic(mock_flattened_data): |
| """Verify the 2D span_labels are correctly populated (inclusive indexing).""" |
| ds = LitBankMentionDataset(mock_flattened_data) |
|
|
| |
| item = ds[0] |
| assert item["starts"][0] == 1 |
| assert item["span_labels"][0, 1] == 1 |
| assert item["span_labels"].sum() == 1 |
|
|
| |
| empty_item = ds[2] |
| assert empty_item["starts"].sum() == 0 |
| assert empty_item["span_labels"].sum() == 0 |
|
|
|
|
| def test_collate_masking_and_shapes(mock_flattened_data): |
| """Verify the 2D mask logic: upper triangle + is_start.""" |
| ds = LitBankMentionDataset(mock_flattened_data) |
| |
| batch = [ds[0], ds[1], ds[2]] |
| collated = collate_fn(batch) |
|
|
| B, N = 3, 4 |
| assert collated["starts"].shape == (B, N) |
| assert collated["spans"].shape == (B, N, N) |
|
|
| |
| |
| |
| mask = collated["span_loss_mask"] |
|
|
| |
| assert mask[0, 0, 0] == True |
| assert mask[0, 0, 1] == True |
|
|
| |
| assert torch.all(mask[0, 2, :] == False) |
|
|
|
|
| def test_out_of_bounds_guard(): |
| """Ensure indexing doesn't crash if data has an error.""" |
| bad_data = [{"sentence": ["Short"], "mentions": [[0, 10]]}] |
| ds = LitBankMentionDataset(bad_data) |
| |
| item = ds[0] |
| assert item["span_labels"].sum() == 0 |
|
|
|
|
| |
| def test_make_litbank_integration(): |
| """Check if the real pipeline loads and provides a valid batch.""" |
| try: |
| data = make_litbank(tag="split_0") |
| batch = next(iter(data.train_loader)) |
|
|
| assert "sentences" in batch |
| assert "span_loss_mask" in batch |
| assert isinstance(batch["spans"], torch.Tensor) |
| except Exception as e: |
| pytest.fail(f"Integration test failed: {e}") |
|
|
|
|
| @patch("mentioned.data.load_dataset") |
| def test_make_litbank_entity(mock_load_dataset): |
|
|
| |
| |
| |
| class FakeSplit(list): |
| @property |
| def column_names(self): |
| return list(self[0].keys()) if self else [] |
|
|
| |
| |
| |
| class DummyDataset(dict): |
| def map(self, fn, batched=False, remove_columns=None): |
| mapped = {} |
|
|
| for split_name, split_data in self.items(): |
| if not split_data: |
| mapped[split_name] = FakeSplit([]) |
| continue |
|
|
| if batched: |
| batch = { |
| "entities": [item["entities"] for item in split_data] |
| } |
|
|
| result = fn(batch) |
|
|
| new_split = [] |
| for i in range(len(result["sentence"])): |
| new_split.append({ |
| "sentence": result["sentence"][i], |
| "entity_spans": result["entity_spans"][i], |
| "entity_labels": result["entity_labels"][i], |
| }) |
|
|
| mapped[split_name] = FakeSplit(new_split) |
| else: |
| mapped[split_name] = FakeSplit(split_data) |
|
|
| return DummyDataset(mapped) |
|
|
| |
| |
| |
| fake_data = DummyDataset({ |
| "train": FakeSplit([ |
| { |
| "entities": [ |
| [ |
| {"token": "John", "bio_tags": ["B-PER"]}, |
| {"token": "Smith", "bio_tags": ["I-PER"]}, |
| ] |
| ] |
| } |
| ]), |
| "validation": FakeSplit([]), |
| "test": FakeSplit([]), |
| }) |
|
|
| mock_load_dataset.return_value = fake_data |
|
|
| |
| |
| |
| data = make_litbank_entity() |
|
|
| batch = next(iter(data.train_loader)) |
| print(batch) |
| |
| |
| |
| assert "starts" in batch |
| assert "spans" in batch |
| assert "gold_labels" in batch |
|
|
| |
| assert batch["spans"].sum() > 0 |
| assert batch["gold_labels"][0] == {(0, 1): "PER"} |
|
|