| """BELIEF span extraction diagram β compact, large fonts, no title.""" |
| from pathlib import Path |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from matplotlib.patches import FancyBboxPatch, Rectangle |
| import numpy as np |
|
|
| OUT = Path("PROJECT_ROOT/figs/modelarchi") |
| OUT.mkdir(parents=True, exist_ok=True) |
|
|
| TOKENS = [ |
| ("...", "normal"), |
| ("<|BELIEF|>", "belief_tag"), |
| ("lead", "belief_content"), |
| ("truck", "belief_content"), |
| ("cut", "belief_content"), |
| ("in", "belief_content"), |
| ("from", "belief_content"), |
| ("right", "belief_content"), |
| ("lane", "belief_content"), |
| (",", "belief_content"), |
| ("TTC", "belief_content"), |
| ("narrowing", "belief_content"), |
| ("</|BELIEF|>", "belief_tag"), |
| ("<|OBSERVE|>", "action_tag"), |
| ("...", "normal"), |
| ] |
|
|
| COLORS = { |
| "normal": ("#d1d5db", "#9ca3af", "#444444"), |
| "belief_tag": ("#f59e0b", "#b45309", "#78350f"), |
| "belief_content": ("#fef3c7", "#d97706", "#78350f"), |
| "action_tag": ("#fecaca", "#b91c1c", "#7f1d1d"), |
| } |
|
|
| C_DANGER = "#d8c7fa" |
| C_DANGER_EC = "#7c3aed" |
| C_DANGER_TC = "#5b21b6" |
| C_POLICY = "#e4ffc2" |
| C_POLICY_EC = "#65a30d" |
| C_POLICY_TC = "#3f6212" |
|
|
|
|
| def main(): |
| fig, ax = plt.subplots(figsize=(14, 5.2)) |
| ax.set_xlim(0, 14) |
| ax.set_ylim(0, 5.2) |
| ax.set_aspect("equal") |
| ax.axis("off") |
|
|
| |
| tok_y = 2.5 |
| tok_h = 0.55 |
| x = 0.15 |
| gap = 0.07 |
| positions = [] |
|
|
| for text, ttype in TOKENS: |
| fc, ec, tc = COLORS[ttype] |
| is_tag = ttype in ("belief_tag", "action_tag") |
| w = max(0.52, len(text) * 0.11 + 0.22) if not is_tag else max(0.9, len(text) * 0.085 + 0.22) |
| fs = 11 if is_tag else 13 |
|
|
| ax.add_patch(Rectangle((x, tok_y), w, tok_h, |
| fc=fc, ec=ec, lw=1.3, zorder=2)) |
| ax.text(x + w/2, tok_y + tok_h/2, text, |
| fontsize=fs, ha="center", va="center", |
| color=tc, fontweight="bold" if is_tag else "normal", |
| family="monospace" if is_tag else "sans-serif", zorder=3) |
| positions.append((x, x + w, ttype)) |
| x += w + gap |
|
|
| |
| hs_y = tok_y - 0.12 |
| hs_h = 0.45 |
| for xl, xr, ttype in positions: |
| if ttype == "normal": |
| c = "#d1d5db" |
| elif ttype in ("belief_tag", "belief_content"): |
| c = "#fbbf24" |
| else: |
| c = "#f87171" |
| ax.add_patch(Rectangle((xl, hs_y - hs_h), xr - xl, hs_h, |
| fc=c, ec="white", lw=0.4, alpha=0.3, zorder=1)) |
|
|
| ax.text(0.0, hs_y - hs_h/2, "$h^{(\\ell)}$", |
| fontsize=15, ha="center", va="center", color="#555", fontstyle="italic") |
|
|
| |
| |
| bx1 = positions[1][0] - 0.03 |
| bx2 = positions[11][1] |
| by = hs_y - hs_h - 0.08 |
|
|
| |
| ax.annotate("", xy=(bx1, by), xytext=(bx1, by - 0.18), |
| arrowprops=dict(arrowstyle="-", color="#d97706", lw=2.0)) |
| ax.plot([bx1, bx2], [by - 0.18, by - 0.18], color="#d97706", lw=2.2) |
| ax.annotate("", xy=(bx2, by), xytext=(bx2, by - 0.18), |
| arrowprops=dict(arrowstyle="-", color="#d97706", lw=2.0)) |
|
|
| ax.text((bx1 + bx2) / 2, by - 0.45, |
| "mean-pool β $z_t^{(f)} \\in \\mathbb{R}^{10240}$ (DangerHead)", |
| fontsize=14, ha="center", color="#b45309", fontweight="bold") |
|
|
| |
| ct_xl = positions[12][0] |
| ct_xr = positions[12][1] |
| ct_mid = (ct_xl + ct_xr) / 2 |
| ty = tok_y + tok_h + 0.06 |
|
|
| ax.plot([ct_xl, ct_xl, ct_xr, ct_xr], [ty, ty + 0.12, ty + 0.12, ty], |
| color=C_POLICY_EC, lw=2.2, solid_capstyle="round") |
|
|
| ax.text(ct_mid, ty + 0.35, |
| "hidden state β $r_t^{(f)} \\in \\mathbb{R}^{2560}$ (PolicyHead)", |
| fontsize=14, ha="center", color=C_POLICY_TC, fontweight="bold") |
|
|
| |
|
|
| fig.savefig(OUT / "belief_span.png", dpi=300, bbox_inches="tight", facecolor="white") |
| fig.savefig(OUT / "belief_span.pdf", bbox_inches="tight", facecolor="white") |
| plt.close() |
| print(f"Saved β {OUT}/belief_span.{{png,pdf}}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|