VLAlert / tools /make_cache_x_v2_fast.py
AsianPlayer's picture
Add VLAlert code
1e05592 verified
Raw
History Blame Contribute Delete
6.46 kB
"""Fast version of make_cache_x_v2.py β€” reuses video frame buffer across ticks.
Bottleneck of the original: for rolling per-tick manifests (e.g. 17 ticks per
video on CARLA), every tick calls `sample_frames_from_mp4_by_indices`, which
opens the video fresh and reads from frame 0 sequentially until it reaches the
wanted indices. For 17 ticks this decodes the same video ~17 times.
Fix: monkey-patch `sample_frames_from_mp4_by_indices` to keep an LRU-1 cache of
the most recently decoded video's full frame list (already resized). Sort the
manifest by video_path so consecutive ticks of the same video hit the cache.
Expected speed-up on CARLA rolling (17 ticks/clip avg): 5-10x for the decode
portion, bringing aggregate throughput close to the GPU-forward-bound limit.
"""
from __future__ import annotations
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))
# Order matters: apply PR fast_patch BEFORE importing model code.
from tools import run_train_cot_belief_fast # noqa: F401, E402
import argparse # noqa: E402
import json # noqa: E402
import logging # noqa: E402
import time # noqa: E402
from typing import Dict, List # noqa: E402
import cv2 # noqa: E402
import numpy as np # noqa: E402
import torch # noqa: E402
from PIL import Image # noqa: E402
from tqdm import tqdm # noqa: E402
# ── monkey-patch sample_frames with video-level cache ────────────────────
from training.VLA import frame_utils as _fu # noqa: E402
_video_cache: Dict[str, List[Image.Image]] = {}
_cache_path: str = ""
def _resize_bgr(frame: np.ndarray, resize_short: int) -> Image.Image:
h, w = frame.shape[:2]
scale = resize_short / min(h, w)
nh, nw = int(round(h * scale)), int(round(w * scale))
frame = cv2.resize(frame, (nw, nh), interpolation=cv2.INTER_AREA)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return Image.fromarray(frame)
def _decode_full_video(video_path: str, resize_short: int) -> List[Image.Image]:
"""Decode every frame of a video once, return resized PIL RGB list."""
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise RuntimeError(f"could not open: {video_path}")
frames: List[Image.Image] = []
while True:
ok, frame = cap.read()
if not ok: break
frames.append(_resize_bgr(frame, resize_short))
cap.release()
return frames
def _patched_sample_by_indices(video_path, indices: List[int],
resize_short: int = 336,
return_times: bool = False):
"""LRU-1 cached version: each video is decoded exactly once."""
global _cache_path, _video_cache
vp = str(video_path)
if vp != _cache_path:
# Evict previous video to free memory
_video_cache.clear()
_video_cache[vp] = _decode_full_video(vp, resize_short)
_cache_path = vp
all_frames = _video_cache[vp]
n_total = len(all_frames)
if n_total <= 0:
raise RuntimeError(f"bad video (0 frames decoded): {video_path}")
clipped = [max(0, min(n_total - 1, int(i))) for i in indices]
frames = [all_frames[i] for i in clipped]
if return_times:
# Caller doesn't have fps anymore; approximate via cv2
cap = cv2.VideoCapture(vp)
fps = float(cap.get(cv2.CAP_PROP_FPS)) or 30.0
cap.release()
return frames, [i / fps for i in clipped]
return frames
# Apply monkey-patch
_fu.sample_frames_from_mp4_by_indices = _patched_sample_by_indices
# ── now import the original extraction code with patches active ──────────
from tools.make_cache_x_v2 import ( # noqa: E402
build_extraction_assistant,
extract_split,
)
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("make_cache_x_v2_fast")
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--manifest", type=Path, required=True)
ap.add_argument("--tag", default="sft_x_v2")
ap.add_argument("--split", required=True)
ap.add_argument("--out_dir", type=Path,
default=ROOT / "data/belief_cache_v2")
ap.add_argument("--ckpt", type=Path,
default=ROOT / "checkpoints/sft_x_v2/best")
ap.add_argument("--base_model", type=Path,
default=ROOT / "models/Qwen3-VL-4B-Instruct")
ap.add_argument("--belief_layers", nargs="+", type=int,
default=[20, 24, 28, 32])
ap.add_argument("--policy_layer", type=int, default=33)
ap.add_argument("--batch_size", type=int, default=4)
ap.add_argument("--limit", type=int, default=0)
ap.add_argument("--pool_mode",
choices=["range", "open", "token_mean", "random_span"],
default="range")
ap.add_argument("--random_span_seed", type=int, default=0)
args = ap.parse_args()
# ── Pre-sort manifest by video_id so consecutive batches hit the cache ──
sorted_manifest = args.out_dir / f"_sorted__{args.manifest.name}"
args.out_dir.mkdir(parents=True, exist_ok=True)
n_records = 0
with open(args.manifest) as fin:
records = []
for ln in fin:
ln = ln.strip()
if not ln: continue
try:
r = json.loads(ln)
except json.JSONDecodeError:
continue
records.append(r)
n_records += 1
def key(r):
return (r.get("video_path", ""), int(r.get("meta", {}).get("tick_index", 0)))
records.sort(key=key)
logger.info(f"[sort] {n_records} records sorted by (video_path, tick_index)")
with open(sorted_manifest, "w") as fout:
for r in records:
fout.write(json.dumps(r) + "\n")
logger.info(f"[save] sorted manifest β†’ {sorted_manifest}")
out_path = args.out_dir / f"{args.tag}__{args.split}.pt"
extract_split(
ckpt_dir=args.ckpt,
base_model=args.base_model,
manifest_path=sorted_manifest,
out_path=out_path,
belief_layers=tuple(args.belief_layers),
policy_layer=args.policy_layer,
batch_size=args.batch_size,
limit=args.limit,
n_frames=8,
pool_mode=args.pool_mode,
random_span_seed=args.random_span_seed,
)
if __name__ == "__main__":
main()