| import evaluate |
| import pytest |
|
|
| ner_eval = evaluate.load("ner_eval.py") |
|
|
| test_cases = [ |
| { |
| "predictions": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"], |
| "references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "O", "B-ORG"], |
| "results": { |
| "overall": { |
| "strict_precision": 1.0, |
| "strict_recall": 1.0, |
| "strict_f1": 1.0, |
| "ent_type_precision": 1.0, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 1.0, |
| "partial_precision": 1.0, |
| "partial_recall": 1.0, |
| "partial_f1": 1.0, |
| "exact_precision": 1.0, |
| "exact_recall": 1.0, |
| "exact_f1": 1.0, |
| }, |
| "LOC": { |
| "strict_precision": 1.0, |
| "strict_recall": 1.0, |
| "strict_f1": 1.0, |
| "ent_type_precision": 1.0, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 1.0, |
| "partial_precision": 1.0, |
| "partial_recall": 1.0, |
| "partial_f1": 1.0, |
| "exact_precision": 1.0, |
| "exact_recall": 1.0, |
| "exact_f1": 1.0, |
| }, |
| "PER": { |
| "strict_precision": 1.0, |
| "strict_recall": 1.0, |
| "strict_f1": 1.0, |
| "ent_type_precision": 1.0, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 1.0, |
| "partial_precision": 1.0, |
| "partial_recall": 1.0, |
| "partial_f1": 1.0, |
| "exact_precision": 1.0, |
| "exact_recall": 1.0, |
| "exact_f1": 1.0, |
| }, |
| "ORG": { |
| "strict_precision": 1.0, |
| "strict_recall": 1.0, |
| "strict_f1": 1.0, |
| "ent_type_precision": 1.0, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 1.0, |
| "partial_precision": 1.0, |
| "partial_recall": 1.0, |
| "partial_f1": 1.0, |
| "exact_precision": 1.0, |
| "exact_recall": 1.0, |
| "exact_f1": 1.0, |
| }, |
| }, |
| }, |
| { |
| "predictions": [ |
| "B-LOC", |
| "I-LOC", |
| "O", |
| "B-PER", |
| "I-PER", |
| "I-PER", |
| "I-PER", |
| "O", |
| "B-LOC", |
| "O", |
| ], |
| "references": [ |
| "B-LOC", |
| "I-LOC", |
| "O", |
| "B-PER", |
| "I-PER", |
| "I-PER", |
| "I-PER", |
| "O", |
| "B-LOC", |
| "O", |
| ], |
| "results": { |
| "overall": { |
| "strict_precision": 1.0, |
| "strict_recall": 1.0, |
| "strict_f1": 1.0, |
| "ent_type_precision": 1.0, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 1.0, |
| "partial_precision": 1.0, |
| "partial_recall": 1.0, |
| "partial_f1": 1.0, |
| "exact_precision": 1.0, |
| "exact_recall": 1.0, |
| "exact_f1": 1.0, |
| }, |
| "LOC": { |
| "strict_precision": 1.0, |
| "strict_recall": 1.0, |
| "strict_f1": 1.0, |
| "ent_type_precision": 1.0, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 1.0, |
| "partial_precision": 1.0, |
| "partial_recall": 1.0, |
| "partial_f1": 1.0, |
| "exact_precision": 1.0, |
| "exact_recall": 1.0, |
| "exact_f1": 1.0, |
| }, |
| "PER": { |
| "strict_precision": 1.0, |
| "strict_recall": 1.0, |
| "strict_f1": 1.0, |
| "ent_type_precision": 1.0, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 1.0, |
| "partial_precision": 1.0, |
| "partial_recall": 1.0, |
| "partial_f1": 1.0, |
| "exact_precision": 1.0, |
| "exact_recall": 1.0, |
| "exact_f1": 1.0, |
| }, |
| }, |
| }, |
| { |
| "predictions": ["O", "B-LOC", "I-LOC", "B-PER", "I-PER", "O", "B-ORG"], |
| "references": ["O", "B-LOC", "I-LOC", "O", "B-PER", "I-PER", "O", "B-ORG"], |
| }, |
| { |
| "predictions": ["B-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG", "I-ORG"], |
| "references": ["B-PER", "I-PER", "O", "B-LOC", "I-LOC", "O", "B-ORG"], |
| "results": { |
| "overall": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "ORG": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "PER": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "LOC": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| }, |
| }, |
| { |
| "predictions": [ |
| "B-LOC", |
| "I-LOC", |
| "I-LOC", |
| "B-ORG", |
| "I-ORG", |
| "O", |
| "B-PER", |
| "I-PER", |
| "I-PER", |
| "O", |
| ], |
| "references": [ |
| "B-LOC", |
| "I-LOC", |
| "O", |
| "O", |
| "B-ORG", |
| "I-ORG", |
| "O", |
| "B-PER", |
| "I-PER", |
| "O", |
| ], |
| "results": { |
| "overall": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 2 / 3, |
| "ent_type_recall": 2 / 3, |
| "ent_type_f1": 2 / 3, |
| "partial_precision": 1 / 3, |
| "partial_recall": 1 / 3, |
| "partial_f1": 1 / 3, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "ORG": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.0, |
| "ent_type_recall": 0.0, |
| "ent_type_f1": 0, |
| "partial_precision": 0.0, |
| "partial_recall": 0.0, |
| "partial_f1": 0, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "PER": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.5, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 2 / 3, |
| "partial_precision": 0.25, |
| "partial_recall": 0.5, |
| "partial_f1": 1 / 3, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| "LOC": { |
| "strict_precision": 0.0, |
| "strict_recall": 0.0, |
| "strict_f1": 0, |
| "ent_type_precision": 0.5, |
| "ent_type_recall": 1.0, |
| "ent_type_f1": 2 / 3, |
| "partial_precision": 0.25, |
| "partial_recall": 0.5, |
| "partial_f1": 1 / 3, |
| "exact_precision": 0.0, |
| "exact_recall": 0.0, |
| "exact_f1": 0, |
| }, |
| }, |
| }, |
| ] |
|
|
|
|
| def compare_results(result1, result2): |
| |
| if isinstance(result1, dict): |
| for key in result1.keys(): |
| if not compare_results(result1[key], result2[key]): |
| return False |
| return True |
| elif isinstance(result1, list): |
| for item1, item2 in zip(result1, result2): |
| if not compare_results(item1, item2): |
| return False |
| return True |
| else: |
| return result1 == result2 |
|
|
|
|
| @pytest.mark.parametrize("case", test_cases) |
| def test_metric(case): |
| if "results" not in case: |
| with pytest.raises(ValueError): |
| results = ner_eval.compute( |
| predictions=[case["predictions"]], references=[case["references"]] |
| ) |
| else: |
| results = ner_eval.compute( |
| predictions=[case["predictions"]], references=[case["references"]] |
| ) |
| assert compare_results(results, case["results"]) |
|
|