CAFF / scripts /split_aware_analysis.py
MrDhifallah's picture
Add files using upload-large-folder tool
1c45044 verified
#!/usr/bin/env python
"""
split_aware_analysis.py -- Stratify test F1 by seed novelty.
Splits the test set into two groups based on whether each query's
seeds were seen during training:
- seen_seeds: at least one of the query's seeds appears in train
- unseen_seeds: none of the query's seeds appear in train
Then computes F1 (and per-hop, per-relation) separately on each group.
If F1_unseen is close to F1_seen, the model generalizes to novel
seed entities; if F1_unseen is much lower, the model is memorizing.
This is the closest "external-like" validation available without
introducing a separate KG or dataset.
Usage:
python scripts/split_aware_analysis.py \
--checkpoint runs/no_dc/seed_42/best.pt \
--train-split data/processed/train.json \
--test-split data/processed/test.json \
--threshold 0.80 \
--mode autoregressive \
--output-json results/split_aware_seed42.json
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
ROOT = Path(__file__).parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from caff import (
AblationFlags,
CAFFConfig,
CAFFEvaluator,
CAFFModel,
CAFFTripleDataset,
CachedBFSExtractor,
FrozenBioEncoder,
KnowledgeGraph,
RelationEmbeddingCache,
load_qa_split,
)
from caff.evaluator import precision_recall_f1
from caff.utils.logging import setup_logging
logger = logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="Split-aware F1 stratification.")
p.add_argument("--checkpoint", required=True)
p.add_argument("--train-split", required=True,
help="Train JSON, used to compute the set of seen seeds.")
p.add_argument("--test-split", default=None)
p.add_argument("--cache-dir", default="cache")
p.add_argument("--mode", default="autoregressive",
choices=["teacher_forced", "autoregressive"])
p.add_argument("--threshold", type=float, default=None)
p.add_argument("--output-json", default=None)
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
return p.parse_args()
def load_checkpoint(ckpt_path: str, device: str, cache_dir: Path):
payload = torch.load(ckpt_path, map_location=device)
config = CAFFConfig(**payload["config"])
ablation = AblationFlags()
logger.info(f"Loading KG from {config.kg_path}...")
kg = KnowledgeGraph.from_tsv(config.kg_path, min_relation_freq=50)
encoder = FrozenBioEncoder(config.encoder_name, device=device)
rel_cache = RelationEmbeddingCache(
encoder, kg.relations,
cache_path=cache_dir / "relation_embeddings.pt",
)
model = CAFFModel(config, rel_cache, ablation=ablation).to(device)
model.load_state_dict(payload["model"])
model.eval()
logger.info(f"Restored checkpoint from {ckpt_path}")
return model, config, encoder, kg
def compute_seen_seeds(train_path: str) -> set[str]:
"""Collect the set of all seeds that appear in the training split."""
train_recs = load_qa_split(train_path)
seen: set[str] = set()
for rec in train_recs:
# rec.seeds is a list of seed strings; rec might be a dict or dataclass.
if hasattr(rec, "seeds"):
seeds = rec.seeds
else:
seeds = rec.get("seeds", [])
for s in seeds:
seen.add(s)
return seen
def classify_query(seeds: list[str], seen_set: set[str]) -> str:
"""Classify a test query as 'seen' if any of its seeds are in train, else 'unseen'."""
for s in seeds:
if s in seen_set:
return "seen"
return "unseen"
def main() -> None:
args = parse_args()
setup_logging(level="INFO")
cache_dir = Path(args.cache_dir)
# Step 1: build the set of seen seeds from train
logger.info(f"Building seen-seeds set from {args.train_split}...")
seen_seeds = compute_seen_seeds(args.train_split)
logger.info(f" {len(seen_seeds):,} unique training seeds")
# Step 2: load checkpoint
model, config, encoder, kg = load_checkpoint(args.checkpoint, args.device, cache_dir)
# Step 3: load test set and classify each query
test_path = args.test_split or config.test_path
test_recs = load_qa_split(test_path)
# Build query_id -> group mapping
qid_to_group: dict[str, str] = {}
n_seen, n_unseen = 0, 0
for rec in test_recs:
if hasattr(rec, "query_id"):
qid = rec.query_id
seeds = rec.seeds
else:
qid = rec.get("query_id")
seeds = rec.get("seeds", [])
group = classify_query(seeds, seen_seeds)
qid_to_group[qid] = group
if group == "seen":
n_seen += 1
else:
n_unseen += 1
logger.info(f" Test queries: {len(test_recs):,}")
logger.info(f" seen-seed group: {n_seen:,} ({n_seen/len(test_recs)*100:.1f}%)")
logger.info(f" unseen-seed group: {n_unseen:,} ({n_unseen/len(test_recs)*100:.1f}%)")
# Step 4: score the test set
bfs = CachedBFSExtractor(kg, L=config.L, K_r=config.K_r,
cache_dir=cache_dir / "bfs")
test_ds = CAFFTripleDataset(test_recs, bfs, require_gold=True)
threshold = args.threshold if args.threshold is not None else config.theta
evaluator = CAFFEvaluator(
config=config, encoder=encoder, mode=args.mode, threshold=threshold,
)
logger.info(f"Scoring test set (mode={args.mode}, theta={threshold})...")
scores, instances, _retained = evaluator._score_dataset(model, test_ds)
scores_np = scores.cpu().numpy() if torch.is_tensor(scores) else np.asarray(scores)
# Step 5: aggregate by group
by_group_scores: dict[str, list[float]] = defaultdict(list)
by_group_labels: dict[str, list[int]] = defaultdict(list)
# Also stratify by (group, hop) for a finer view
by_group_hop_scores: dict[tuple[str, int], list[float]] = defaultdict(list)
by_group_hop_labels: dict[tuple[str, int], list[int]] = defaultdict(list)
for inst, sc in zip(instances, scores_np.tolist()):
group = qid_to_group.get(inst.query_id, "unknown")
by_group_scores[group].append(sc)
by_group_labels[group].append(inst.label)
by_group_hop_scores[(group, inst.hop)].append(sc)
by_group_hop_labels[(group, inst.hop)].append(inst.label)
# Step 6: print results
print()
print("=" * 88)
print(f"Split-aware F1 stratification (mode={args.mode}, theta={threshold})")
print(f"Checkpoint: {args.checkpoint}")
print("=" * 88)
print(f"Seen-seed group: {n_seen:>5} queries (seed appears in train)")
print(f"Unseen-seed group: {n_unseen:>5} queries (no seed appears in train)")
print()
print(f"{'group':<10} | {'n_total':>8} | {'n_pos':>6} | {'pos%':>6} | "
f"{'prec':>6} | {'recall':>6} | {'F1':>6}")
print("-" * 66)
group_rows = []
for group in ["seen", "unseen"]:
if group not in by_group_scores:
continue
s = np.asarray(by_group_scores[group])
l = np.asarray(by_group_labels[group])
n_total = len(l)
n_pos = int(l.sum())
pos_rate = n_pos / n_total if n_total > 0 else 0.0
m = precision_recall_f1(s, l, threshold=threshold)
group_rows.append({
"group": group,
"n_total": n_total,
"n_pos": n_pos,
"pos_rate": pos_rate,
"precision": m["precision"],
"recall": m["recall"],
"f1": m["f1"],
})
print(f"{group:<10} | {n_total:>8} | {n_pos:>6} | {pos_rate*100:>5.1f}% | "
f"{m['precision']:>6.4f} | {m['recall']:>6.4f} | {m['f1']:>6.4f}")
# Compute the generalization gap
if len(group_rows) == 2:
f1_seen = group_rows[0]["f1"]
f1_unseen = group_rows[1]["f1"]
gap = f1_seen - f1_unseen
rel_gap = (gap / f1_seen * 100) if f1_seen > 0 else 0.0
print()
print(f"Generalization gap: F1_seen - F1_unseen = {gap:+.4f} ({rel_gap:+.1f}%)")
if abs(gap) < 0.02:
print(f" ==> Small gap; model generalizes well to novel seeds.")
elif gap > 0:
print(f" ==> Larger F1 on seen seeds; some memorization effect.")
else:
print(f" ==> Larger F1 on unseen; unusual.")
# Per (group, hop) breakdown
print()
print("=" * 88)
print(f"Per (group, hop) breakdown")
print("=" * 88)
print(f"{'group':<10} | {'hop':>4} | {'n_total':>8} | {'n_pos':>6} | "
f"{'prec':>6} | {'recall':>6} | {'F1':>6}")
print("-" * 66)
grouphop_rows = []
for group in ["seen", "unseen"]:
for hop in [1, 2, 3]:
key = (group, hop)
if key not in by_group_hop_scores:
continue
s = np.asarray(by_group_hop_scores[key])
l = np.asarray(by_group_hop_labels[key])
n_total = len(l)
n_pos = int(l.sum())
if n_total == 0:
continue
m = precision_recall_f1(s, l, threshold=threshold)
grouphop_rows.append({
"group": group, "hop": hop,
"n_total": n_total, "n_pos": n_pos,
"precision": m["precision"], "recall": m["recall"], "f1": m["f1"],
})
print(f"{group:<10} | {hop:>4} | {n_total:>8} | {n_pos:>6} | "
f"{m['precision']:>6.4f} | {m['recall']:>6.4f} | {m['f1']:>6.4f}")
print("=" * 88)
# Save JSON
if args.output_json:
out = {
"checkpoint": str(args.checkpoint),
"mode": args.mode,
"threshold": threshold,
"n_seen_seeds_in_train": len(seen_seeds),
"n_queries_seen_group": n_seen,
"n_queries_unseen_group": n_unseen,
"by_group": group_rows,
"by_group_hop": grouphop_rows,
}
out_path = Path(args.output_json)
out_path.parent.mkdir(parents=True, exist_ok=True)
with out_path.open("w", encoding="utf-8") as f:
json.dump(out, f, indent=2)
logger.info(f"Results written to {out_path}")
if __name__ == "__main__":
main()