File size: 4,332 Bytes
1e05592 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """BELIEF span extraction diagram β compact, large fonts, no title."""
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Rectangle
import numpy as np
OUT = Path("PROJECT_ROOT/figs/modelarchi")
OUT.mkdir(parents=True, exist_ok=True)
TOKENS = [
("...", "normal"),
("<|BELIEF|>", "belief_tag"),
("lead", "belief_content"),
("truck", "belief_content"),
("cut", "belief_content"),
("in", "belief_content"),
("from", "belief_content"),
("right", "belief_content"),
("lane", "belief_content"),
(",", "belief_content"),
("TTC", "belief_content"),
("narrowing", "belief_content"),
("</|BELIEF|>", "belief_tag"),
("<|OBSERVE|>", "action_tag"),
("...", "normal"),
]
COLORS = {
"normal": ("#d1d5db", "#9ca3af", "#444444"),
"belief_tag": ("#f59e0b", "#b45309", "#78350f"),
"belief_content": ("#fef3c7", "#d97706", "#78350f"),
"action_tag": ("#fecaca", "#b91c1c", "#7f1d1d"),
}
C_DANGER = "#d8c7fa"
C_DANGER_EC = "#7c3aed"
C_DANGER_TC = "#5b21b6"
C_POLICY = "#e4ffc2"
C_POLICY_EC = "#65a30d"
C_POLICY_TC = "#3f6212"
def main():
fig, ax = plt.subplots(figsize=(14, 5.2))
ax.set_xlim(0, 14)
ax.set_ylim(0, 5.2)
ax.set_aspect("equal")
ax.axis("off")
# ββ Token boxes ββ
tok_y = 2.5
tok_h = 0.55
x = 0.15
gap = 0.07
positions = []
for text, ttype in TOKENS:
fc, ec, tc = COLORS[ttype]
is_tag = ttype in ("belief_tag", "action_tag")
w = max(0.52, len(text) * 0.11 + 0.22) if not is_tag else max(0.9, len(text) * 0.085 + 0.22)
fs = 11 if is_tag else 13
ax.add_patch(Rectangle((x, tok_y), w, tok_h,
fc=fc, ec=ec, lw=1.3, zorder=2))
ax.text(x + w/2, tok_y + tok_h/2, text,
fontsize=fs, ha="center", va="center",
color=tc, fontweight="bold" if is_tag else "normal",
family="monospace" if is_tag else "sans-serif", zorder=3)
positions.append((x, x + w, ttype))
x += w + gap
# ββ Hidden state bars ββ
hs_y = tok_y - 0.12
hs_h = 0.45
for xl, xr, ttype in positions:
if ttype == "normal":
c = "#d1d5db"
elif ttype in ("belief_tag", "belief_content"):
c = "#fbbf24"
else:
c = "#f87171"
ax.add_patch(Rectangle((xl, hs_y - hs_h), xr - xl, hs_h,
fc=c, ec="white", lw=0.4, alpha=0.3, zorder=1))
ax.text(0.0, hs_y - hs_h/2, "$h^{(\\ell)}$",
fontsize=15, ha="center", va="center", color="#555", fontstyle="italic")
# ββ Bottom: span-pool bracket β DangerHead ββ
# Bracket starts just before <|BELIEF|> (index 1), covers content through index 11
bx1 = positions[1][0] - 0.03
bx2 = positions[11][1]
by = hs_y - hs_h - 0.08
# Curly-brace style bracket
ax.annotate("", xy=(bx1, by), xytext=(bx1, by - 0.18),
arrowprops=dict(arrowstyle="-", color="#d97706", lw=2.0))
ax.plot([bx1, bx2], [by - 0.18, by - 0.18], color="#d97706", lw=2.2)
ax.annotate("", xy=(bx2, by), xytext=(bx2, by - 0.18),
arrowprops=dict(arrowstyle="-", color="#d97706", lw=2.0))
ax.text((bx1 + bx2) / 2, by - 0.45,
"mean-pool β $z_t^{(f)} \\in \\mathbb{R}^{10240}$ (DangerHead)",
fontsize=14, ha="center", color="#b45309", fontweight="bold")
# ββ Top: close-tag β PolicyHead ββ
ct_xl = positions[12][0]
ct_xr = positions[12][1]
ct_mid = (ct_xl + ct_xr) / 2
ty = tok_y + tok_h + 0.06
ax.plot([ct_xl, ct_xl, ct_xr, ct_xr], [ty, ty + 0.12, ty + 0.12, ty],
color=C_POLICY_EC, lw=2.2, solid_capstyle="round")
ax.text(ct_mid, ty + 0.35,
"hidden state β $r_t^{(f)} \\in \\mathbb{R}^{2560}$ (PolicyHead)",
fontsize=14, ha="center", color=C_POLICY_TC, fontweight="bold")
# (legend removed)
fig.savefig(OUT / "belief_span.png", dpi=300, bbox_inches="tight", facecolor="white")
fig.savefig(OUT / "belief_span.pdf", bbox_inches="tight", facecolor="white")
plt.close()
print(f"Saved β {OUT}/belief_span.{{png,pdf}}")
if __name__ == "__main__":
main()
|