VLAlert / tools /render_belief_span.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
4.33 kB
"""BELIEF span extraction diagram β€” compact, large fonts, no title."""
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Rectangle
import numpy as np
OUT = Path("PROJECT_ROOT/figs/modelarchi")
OUT.mkdir(parents=True, exist_ok=True)
TOKENS = [
("...", "normal"),
("<|BELIEF|>", "belief_tag"),
("lead", "belief_content"),
("truck", "belief_content"),
("cut", "belief_content"),
("in", "belief_content"),
("from", "belief_content"),
("right", "belief_content"),
("lane", "belief_content"),
(",", "belief_content"),
("TTC", "belief_content"),
("narrowing", "belief_content"),
("</|BELIEF|>", "belief_tag"),
("<|OBSERVE|>", "action_tag"),
("...", "normal"),
]
COLORS = {
"normal": ("#d1d5db", "#9ca3af", "#444444"),
"belief_tag": ("#f59e0b", "#b45309", "#78350f"),
"belief_content": ("#fef3c7", "#d97706", "#78350f"),
"action_tag": ("#fecaca", "#b91c1c", "#7f1d1d"),
}
C_DANGER = "#d8c7fa"
C_DANGER_EC = "#7c3aed"
C_DANGER_TC = "#5b21b6"
C_POLICY = "#e4ffc2"
C_POLICY_EC = "#65a30d"
C_POLICY_TC = "#3f6212"
def main():
fig, ax = plt.subplots(figsize=(14, 5.2))
ax.set_xlim(0, 14)
ax.set_ylim(0, 5.2)
ax.set_aspect("equal")
ax.axis("off")
# ── Token boxes ──
tok_y = 2.5
tok_h = 0.55
x = 0.15
gap = 0.07
positions = []
for text, ttype in TOKENS:
fc, ec, tc = COLORS[ttype]
is_tag = ttype in ("belief_tag", "action_tag")
w = max(0.52, len(text) * 0.11 + 0.22) if not is_tag else max(0.9, len(text) * 0.085 + 0.22)
fs = 11 if is_tag else 13
ax.add_patch(Rectangle((x, tok_y), w, tok_h,
fc=fc, ec=ec, lw=1.3, zorder=2))
ax.text(x + w/2, tok_y + tok_h/2, text,
fontsize=fs, ha="center", va="center",
color=tc, fontweight="bold" if is_tag else "normal",
family="monospace" if is_tag else "sans-serif", zorder=3)
positions.append((x, x + w, ttype))
x += w + gap
# ── Hidden state bars ──
hs_y = tok_y - 0.12
hs_h = 0.45
for xl, xr, ttype in positions:
if ttype == "normal":
c = "#d1d5db"
elif ttype in ("belief_tag", "belief_content"):
c = "#fbbf24"
else:
c = "#f87171"
ax.add_patch(Rectangle((xl, hs_y - hs_h), xr - xl, hs_h,
fc=c, ec="white", lw=0.4, alpha=0.3, zorder=1))
ax.text(0.0, hs_y - hs_h/2, "$h^{(\\ell)}$",
fontsize=15, ha="center", va="center", color="#555", fontstyle="italic")
# ── Bottom: span-pool bracket β†’ DangerHead ──
# Bracket starts just before <|BELIEF|> (index 1), covers content through index 11
bx1 = positions[1][0] - 0.03
bx2 = positions[11][1]
by = hs_y - hs_h - 0.08
# Curly-brace style bracket
ax.annotate("", xy=(bx1, by), xytext=(bx1, by - 0.18),
arrowprops=dict(arrowstyle="-", color="#d97706", lw=2.0))
ax.plot([bx1, bx2], [by - 0.18, by - 0.18], color="#d97706", lw=2.2)
ax.annotate("", xy=(bx2, by), xytext=(bx2, by - 0.18),
arrowprops=dict(arrowstyle="-", color="#d97706", lw=2.0))
ax.text((bx1 + bx2) / 2, by - 0.45,
"mean-pool β†’ $z_t^{(f)} \\in \\mathbb{R}^{10240}$ (DangerHead)",
fontsize=14, ha="center", color="#b45309", fontweight="bold")
# ── Top: close-tag β†’ PolicyHead ──
ct_xl = positions[12][0]
ct_xr = positions[12][1]
ct_mid = (ct_xl + ct_xr) / 2
ty = tok_y + tok_h + 0.06
ax.plot([ct_xl, ct_xl, ct_xr, ct_xr], [ty, ty + 0.12, ty + 0.12, ty],
color=C_POLICY_EC, lw=2.2, solid_capstyle="round")
ax.text(ct_mid, ty + 0.35,
"hidden state β†’ $r_t^{(f)} \\in \\mathbb{R}^{2560}$ (PolicyHead)",
fontsize=14, ha="center", color=C_POLICY_TC, fontweight="bold")
# (legend removed)
fig.savefig(OUT / "belief_span.png", dpi=300, bbox_inches="tight", facecolor="white")
fig.savefig(OUT / "belief_span.pdf", bbox_inches="tight", facecolor="white")
plt.close()
print(f"Saved β†’ {OUT}/belief_span.{{png,pdf}}")
if __name__ == "__main__":
main()