File size: 6,455 Bytes
1e05592
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""Fast version of make_cache_x_v2.py β€” reuses video frame buffer across ticks.

Bottleneck of the original: for rolling per-tick manifests (e.g. 17 ticks per
video on CARLA), every tick calls `sample_frames_from_mp4_by_indices`, which
opens the video fresh and reads from frame 0 sequentially until it reaches the
wanted indices. For 17 ticks this decodes the same video ~17 times.

Fix: monkey-patch `sample_frames_from_mp4_by_indices` to keep an LRU-1 cache of
the most recently decoded video's full frame list (already resized). Sort the
manifest by video_path so consecutive ticks of the same video hit the cache.

Expected speed-up on CARLA rolling (17 ticks/clip avg): 5-10x for the decode
portion, bringing aggregate throughput close to the GPU-forward-bound limit.
"""
from __future__ import annotations

import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

# Order matters: apply PR fast_patch BEFORE importing model code.
from tools import run_train_cot_belief_fast  # noqa: F401, E402

import argparse  # noqa: E402
import json  # noqa: E402
import logging  # noqa: E402
import time  # noqa: E402
from typing import Dict, List  # noqa: E402

import cv2  # noqa: E402
import numpy as np  # noqa: E402
import torch  # noqa: E402
from PIL import Image  # noqa: E402
from tqdm import tqdm  # noqa: E402

# ── monkey-patch sample_frames with video-level cache ────────────────────
from training.VLA import frame_utils as _fu  # noqa: E402

_video_cache: Dict[str, List[Image.Image]] = {}
_cache_path: str = ""


def _resize_bgr(frame: np.ndarray, resize_short: int) -> Image.Image:
    h, w = frame.shape[:2]
    scale = resize_short / min(h, w)
    nh, nw = int(round(h * scale)), int(round(w * scale))
    frame = cv2.resize(frame, (nw, nh), interpolation=cv2.INTER_AREA)
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    return Image.fromarray(frame)


def _decode_full_video(video_path: str, resize_short: int) -> List[Image.Image]:
    """Decode every frame of a video once, return resized PIL RGB list."""
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"could not open: {video_path}")
    frames: List[Image.Image] = []
    while True:
        ok, frame = cap.read()
        if not ok: break
        frames.append(_resize_bgr(frame, resize_short))
    cap.release()
    return frames


def _patched_sample_by_indices(video_path, indices: List[int],
                                resize_short: int = 336,
                                return_times: bool = False):
    """LRU-1 cached version: each video is decoded exactly once."""
    global _cache_path, _video_cache
    vp = str(video_path)
    if vp != _cache_path:
        # Evict previous video to free memory
        _video_cache.clear()
        _video_cache[vp] = _decode_full_video(vp, resize_short)
        _cache_path = vp
    all_frames = _video_cache[vp]
    n_total = len(all_frames)
    if n_total <= 0:
        raise RuntimeError(f"bad video (0 frames decoded): {video_path}")
    clipped = [max(0, min(n_total - 1, int(i))) for i in indices]
    frames = [all_frames[i] for i in clipped]
    if return_times:
        # Caller doesn't have fps anymore; approximate via cv2
        cap = cv2.VideoCapture(vp)
        fps = float(cap.get(cv2.CAP_PROP_FPS)) or 30.0
        cap.release()
        return frames, [i / fps for i in clipped]
    return frames


# Apply monkey-patch
_fu.sample_frames_from_mp4_by_indices = _patched_sample_by_indices


# ── now import the original extraction code with patches active ──────────
from tools.make_cache_x_v2 import (  # noqa: E402
    build_extraction_assistant,
    extract_split,
)


logging.basicConfig(level=logging.INFO,
                    format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("make_cache_x_v2_fast")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", type=Path, required=True)
    ap.add_argument("--tag",      default="sft_x_v2")
    ap.add_argument("--split",    required=True)
    ap.add_argument("--out_dir",  type=Path,
                    default=ROOT / "data/belief_cache_v2")
    ap.add_argument("--ckpt",     type=Path,
                    default=ROOT / "checkpoints/sft_x_v2/best")
    ap.add_argument("--base_model", type=Path,
                    default=ROOT / "models/Qwen3-VL-4B-Instruct")
    ap.add_argument("--belief_layers", nargs="+", type=int,
                    default=[20, 24, 28, 32])
    ap.add_argument("--policy_layer", type=int, default=33)
    ap.add_argument("--batch_size",   type=int, default=4)
    ap.add_argument("--limit",        type=int, default=0)
    ap.add_argument("--pool_mode",
                    choices=["range", "open", "token_mean", "random_span"],
                    default="range")
    ap.add_argument("--random_span_seed", type=int, default=0)
    args = ap.parse_args()

    # ── Pre-sort manifest by video_id so consecutive batches hit the cache ──
    sorted_manifest = args.out_dir / f"_sorted__{args.manifest.name}"
    args.out_dir.mkdir(parents=True, exist_ok=True)
    n_records = 0
    with open(args.manifest) as fin:
        records = []
        for ln in fin:
            ln = ln.strip()
            if not ln: continue
            try:
                r = json.loads(ln)
            except json.JSONDecodeError:
                continue
            records.append(r)
            n_records += 1

    def key(r):
        return (r.get("video_path", ""), int(r.get("meta", {}).get("tick_index", 0)))

    records.sort(key=key)
    logger.info(f"[sort] {n_records} records sorted by (video_path, tick_index)")

    with open(sorted_manifest, "w") as fout:
        for r in records:
            fout.write(json.dumps(r) + "\n")
    logger.info(f"[save] sorted manifest β†’ {sorted_manifest}")

    out_path = args.out_dir / f"{args.tag}__{args.split}.pt"
    extract_split(
        ckpt_dir=args.ckpt,
        base_model=args.base_model,
        manifest_path=sorted_manifest,
        out_path=out_path,
        belief_layers=tuple(args.belief_layers),
        policy_layer=args.policy_layer,
        batch_size=args.batch_size,
        limit=args.limit,
        n_frames=8,
        pool_mode=args.pool_mode,
        random_span_seed=args.random_span_seed,
    )


if __name__ == "__main__":
    main()