| |
| """ |
| split_aware_analysis.py -- Stratify test F1 by seed novelty. |
| |
| Splits the test set into two groups based on whether each query's |
| seeds were seen during training: |
| |
| - seen_seeds: at least one of the query's seeds appears in train |
| - unseen_seeds: none of the query's seeds appear in train |
| |
| Then computes F1 (and per-hop, per-relation) separately on each group. |
| If F1_unseen is close to F1_seen, the model generalizes to novel |
| seed entities; if F1_unseen is much lower, the model is memorizing. |
| |
| This is the closest "external-like" validation available without |
| introducing a separate KG or dataset. |
| |
| Usage: |
| python scripts/split_aware_analysis.py \ |
| --checkpoint runs/no_dc/seed_42/best.pt \ |
| --train-split data/processed/train.json \ |
| --test-split data/processed/test.json \ |
| --threshold 0.80 \ |
| --mode autoregressive \ |
| --output-json results/split_aware_seed42.json |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import logging |
| import sys |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| ROOT = Path(__file__).parent.parent |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from caff import ( |
| AblationFlags, |
| CAFFConfig, |
| CAFFEvaluator, |
| CAFFModel, |
| CAFFTripleDataset, |
| CachedBFSExtractor, |
| FrozenBioEncoder, |
| KnowledgeGraph, |
| RelationEmbeddingCache, |
| load_qa_split, |
| ) |
| from caff.evaluator import precision_recall_f1 |
| from caff.utils.logging import setup_logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser(description="Split-aware F1 stratification.") |
| p.add_argument("--checkpoint", required=True) |
| p.add_argument("--train-split", required=True, |
| help="Train JSON, used to compute the set of seen seeds.") |
| p.add_argument("--test-split", default=None) |
| p.add_argument("--cache-dir", default="cache") |
| p.add_argument("--mode", default="autoregressive", |
| choices=["teacher_forced", "autoregressive"]) |
| p.add_argument("--threshold", type=float, default=None) |
| p.add_argument("--output-json", default=None) |
| p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| return p.parse_args() |
|
|
|
|
| def load_checkpoint(ckpt_path: str, device: str, cache_dir: Path): |
| payload = torch.load(ckpt_path, map_location=device) |
| config = CAFFConfig(**payload["config"]) |
| ablation = AblationFlags() |
| logger.info(f"Loading KG from {config.kg_path}...") |
| kg = KnowledgeGraph.from_tsv(config.kg_path, min_relation_freq=50) |
| encoder = FrozenBioEncoder(config.encoder_name, device=device) |
| rel_cache = RelationEmbeddingCache( |
| encoder, kg.relations, |
| cache_path=cache_dir / "relation_embeddings.pt", |
| ) |
| model = CAFFModel(config, rel_cache, ablation=ablation).to(device) |
| model.load_state_dict(payload["model"]) |
| model.eval() |
| logger.info(f"Restored checkpoint from {ckpt_path}") |
| return model, config, encoder, kg |
|
|
|
|
| def compute_seen_seeds(train_path: str) -> set[str]: |
| """Collect the set of all seeds that appear in the training split.""" |
| train_recs = load_qa_split(train_path) |
| seen: set[str] = set() |
| for rec in train_recs: |
| |
| if hasattr(rec, "seeds"): |
| seeds = rec.seeds |
| else: |
| seeds = rec.get("seeds", []) |
| for s in seeds: |
| seen.add(s) |
| return seen |
|
|
|
|
| def classify_query(seeds: list[str], seen_set: set[str]) -> str: |
| """Classify a test query as 'seen' if any of its seeds are in train, else 'unseen'.""" |
| for s in seeds: |
| if s in seen_set: |
| return "seen" |
| return "unseen" |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| setup_logging(level="INFO") |
| cache_dir = Path(args.cache_dir) |
|
|
| |
| logger.info(f"Building seen-seeds set from {args.train_split}...") |
| seen_seeds = compute_seen_seeds(args.train_split) |
| logger.info(f" {len(seen_seeds):,} unique training seeds") |
|
|
| |
| model, config, encoder, kg = load_checkpoint(args.checkpoint, args.device, cache_dir) |
|
|
| |
| test_path = args.test_split or config.test_path |
| test_recs = load_qa_split(test_path) |
|
|
| |
| qid_to_group: dict[str, str] = {} |
| n_seen, n_unseen = 0, 0 |
| for rec in test_recs: |
| if hasattr(rec, "query_id"): |
| qid = rec.query_id |
| seeds = rec.seeds |
| else: |
| qid = rec.get("query_id") |
| seeds = rec.get("seeds", []) |
| group = classify_query(seeds, seen_seeds) |
| qid_to_group[qid] = group |
| if group == "seen": |
| n_seen += 1 |
| else: |
| n_unseen += 1 |
| logger.info(f" Test queries: {len(test_recs):,}") |
| logger.info(f" seen-seed group: {n_seen:,} ({n_seen/len(test_recs)*100:.1f}%)") |
| logger.info(f" unseen-seed group: {n_unseen:,} ({n_unseen/len(test_recs)*100:.1f}%)") |
|
|
| |
| bfs = CachedBFSExtractor(kg, L=config.L, K_r=config.K_r, |
| cache_dir=cache_dir / "bfs") |
| test_ds = CAFFTripleDataset(test_recs, bfs, require_gold=True) |
| threshold = args.threshold if args.threshold is not None else config.theta |
| evaluator = CAFFEvaluator( |
| config=config, encoder=encoder, mode=args.mode, threshold=threshold, |
| ) |
| logger.info(f"Scoring test set (mode={args.mode}, theta={threshold})...") |
| scores, instances, _retained = evaluator._score_dataset(model, test_ds) |
| scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else np.asarray(scores) |
|
|
| |
| by_group_scores: dict[str, list[float]] = defaultdict(list) |
| by_group_labels: dict[str, list[int]] = defaultdict(list) |
| |
| by_group_hop_scores: dict[tuple[str, int], list[float]] = defaultdict(list) |
| by_group_hop_labels: dict[tuple[str, int], list[int]] = defaultdict(list) |
|
|
| for inst, sc in zip(instances, scores_np.tolist()): |
| group = qid_to_group.get(inst.query_id, "unknown") |
| by_group_scores[group].append(sc) |
| by_group_labels[group].append(inst.label) |
| by_group_hop_scores[(group, inst.hop)].append(sc) |
| by_group_hop_labels[(group, inst.hop)].append(inst.label) |
|
|
| |
| print() |
| print("=" * 88) |
| print(f"Split-aware F1 stratification (mode={args.mode}, theta={threshold})") |
| print(f"Checkpoint: {args.checkpoint}") |
| print("=" * 88) |
| print(f"Seen-seed group: {n_seen:>5} queries (seed appears in train)") |
| print(f"Unseen-seed group: {n_unseen:>5} queries (no seed appears in train)") |
| print() |
| print(f"{'group':<10} | {'n_total':>8} | {'n_pos':>6} | {'pos%':>6} | " |
| f"{'prec':>6} | {'recall':>6} | {'F1':>6}") |
| print("-" * 66) |
|
|
| group_rows = [] |
| for group in ["seen", "unseen"]: |
| if group not in by_group_scores: |
| continue |
| s = np.asarray(by_group_scores[group]) |
| l = np.asarray(by_group_labels[group]) |
| n_total = len(l) |
| n_pos = int(l.sum()) |
| pos_rate = n_pos / n_total if n_total > 0 else 0.0 |
| m = precision_recall_f1(s, l, threshold=threshold) |
| group_rows.append({ |
| "group": group, |
| "n_total": n_total, |
| "n_pos": n_pos, |
| "pos_rate": pos_rate, |
| "precision": m["precision"], |
| "recall": m["recall"], |
| "f1": m["f1"], |
| }) |
| print(f"{group:<10} | {n_total:>8} | {n_pos:>6} | {pos_rate*100:>5.1f}% | " |
| f"{m['precision']:>6.4f} | {m['recall']:>6.4f} | {m['f1']:>6.4f}") |
|
|
| |
| if len(group_rows) == 2: |
| f1_seen = group_rows[0]["f1"] |
| f1_unseen = group_rows[1]["f1"] |
| gap = f1_seen - f1_unseen |
| rel_gap = (gap / f1_seen * 100) if f1_seen > 0 else 0.0 |
| print() |
| print(f"Generalization gap: F1_seen - F1_unseen = {gap:+.4f} ({rel_gap:+.1f}%)") |
| if abs(gap) < 0.02: |
| print(f" ==> Small gap; model generalizes well to novel seeds.") |
| elif gap > 0: |
| print(f" ==> Larger F1 on seen seeds; some memorization effect.") |
| else: |
| print(f" ==> Larger F1 on unseen; unusual.") |
|
|
| |
| print() |
| print("=" * 88) |
| print(f"Per (group, hop) breakdown") |
| print("=" * 88) |
| print(f"{'group':<10} | {'hop':>4} | {'n_total':>8} | {'n_pos':>6} | " |
| f"{'prec':>6} | {'recall':>6} | {'F1':>6}") |
| print("-" * 66) |
| grouphop_rows = [] |
| for group in ["seen", "unseen"]: |
| for hop in [1, 2, 3]: |
| key = (group, hop) |
| if key not in by_group_hop_scores: |
| continue |
| s = np.asarray(by_group_hop_scores[key]) |
| l = np.asarray(by_group_hop_labels[key]) |
| n_total = len(l) |
| n_pos = int(l.sum()) |
| if n_total == 0: |
| continue |
| m = precision_recall_f1(s, l, threshold=threshold) |
| grouphop_rows.append({ |
| "group": group, "hop": hop, |
| "n_total": n_total, "n_pos": n_pos, |
| "precision": m["precision"], "recall": m["recall"], "f1": m["f1"], |
| }) |
| print(f"{group:<10} | {hop:>4} | {n_total:>8} | {n_pos:>6} | " |
| f"{m['precision']:>6.4f} | {m['recall']:>6.4f} | {m['f1']:>6.4f}") |
| print("=" * 88) |
|
|
| |
| if args.output_json: |
| out = { |
| "checkpoint": str(args.checkpoint), |
| "mode": args.mode, |
| "threshold": threshold, |
| "n_seen_seeds_in_train": len(seen_seeds), |
| "n_queries_seen_group": n_seen, |
| "n_queries_unseen_group": n_unseen, |
| "by_group": group_rows, |
| "by_group_hop": grouphop_rows, |
| } |
| out_path = Path(args.output_json) |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| with out_path.open("w", encoding="utf-8") as f: |
| json.dump(out, f, indent=2) |
| logger.info(f"Results written to {out_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|