| """Generate score-distribution plot for the v1 AMR probe applied to 100 simulated |
| short reads. Loads from the cached probe_on_reads_test.json (no Modal needed). |
| |
| Output mirrors the AMR distribution plot: PNG + scores.jsonl + summary.json. |
| |
| Usage: |
| python probes/make_reads_score_plot.py |
| """ |
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import matplotlib |
|
|
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import roc_auc_score, precision_recall_curve |
|
|
|
|
| SRC = Path("/home/ror25cal/MGnify/probes/results/probe_on_reads_test.json") |
| OUT_PNG = Path("/home/ror25cal/MGnify/probes/results/probe_on_reads_score_distributions.png") |
|
|
|
|
| def stats(scores, labels): |
| auc = roc_auc_score(labels, scores) |
| prec, rec, thr = precision_recall_curve(labels, scores) |
| f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9) |
| idx = int(np.argmax(f1)) |
| best_thr = float(thr[min(idx, len(thr) - 1)]) |
| return float(auc), best_thr, float(f1[idx]) |
|
|
|
|
| def main(): |
| raw = json.loads(SRC.read_text()) |
| records = raw["results"] |
| labels = np.array([r["label"] for r in records]) |
| max_logits = np.array([r["max_logit"] for r in records]) |
| mean_logits = np.array([r["mean_logit"] for r in records]) |
|
|
| auc_max, t_max, f1_max = stats(max_logits, labels) |
| auc_mean, t_mean, f1_mean = stats(mean_logits, labels) |
| print(f"max-pool: AUC={auc_max:.4f} best-F1 thr={t_max:.3f} F1={f1_max:.3f}") |
| print(f"mean-pool: AUC={auc_mean:.4f} best-F1 thr={t_mean:.3f} F1={f1_mean:.3f}") |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) |
| for ax, scores, name, auc, best_thr, f1 in [ |
| (axes[0], max_logits, "max-pool", auc_max, t_max, f1_max), |
| (axes[1], mean_logits, "mean-pool", auc_mean, t_mean, f1_mean), |
| ]: |
| pos = scores[labels == 1] |
| neg = scores[labels == 0] |
| bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 30) |
| ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg (n={len(neg)})", color="tab:red", density=True) |
| ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive (n={len(pos)})", color="tab:green", density=True) |
| rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0]) |
| ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=80, color="tab:red", alpha=0.6) |
| ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=80, color="tab:green", alpha=0.6) |
| ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)") |
| ax.axvline(best_thr, color="black", lw=1, ls=":", label=f"best-F1 boundary ({best_thr:.2f})") |
| ax.set_xlabel("per-read probe logit (= w · h_pooled + b)", fontsize=10) |
| ax.set_ylabel("density", fontsize=10) |
| ax.set_title(f"{name} • AUC={auc:.4f} • best-F1={f1:.3f}", fontsize=11) |
| ax.legend(fontsize=8, loc="upper left") |
| ax.grid(True, alpha=0.2) |
|
|
| fig.suptitle( |
| "v1 AMR linear probe applied to simulated MiSeq short reads (301 bp)\n" |
| "50 pos reads from MGYG000307615_01006 (abc-f MACROLIDE, 57% ID to ref) • " |
| "50 neg reads from matched MGYG000307615_00395\n" |
| "MAG was in v1 val split — never trained on. Pooling over all ~301 tokens of each read.", |
| fontsize=10, |
| ) |
| fig.tight_layout() |
| OUT_PNG.parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(OUT_PNG, dpi=120, bbox_inches="tight") |
| plt.close(fig) |
| print(f"saved {OUT_PNG.stat().st_size/1024:.1f} KB to {OUT_PNG}") |
|
|
| |
| out_jsonl = OUT_PNG.with_suffix(".scores.jsonl") |
| with open(out_jsonl, "w") as f: |
| for r in records: |
| f.write(json.dumps({ |
| "read_id": r["read_id"], |
| "label": r["label"], |
| "seq_len": r["seq_len"], |
| "max_logit": r["max_logit"], |
| "mean_logit": r["mean_logit"], |
| "median_logit": r.get("median_logit"), |
| }) + "\n") |
| print(f"saved {len(records)} per-read scores to {out_jsonl}") |
|
|
| out_summary = OUT_PNG.with_suffix(".summary.json") |
| out_summary.write_text(json.dumps({ |
| "n_pos": int((labels == 1).sum()), |
| "n_neg": int((labels == 0).sum()), |
| "max_pool": {"auc": auc_max, "best_f1_threshold": t_max, "best_f1": f1_max}, |
| "mean_pool": {"auc": auc_mean, "best_f1_threshold": t_mean, "best_f1": f1_mean}, |
| }, indent=2)) |
| print(f"saved aggregate summary to {out_summary}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|