mgnify-evo2-probes / code /probes /make_reads_score_plot.py
JG1310's picture
Probe artifacts: code, manifests, plots, scores, summaries, checkpoints, docs
eb69de4 verified
"""Generate score-distribution plot for the v1 AMR probe applied to 100 simulated
short reads. Loads from the cached probe_on_reads_test.json (no Modal needed).
Output mirrors the AMR distribution plot: PNG + scores.jsonl + summary.json.
Usage:
python probes/make_reads_score_plot.py
"""
from __future__ import annotations
import json
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, precision_recall_curve
SRC = Path("/home/ror25cal/MGnify/probes/results/probe_on_reads_test.json")
OUT_PNG = Path("/home/ror25cal/MGnify/probes/results/probe_on_reads_score_distributions.png")
def stats(scores, labels):
auc = roc_auc_score(labels, scores)
prec, rec, thr = precision_recall_curve(labels, scores)
f1 = 2 * prec * rec / np.maximum(prec + rec, 1e-9)
idx = int(np.argmax(f1))
best_thr = float(thr[min(idx, len(thr) - 1)])
return float(auc), best_thr, float(f1[idx])
def main():
raw = json.loads(SRC.read_text())
records = raw["results"]
labels = np.array([r["label"] for r in records])
max_logits = np.array([r["max_logit"] for r in records])
mean_logits = np.array([r["mean_logit"] for r in records])
auc_max, t_max, f1_max = stats(max_logits, labels)
auc_mean, t_mean, f1_mean = stats(mean_logits, labels)
print(f"max-pool: AUC={auc_max:.4f} best-F1 thr={t_max:.3f} F1={f1_max:.3f}")
print(f"mean-pool: AUC={auc_mean:.4f} best-F1 thr={t_mean:.3f} F1={f1_mean:.3f}")
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, scores, name, auc, best_thr, f1 in [
(axes[0], max_logits, "max-pool", auc_max, t_max, f1_max),
(axes[1], mean_logits, "mean-pool", auc_mean, t_mean, f1_mean),
]:
pos = scores[labels == 1]
neg = scores[labels == 0]
bins = np.linspace(scores.min() - 0.5, scores.max() + 0.5, 30)
ax.hist(neg, bins=bins, alpha=0.55, label=f"matched-neg (n={len(neg)})", color="tab:red", density=True)
ax.hist(pos, bins=bins, alpha=0.55, label=f"AMR positive (n={len(pos)})", color="tab:green", density=True)
rug_y = ax.get_ylim()[0] - 0.02 * (ax.get_ylim()[1] - ax.get_ylim()[0])
ax.scatter(neg, np.full_like(neg, rug_y), marker="|", s=80, color="tab:red", alpha=0.6)
ax.scatter(pos, np.full_like(pos, rug_y * 1.5), marker="|", s=80, color="tab:green", alpha=0.6)
ax.axvline(0, color="grey", lw=1, ls="--", label="default boundary (logit=0)")
ax.axvline(best_thr, color="black", lw=1, ls=":", label=f"best-F1 boundary ({best_thr:.2f})")
ax.set_xlabel("per-read probe logit (= w · h_pooled + b)", fontsize=10)
ax.set_ylabel("density", fontsize=10)
ax.set_title(f"{name} • AUC={auc:.4f} • best-F1={f1:.3f}", fontsize=11)
ax.legend(fontsize=8, loc="upper left")
ax.grid(True, alpha=0.2)
fig.suptitle(
"v1 AMR linear probe applied to simulated MiSeq short reads (301 bp)\n"
"50 pos reads from MGYG000307615_01006 (abc-f MACROLIDE, 57% ID to ref) • "
"50 neg reads from matched MGYG000307615_00395\n"
"MAG was in v1 val split — never trained on. Pooling over all ~301 tokens of each read.",
fontsize=10,
)
fig.tight_layout()
OUT_PNG.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(OUT_PNG, dpi=120, bbox_inches="tight")
plt.close(fig)
print(f"saved {OUT_PNG.stat().st_size/1024:.1f} KB to {OUT_PNG}")
# Raw scores
out_jsonl = OUT_PNG.with_suffix(".scores.jsonl")
with open(out_jsonl, "w") as f:
for r in records:
f.write(json.dumps({
"read_id": r["read_id"],
"label": r["label"],
"seq_len": r["seq_len"],
"max_logit": r["max_logit"],
"mean_logit": r["mean_logit"],
"median_logit": r.get("median_logit"),
}) + "\n")
print(f"saved {len(records)} per-read scores to {out_jsonl}")
out_summary = OUT_PNG.with_suffix(".summary.json")
out_summary.write_text(json.dumps({
"n_pos": int((labels == 1).sum()),
"n_neg": int((labels == 0).sum()),
"max_pool": {"auc": auc_max, "best_f1_threshold": t_max, "best_f1": f1_max},
"mean_pool": {"auc": auc_mean, "best_f1_threshold": t_mean, "best_f1": f1_mean},
}, indent=2))
print(f"saved aggregate summary to {out_summary}")
if __name__ == "__main__":
main()