VLAlert / tools /render_modelarchi_v4.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
8.31 kB
"""VLAlert Architecture v4 β€” clean academic flowchart.
Horizontal pipeline, minimal text, publication-ready.
Bottom: hidden state extraction diagram showing BELIEF span β†’ z_t, close-tag β†’ r_t.
Output: figs/modelarchi/modelarchi_v4.{png,pdf}
"""
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch, Rectangle
import numpy as np
ROOT = Path("PROJECT_ROOT")
OUT = ROOT / "figs/modelarchi"
OUT.mkdir(parents=True, exist_ok=True)
C_INPUT = "#e2e8f0"
C_VLM = "#fde68a"
C_BLIEF = "#fed7aa"
C_DHEAD = "#bbf7d0"
C_PHEAD = "#dbeafe"
C_FSM = "#e9d5ff"
C_ACT = "#fecaca"
C_FB = "#dc2626"
C_BSPAN = "#fef3c7"
def box(ax, x, y, w, h, lines, *, fc, ec="#334155", fs=10, lw=1.4):
ax.add_patch(FancyBboxPatch(
(x, y), w, h, boxstyle="round,pad=0.08,rounding_size=0.12",
lw=lw, ec=ec, fc=fc, zorder=2))
if isinstance(lines, str):
lines = [lines]
n = len(lines)
for i, line in enumerate(lines):
yi = y + h/2 + (n/2 - i - 0.5) * fs * 0.015
fw = "bold" if i == 0 else "normal"
ax.text(x + w/2, yi, line, ha="center", va="center",
fontsize=fs if i == 0 else fs - 1, fontweight=fw,
color="#1e293b", zorder=3)
def arr(ax, x1, y1, x2, y2, *, color="#334155", lw=1.6, label="", lfs=7,
label_above=True):
ax.add_patch(FancyArrowPatch(
(x1, y1), (x2, y2),
arrowstyle="->,head_length=8,head_width=5",
color=color, lw=lw, zorder=1))
if label:
mx, my = (x1+x2)/2, (y1+y2)/2
offset = 0.18 if label_above else -0.18
ax.text(mx, my + offset, label, fontsize=lfs, ha="center",
color=color, fontstyle="italic")
def main():
fig, ax = plt.subplots(figsize=(16, 7.5))
ax.set_xlim(0, 16)
ax.set_ylim(0, 7.5)
ax.set_aspect("equal")
ax.axis("off")
# ═══════════════════════════════════════════════════════
# Top row: main pipeline (y β‰ˆ 5.5)
# ═══════════════════════════════════════════════════════
Y = 5.5
H = 1.0
G = 0.3
# 1. Input
bx1 = 0.3
box(ax, bx1, Y-H/2, 1.5, H, ["Video Sampler", "$X_t$"],
fc=C_INPUT, fs=10)
for i in range(5):
ax.add_patch(Rectangle((0.45 + i*0.2, Y+H/2+0.08), 0.16, 0.12,
fc="#94a3b8", ec="#64748b", lw=0.5, zorder=2))
ax.text(0.95, Y+H/2+0.3, "8 frames", fontsize=7, ha="center", color="#64748b")
# 2. VLM
bx2 = bx1 + 1.5 + G
box(ax, bx2, Y-H/2, 2.2, H, ["VLM Extractor", "Qwen3-VL-4B + LoRA"],
fc=C_VLM, fs=10)
arr(ax, bx1+1.5, Y, bx2, Y)
# 3. Belief / Register (stacked)
bx3 = bx2 + 2.2 + G
box(ax, bx3, Y+0.08, 2.0, H/2-0.05,
["Belief $z_t \\in \\mathbb{R}^{8{\\times}10240}$"],
fc=C_BLIEF, ec="#c2410c", fs=9)
box(ax, bx3, Y-H/2, 2.0, H/2-0.05,
["Register $r_t \\in \\mathbb{R}^{8{\\times}2560}$"],
fc=C_BLIEF, ec="#c2410c", fs=9)
arr(ax, bx2+2.2, Y+0.3, bx3, Y+0.3, label="L{20..32}", lfs=6)
arr(ax, bx2+2.2, Y-0.2, bx3, Y-0.2, label="L33", lfs=6)
# 4. DangerHead
bx4 = bx3 + 2.0 + G
box(ax, bx4, Y-H/2, 1.6, H, ["DangerHead", "$d_t, \\, \\mathcal{S}_t$"],
fc=C_DHEAD, ec="#15803d", fs=10)
arr(ax, bx3+2.0, Y+0.3, bx4, Y+0.1, label="$z_t$", lfs=8)
# 5. PolicyHead
bx5 = bx4 + 1.6 + G
box(ax, bx5, Y-H/2, 1.6, H, ["PolicyHead", "$\\pi_t$"],
fc=C_PHEAD, ec="#1d4ed8", fs=10)
arr(ax, bx4+1.6, Y+0.1, bx5, Y+0.1, label="$\\mathcal{S}_t, d_t$", lfs=7)
arr(ax, bx3+2.0, Y-0.2, bx5, Y-0.2, label="$r_t$", lfs=8, color="#6366f1")
# 6. FSM
bx6 = bx5 + 1.6 + G
box(ax, bx6, Y-H/2, 1.2, H, ["FSM", "Decoder"],
fc=C_FSM, ec="#7c3aed", fs=10)
arr(ax, bx5+1.6, Y, bx6, Y)
# 7. Action
bx7 = bx6 + 1.2 + G
box(ax, bx7, Y-H/2, 1.5, H, ["Action $a_t$", "{Sil, Obs, Alrt}"],
fc=C_ACT, ec="#b91c1c", fs=10)
arr(ax, bx6+1.2, Y, bx7, Y)
# ── Feedback: Action β†’ Video Sampler (bottom loop) ──
fb_y = Y - H/2 - 0.6
# Action bottom
ax.plot([bx7+0.75, bx7+0.75], [Y-H/2, fb_y], color=C_FB, lw=2.0, zorder=1)
# Horizontal
ax.plot([bx1+0.75, bx7+0.75], [fb_y, fb_y], color=C_FB, lw=2.0, zorder=1)
# Up to Sampler
ax.annotate("", xy=(bx1+0.75, Y-H/2), xytext=(bx1+0.75, fb_y),
arrowprops=dict(arrowstyle="-|>", color=C_FB, lw=2.0))
ax.text((bx1+bx7+0.75)/2, fb_y-0.22,
"$a_{t-1}$ feedback (re-targets sampling window)",
fontsize=9, ha="center", color=C_FB, fontweight="bold")
# ═══════════════════════════════════════════════════════
# Bottom: Hidden state extraction diagram
# ═══════════════════════════════════════════════════════
# Title
ax.text(8.0, 3.25, "Hidden State Extraction from BELIEF Span",
fontsize=12, fontweight="bold", ha="center", color="#334155")
# Token bar
tok_y = 2.3
tok_h = 0.4
tokens = [
("...", "#e5e7eb", "#9ca3af", 0.4),
("<|BELIEF|>", "#f59e0b", "#d97706", 1.0),
("lead", C_BSPAN, "#f59e0b", 0.5),
("truck", C_BSPAN, "#f59e0b", 0.55),
("cut-in,", C_BSPAN, "#f59e0b", 0.6),
("TTC↓", C_BSPAN, "#f59e0b", 0.5),
("</|BELIEF|>", "#f59e0b", "#d97706", 1.1),
("<|OBS|>", "#fecaca", "#dc2626", 0.7),
("...", "#e5e7eb", "#9ca3af", 0.4),
]
x = 2.5
positions = {}
for i, (text, fc, ec, w) in enumerate(tokens):
ax.add_patch(Rectangle((x, tok_y), w, tok_h, fc=fc, ec=ec, lw=1.0, zorder=2))
is_tag = text.startswith("<|")
ax.text(x+w/2, tok_y+tok_h/2, text, fontsize=7 if is_tag else 8,
ha="center", va="center", color="#78350f",
fontweight="bold" if is_tag else "normal", zorder=3)
positions[i] = (x, x+w)
x += w + 0.06
# Bracket: span-pool range (tokens 1-5, between open and close)
sp_x1 = positions[2][0]
sp_x2 = positions[5][1]
by = tok_y - 0.05
ax.plot([sp_x1, sp_x1, sp_x2, sp_x2], [by, by-0.12, by-0.12, by],
color="#d97706", lw=1.5)
ax.text((sp_x1+sp_x2)/2, by-0.28,
"mean-pool β†’ $z_t^{(f)} \\in \\mathbb{R}^{10240}$",
fontsize=9, ha="center", color="#d97706", fontweight="bold")
ax.text((sp_x1+sp_x2)/2, by-0.52,
"layers {20, 24, 28, 32} concat",
fontsize=7, ha="center", color="#92400e")
# Arrow down to DangerHead label
arr(ax, (sp_x1+sp_x2)/2, by-0.65, (sp_x1+sp_x2)/2, by-1.0,
color="#d97706", lw=1.2)
box(ax, (sp_x1+sp_x2)/2-0.8, by-1.45, 1.6, 0.4,
["β†’ DangerHead"], fc=C_DHEAD, ec="#15803d", fs=9)
# Close-tag position (token 6)
ct_x = (positions[6][0] + positions[6][1]) / 2
ct_by = tok_y + tok_h + 0.05
ax.plot([ct_x, ct_x], [ct_by, ct_by+0.15], color="#2563eb", lw=1.5)
ax.text(ct_x, ct_by+0.3,
"hidden at close-tag β†’ $r_t^{(f)} \\in \\mathbb{R}^{2560}$",
fontsize=9, ha="center", color="#2563eb", fontweight="bold")
ax.text(ct_x, ct_by+0.55, "layer 33", fontsize=7, ha="center", color="#3b82f6")
# Arrow up to PolicyHead label
arr(ax, ct_x, ct_by+0.7, ct_x, ct_by+1.0, color="#2563eb", lw=1.2)
box(ax, ct_x-0.8, ct_by+1.0, 1.6, 0.4,
["β†’ PolicyHead"], fc=C_PHEAD, ec="#1d4ed8", fs=9)
# Label the token bar
ax.text(2.0, tok_y + tok_h/2, "VLM\noutput\ntokens",
fontsize=7, ha="center", va="center", color="#666")
fig.savefig(OUT / "modelarchi_v4.png", dpi=250, bbox_inches="tight",
facecolor="white")
fig.savefig(OUT / "modelarchi_v4.pdf", bbox_inches="tight",
facecolor="white")
plt.close()
print(f"Saved β†’ {OUT}/modelarchi_v4.{{png,pdf}}")
if __name__ == "__main__":
main()