"""BELIEF span extraction diagram — compact, large fonts, no title.""" from pathlib import Path import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from matplotlib.patches import FancyBboxPatch, Rectangle import numpy as np OUT = Path("PROJECT_ROOT/figs/modelarchi") OUT.mkdir(parents=True, exist_ok=True) TOKENS = [ ("...", "normal"), ("<|BELIEF|>", "belief_tag"), ("lead", "belief_content"), ("truck", "belief_content"), ("cut", "belief_content"), ("in", "belief_content"), ("from", "belief_content"), ("right", "belief_content"), ("lane", "belief_content"), (",", "belief_content"), ("TTC", "belief_content"), ("narrowing", "belief_content"), ("", "belief_tag"), ("<|OBSERVE|>", "action_tag"), ("...", "normal"), ] COLORS = { "normal": ("#d1d5db", "#9ca3af", "#444444"), "belief_tag": ("#f59e0b", "#b45309", "#78350f"), "belief_content": ("#fef3c7", "#d97706", "#78350f"), "action_tag": ("#fecaca", "#b91c1c", "#7f1d1d"), } C_DANGER = "#d8c7fa" C_DANGER_EC = "#7c3aed" C_DANGER_TC = "#5b21b6" C_POLICY = "#e4ffc2" C_POLICY_EC = "#65a30d" C_POLICY_TC = "#3f6212" def main(): fig, ax = plt.subplots(figsize=(14, 5.2)) ax.set_xlim(0, 14) ax.set_ylim(0, 5.2) ax.set_aspect("equal") ax.axis("off") # ── Token boxes ── tok_y = 2.5 tok_h = 0.55 x = 0.15 gap = 0.07 positions = [] for text, ttype in TOKENS: fc, ec, tc = COLORS[ttype] is_tag = ttype in ("belief_tag", "action_tag") w = max(0.52, len(text) * 0.11 + 0.22) if not is_tag else max(0.9, len(text) * 0.085 + 0.22) fs = 11 if is_tag else 13 ax.add_patch(Rectangle((x, tok_y), w, tok_h, fc=fc, ec=ec, lw=1.3, zorder=2)) ax.text(x + w/2, tok_y + tok_h/2, text, fontsize=fs, ha="center", va="center", color=tc, fontweight="bold" if is_tag else "normal", family="monospace" if is_tag else "sans-serif", zorder=3) positions.append((x, x + w, ttype)) x += w + gap # ── Hidden state bars ── hs_y = tok_y - 0.12 hs_h = 0.45 for xl, xr, ttype in positions: if ttype == "normal": c = "#d1d5db" elif ttype in ("belief_tag", "belief_content"): c = "#fbbf24" else: c = "#f87171" ax.add_patch(Rectangle((xl, hs_y - hs_h), xr - xl, hs_h, fc=c, ec="white", lw=0.4, alpha=0.3, zorder=1)) ax.text(0.0, hs_y - hs_h/2, "$h^{(\\ell)}$", fontsize=15, ha="center", va="center", color="#555", fontstyle="italic") # ── Bottom: span-pool bracket → DangerHead ── # Bracket starts just before <|BELIEF|> (index 1), covers content through index 11 bx1 = positions[1][0] - 0.03 bx2 = positions[11][1] by = hs_y - hs_h - 0.08 # Curly-brace style bracket ax.annotate("", xy=(bx1, by), xytext=(bx1, by - 0.18), arrowprops=dict(arrowstyle="-", color="#d97706", lw=2.0)) ax.plot([bx1, bx2], [by - 0.18, by - 0.18], color="#d97706", lw=2.2) ax.annotate("", xy=(bx2, by), xytext=(bx2, by - 0.18), arrowprops=dict(arrowstyle="-", color="#d97706", lw=2.0)) ax.text((bx1 + bx2) / 2, by - 0.45, "mean-pool → $z_t^{(f)} \\in \\mathbb{R}^{10240}$ (DangerHead)", fontsize=14, ha="center", color="#b45309", fontweight="bold") # ── Top: close-tag → PolicyHead ── ct_xl = positions[12][0] ct_xr = positions[12][1] ct_mid = (ct_xl + ct_xr) / 2 ty = tok_y + tok_h + 0.06 ax.plot([ct_xl, ct_xl, ct_xr, ct_xr], [ty, ty + 0.12, ty + 0.12, ty], color=C_POLICY_EC, lw=2.2, solid_capstyle="round") ax.text(ct_mid, ty + 0.35, "hidden state → $r_t^{(f)} \\in \\mathbb{R}^{2560}$ (PolicyHead)", fontsize=14, ha="center", color=C_POLICY_TC, fontweight="bold") # (legend removed) fig.savefig(OUT / "belief_span.png", dpi=300, bbox_inches="tight", facecolor="white") fig.savefig(OUT / "belief_span.pdf", bbox_inches="tight", facecolor="white") plt.close() print(f"Saved → {OUT}/belief_span.{{png,pdf}}") if __name__ == "__main__": main()