File size: 10,233 Bytes
08b23ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# Copyright (c) 2021 Henrique Morimitsu
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
# SPDX-License-Identifier: Apache License 2.0
#
# This file has been modified by ByteDance Ltd. and/or its affiliates. on 2025.09.04
#
# Original file was released under Apache License 2.0, with the full license text
# available at https://github.com/hmorimitsu/ptlflow/blob/main/LICENSE.
#
# This modified file is released under the same license.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This module processes PNG frame sequences to generate optical flow using PTLFlow,
with support for visualization and video generation.
"""

import argparse
import os
import subprocess
import shutil
import logging
from pathlib import Path
from typing import List, Tuple, Optional, Union

import cv2 as cv
import torch
import numpy as np
from tqdm import tqdm

from third_partys.ptlflow.ptlflow.utils import flow_utils
from third_partys.ptlflow.ptlflow.utils.io_adapter import IOAdapter
import third_partys.ptlflow.ptlflow as ptlflow

class OpticalFlowProcessor:
    """Handles optical flow computation and visualization."""
    
    def __init__(
        self,
        model_name: str = 'dpflow',
        checkpoint: str = 'sintel',
        device: Optional[str] = None,
        resize_to: Optional[Tuple[int, int]] = None
    ):
        """
        Initialize optical flow processor.
        
        Args:
            model_name: Name of the flow model to use
            checkpoint: Checkpoint/dataset name for the model
            device: Device to run on (auto-detect if None)
            resize_to: Optional (width, height) to resize frames
        """
        self.model_name = model_name
        self.checkpoint = checkpoint
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
        self.resize_to = resize_to
        
        # Initialize model
        self.model = ptlflow.get_model(model_name, ckpt_path=checkpoint).to(self.device).eval()
        print(f"Loaded {model_name} model on {self.device}")
        
        self.io_adapter = None
    
    def load_frame_sequence(self, frames_dir: Union[str, Path]) -> Tuple[List[np.ndarray], List[Path]]:
        """
        Load PNG frame sequence from directory.
        """
        frames_dir = Path(frames_dir)
        
        if not frames_dir.exists():
            raise FileNotFoundError(f"Frames directory not found: {frames_dir}")
        
        # Find PNG files and sort naturally
        png_files = list(frames_dir.glob('*.png'))
        if len(png_files) < 2:
            raise ValueError(f"Need at least 2 PNG frames, found {len(png_files)} in {frames_dir}")
        
        # Natural sorting for proper frame order
        png_files.sort(key=lambda x: self._natural_sort_key(x.name))
        
        frames = []
        for png_path in tqdm(png_files, desc="Loading frames"):
            # Load image in color
            img_bgr = cv.imread(str(png_path), cv.IMREAD_COLOR)
            
            if self.resize_to:
                img_bgr = cv.resize(img_bgr, self.resize_to, cv.INTER_LINEAR)
        
            img_rgb = cv.cvtColor(img_bgr, cv.COLOR_BGR2RGB)
            frames.append(img_rgb)
        
        return frames, png_files
    
    def _natural_sort_key(self, filename: str) -> List[Union[int, str]]:
        """Natural sorting key for filenames with numbers."""
        import re
        return [int(text) if text.isdigit() else text.lower() 
                for text in re.split('([0-9]+)', filename)]
    
    def compute_optical_flow_sequence(
        self,
        frames: List[np.ndarray],
        flow_vis_dir: Union[str, Path],
        flow_save_dir: Optional[Union[str, Path]] = None,
        save_visualizations: bool = True
    ) -> List[torch.Tensor]:
        """
        Compute optical flow for entire frame sequence.
        """
        if len(frames) < 2:
            raise ValueError("Need at least 2 frames for optical flow")
        
        flow_vis_dir = Path(flow_vis_dir)
        flow_save_dir = Path(flow_save_dir) if flow_save_dir else flow_vis_dir
        
        H, W = frames[0].shape[:2]
        
        # Initialize IO adapter
        if self.io_adapter is None:
            self.io_adapter = IOAdapter(self.model, (H, W))
        
        flows = []
        for i in tqdm(range(len(frames) - 1), desc="Computing optical flow"):
            # Prepare frame pair
            frame_pair = [frames[i], frames[i + 1]]
            raw_inputs = self.io_adapter.prepare_inputs(frame_pair)
            
            imgs = raw_inputs['images'][0]  # (2, 3, H, W)
            
            pair_tensor = torch.stack((imgs[0:1], imgs[1:2]), dim=1).squeeze(0)  # (1, 2, 3, H, W)
            pair_tensor = pair_tensor.to(self.device, non_blocking=True).contiguous()
        
            with torch.no_grad():
                flow_result = self.model({'images': pair_tensor.unsqueeze(0)})
                flow = flow_result['flows'][0]  # (1, 2, H, W)
            
            flows.append(flow)
            
            if save_visualizations:
                self._save_flow_outputs(flow, i, flow_vis_dir, flow_save_dir)
             
        return flows
    
    def _save_flow_outputs(
        self,
        flow_tensor: torch.Tensor,
        frame_idx: int,
        viz_dir: Path,
        flow_dir: Path
    ) -> None:
        """Save flow outputs in both .flo and visualization formats."""
        # Save raw flow (.flo format)
        flow_hw2 = flow_tensor[0]  # (2, H, W)
        flow_np = flow_hw2.permute(1, 2, 0).cpu().numpy()  # (H, W, 2)
        
        flow_path = flow_dir / f'flow_{frame_idx:04d}.flo'
        flow_utils.flow_write(flow_path, flow_np)
        
        # Save visualization
        flow_rgb = flow_utils.flow_to_rgb(flow_tensor)[0]  # Remove batch dimension
        
        if flow_rgb.dim() == 4:  # (Npred, 3, H, W)
            flow_rgb = flow_rgb[0]
        
        flow_rgb_np = (flow_rgb * 255).byte().permute(1, 2, 0).cpu().numpy()  # (H, W, 3)
        viz_bgr = cv.cvtColor(flow_rgb_np, cv.COLOR_RGB2BGR)
        
        viz_path = viz_dir / f'flow_viz_{frame_idx:04d}.png'
        cv.imwrite(str(viz_path), viz_bgr)

def create_flow_video(
    image_dir: Union[str, Path],
    output_filename: str = 'flow.mp4',
    fps: int = 10,
    pattern: str = 'flow_viz_*.png',
    cleanup_temp: bool = True
) -> bool:
    """
    Create MP4 video from flow visualization images.
    """
    image_dir = Path(image_dir)
    
    if not image_dir.exists():
        print(f"Image directory not found: {image_dir}")
    
    image_files = sorted(image_dir.glob(pattern))
    if not image_files:
        print(f"No images found matching pattern '{pattern}' in {image_dir}")
    
    temp_dir = image_dir / 'temp_sequence'
    temp_dir.mkdir(exist_ok=True)
    
    try:
        # Copy files with sequential naming
        for i, img_file in enumerate(image_files):
            temp_name = temp_dir / f'frame_{i:05d}.png'
            shutil.copy2(img_file, temp_name)
        
        # Create video using ffmpeg
        output_path = image_dir / output_filename
        
        cmd = [
            'ffmpeg', '-y',
            '-framerate', str(fps),
            '-i', str(temp_dir / 'frame_%05d.png'),
            '-c:v', 'libx264',
            '-pix_fmt', 'yuv420p',
            str(output_path)
        ]
        
        subprocess.run(
            cmd, 
            capture_output=True, 
            text=True, 
            check=True
        )
        return True
    except Exception as e:
        print(f"Video creation failed: {e}")
        return False
    finally:
        if cleanup_temp and temp_dir.exists():
            shutil.rmtree(temp_dir)

def main(
    frames_dir: Union[str, Path],
    flow_vis_dir: Union[str, Path] = 'flow_out',
    flow_save_dir: Optional[Union[str, Path]] = None,
    resize_to: Optional[Tuple[int, int]] = None,
    model_name: str = 'dpflow',
    checkpoint: str = 'sintel'
) -> bool:

    # Initialize processor
    processor = OpticalFlowProcessor(
        model_name=model_name,
        checkpoint=checkpoint,
        resize_to=resize_to
    )
    
    # Load frames
    frames, png_paths = processor.load_frame_sequence(frames_dir)
    
    # Compute optical flow
    flows = processor.compute_optical_flow_sequence(
        frames=frames,
        flow_vis_dir=flow_vis_dir,
        flow_save_dir=flow_save_dir,
        save_visualizations=True
    )
    
    # Create video
    create_flow_video(flow_vis_dir)

def get_parser():
    parser = argparse.ArgumentParser(description="Optical flow inference on frame sequences")
    
    parser.add_argument('--input_path', type=str, help="base input path")
    parser.add_argument('--seq_name', type=str, help="sequence name")
    parser.add_argument('--model_name', type=str, default='dpflow', help="Optical flow model to use")
    parser.add_argument('--checkpoint', type=str, default='sintel', help="Model checkpoint/dataset name")
    parser.add_argument('--resize_width', type=int, default=None, help="Resize frame width (must specify both width and height)")
    parser.add_argument('--resize_height', type=int, default=None, help="Resize frame height (must specify both width and height)")
    parser.add_argument('--fps', type=int, default=10, help="Frame rate for output video")

    return parser

if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    
    # Path
    frames_dir = f'{args.input_path}/{args.seq_name}/imgs'
    flow_vis_dir = frames_dir.replace("imgs", "flow_vis")
    flow_save_dir = frames_dir.replace("imgs", "flow")

    os.makedirs(flow_vis_dir, exist_ok=True)
    os.makedirs(flow_save_dir, exist_ok=True)
    
    # Prepare resize parameter
    resize_to = None
    if args.resize_width and args.resize_height:
        resize_to = (args.resize_width, args.resize_height)
    
    # Process optical flow
    success = main(
        frames_dir=frames_dir,
        flow_vis_dir=flow_vis_dir,
        flow_save_dir=flow_save_dir,
        resize_to=resize_to,
        model_name=args.model_name,
        checkpoint=args.checkpoint
    )
    
    print("Optical flow processing completed successfully")