"""Generate score-distribution plot for the v1 AMR probe applied to 100 simulated short reads. Loads from the cached probe_on_reads_test.json (no Modal needed). Output mirrors the AMR distribution plot: PNG + scores.jsonl + summary.json. Usage: python probes/make_reads_score_plot.py """ from __future__ import annotations import json from pathlib import Path import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from sklearn.metrics import roc_auc_score, precision_recall_curve SRC = Path("/home/ror25cal/MGnify/probes/results/probe_on_reads_test.json") OUT_PNG = Path("/home/ror25cal/MGnify/probes/results/probe_on_reads_score_distributions.png") def stats(scores, labels): auc = roc_auc_score(labels, scores) prec, rec, thr = precision_recall_curve(labels, scores) f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9) idx = int(np.argmax(f1)) best_thr = float(thr[min(idx, len(thr) - 1)]) return float(auc), best_thr, float(f1[idx]) def main(): raw = json.loads(SRC.read_text()) records = raw["results"] labels = np.array([r["label"] for r in records]) max_logits = np.array([r["max_logit"] for r in records]) mean_logits = np.array([r["mean_logit"] for r in records]) auc_max, t_max, f1_max = stats(max_logits, labels) auc_mean, t_mean, f1_mean = stats(mean_logits, labels) print(f"max-pool: AUC={auc_max:.4f} best-F1 thr={t_max:.3f} F1={f1_max:.3f}") print(f"mean-pool: AUC={auc_mean:.4f} best-F1 thr={t_mean:.3f} F1={f1_mean:.3f}") fig, axes = plt.subplots(1, 2, figsize=(14, 5)) for ax, scores, name, auc, best_thr, f1 in [ (axes[0], max_logits, "max-pool", auc_max, t_max, f1_max), (axes[1], mean_logits, "mean-pool", auc_mean, t_mean, f1_mean), ]: pos = scores[labels == 1] neg = scores[labels == 0] bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 30) ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg (n={len(neg)})", color="tab:red", density=True) ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive (n={len(pos)})", color="tab:green", density=True) rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=80, color="tab:red", alpha=0.6) ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=80, color="tab:green", alpha=0.6) ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)") ax.axvline(best_thr, color="black", lw=1, ls=":", label=f"best-F1 boundary ({best_thr:.2f})") ax.set_xlabel("per-read probe logit (= w · h_pooled + b)", fontsize=10) ax.set_ylabel("density", fontsize=10) ax.set_title(f"{name} • AUC={auc:.4f} • best-F1={f1:.3f}", fontsize=11) ax.legend(fontsize=8, loc="upper left") ax.grid(True, alpha=0.2) fig.suptitle( "v1 AMR linear probe applied to simulated MiSeq short reads (301 bp)\n" "50 pos reads from MGYG000307615_01006 (abc-f MACROLIDE, 57% ID to ref) • " "50 neg reads from matched MGYG000307615_00395\n" "MAG was in v1 val split — never trained on. Pooling over all ~301 tokens of each read.", fontsize=10, ) fig.tight_layout() OUT_PNG.parent.mkdir(parents=True, exist_ok=True) fig.savefig(OUT_PNG, dpi=120, bbox_inches="tight") plt.close(fig) print(f"saved {OUT_PNG.stat().st_size/1024:.1f} KB to {OUT_PNG}") # Raw scores out_jsonl = OUT_PNG.with_suffix(".scores.jsonl") with open(out_jsonl, "w") as f: for r in records: f.write(json.dumps({ "read_id": r["read_id"], "label": r["label"], "seq_len": r["seq_len"], "max_logit": r["max_logit"], "mean_logit": r["mean_logit"], "median_logit": r.get("median_logit"), }) + "\n") print(f"saved {len(records)} per-read scores to {out_jsonl}") out_summary = OUT_PNG.with_suffix(".summary.json") out_summary.write_text(json.dumps({ "n_pos": int((labels == 1).sum()), "n_neg": int((labels == 0).sum()), "max_pool": {"auc": auc_max, "best_f1_threshold": t_max, "best_f1": f1_max}, "mean_pool": {"auc": auc_mean, "best_f1_threshold": t_mean, "best_f1": f1_mean}, }, indent=2)) print(f"saved aggregate summary to {out_summary}") if __name__ == "__main__": main()