Spaces:
Sleeping
Sleeping
| import gc | |
| import os | |
| import spaces | |
| import gradio as gr | |
| import random | |
| import tempfile | |
| import time | |
| from easydict import EasyDict | |
| import numpy as np | |
| import torch | |
| from dav.pipelines import DAVPipeline | |
| from dav.models import UNetSpatioTemporalRopeConditionModel | |
| from diffusers import AutoencoderKLTemporalDecoder, FlowMatchEulerDiscreteScheduler | |
| from dav.utils import img_utils | |
| def seed_all(seed: int = 0): | |
| """ | |
| Set random seeds for reproducibility. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| examples = [ | |
| ["demos/wooly_mammoth.mp4", 3, 32, 8, 16, 6, 768], | |
| ] | |
| def load_models(model_base, device): | |
| vae = AutoencoderKLTemporalDecoder.from_pretrained(model_base, subfolder="vae") | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
| model_base, subfolder="scheduler" | |
| ) | |
| unet = UNetSpatioTemporalRopeConditionModel.from_pretrained( | |
| model_base, subfolder="unet" | |
| ) | |
| unet_interp = UNetSpatioTemporalRopeConditionModel.from_pretrained( | |
| model_base, subfolder="unet_interp" | |
| ) | |
| pipe = DAVPipeline( | |
| vae=vae, | |
| unet=unet, | |
| unet_interp=unet_interp, | |
| scheduler=scheduler, | |
| ) | |
| pipe = pipe.to(device) | |
| return pipe | |
| model_base = "hhyangcs/depth-any-video" | |
| device_type = "cuda" | |
| device = torch.device(device_type) | |
| pipe = load_models(model_base, device) | |
| def infer_depth( | |
| file: str, | |
| denoise_steps: int = 3, | |
| num_frames: int = 32, | |
| decode_chunk_size: int = 16, | |
| num_interp_frames: int = 16, | |
| num_overlap_frames: int = 6, | |
| max_resolution: int = 1024, | |
| seed: int = 66, | |
| output_dir: str = "./outputs", | |
| ): | |
| seed_all(seed) | |
| max_frames = (num_interp_frames + 2 - num_overlap_frames) * (num_frames // 2) | |
| image, fps = img_utils.read_video(file, max_frames=max_frames) | |
| image = img_utils.imresize_max(image, max_resolution) | |
| image = img_utils.imcrop_multi(image) | |
| image_tensor = np.ascontiguousarray( | |
| [_img.transpose(2, 0, 1) / 255.0 for _img in image] | |
| ) | |
| image_tensor = torch.from_numpy(image_tensor).to(device) | |
| print(f"==> video name: {file}, frames shape: {image_tensor.shape}") | |
| with torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.float16): | |
| pipe_out = pipe( | |
| image_tensor, | |
| num_frames=num_frames, | |
| num_overlap_frames=num_overlap_frames, | |
| num_interp_frames=num_interp_frames, | |
| decode_chunk_size=decode_chunk_size, | |
| num_inference_steps=denoise_steps, | |
| ) | |
| disparity = pipe_out.disparity | |
| disparity_colored = pipe_out.disparity_colored | |
| image = pipe_out.image | |
| # (N, H, 2 * W, 3) | |
| merged = np.concatenate( | |
| [ | |
| image, | |
| disparity_colored, | |
| ], | |
| axis=2, | |
| ) | |
| file_name = os.path.splitext(os.path.basename(file))[0] | |
| os.makedirs(output_dir, exist_ok=True) | |
| output_path = os.path.join(output_dir, f"{file_name}_depth.mp4") | |
| img_utils.write_video( | |
| output_path, | |
| merged, | |
| fps, | |
| ) | |
| # clear the cache for the next video | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return output_path | |
| def construct_demo(): | |
| with gr.Blocks(analytics_enabled=False) as depthanyvideo_iface: | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| input_video = gr.Video(label="Input Video") | |
| with gr.Column(scale=2): | |
| with gr.Row(equal_height=True): | |
| output_video = gr.Video( | |
| label="Ouput Video & Depth", | |
| interactive=False, | |
| autoplay=True, | |
| loop=True, | |
| show_share_button=True, | |
| scale=2, | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| with gr.Row(equal_height=False): | |
| with gr.Accordion("Advanced Settings", open=False): | |
| denoise_steps = gr.Slider( | |
| label="Denoise Steps", | |
| minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1, | |
| ) | |
| num_frames = gr.Slider( | |
| label="Number of Key Frames", | |
| minimum=16, | |
| maximum=32, | |
| value=24, | |
| step=2, | |
| ) | |
| decode_chunk_size = gr.Slider( | |
| label="Decode Chunk Size", | |
| minimum=8, | |
| maximum=32, | |
| value=8, | |
| step=1, | |
| ) | |
| num_interp_frames = gr.Slider( | |
| label="Number of Interpolation Frames", | |
| minimum=8, | |
| maximum=32, | |
| value=16, | |
| step=1, | |
| ) | |
| num_overlap_frames = gr.Slider( | |
| label="Number of Overlap Frames", | |
| minimum=2, | |
| maximum=10, | |
| value=6, | |
| step=1, | |
| ) | |
| max_resolution = gr.Slider( | |
| label="Maximum Resolution", | |
| minimum=512, | |
| maximum=2048, | |
| value=768, | |
| step=32, | |
| ) | |
| generate_btn = gr.Button("Generate") | |
| with gr.Column(scale=2): | |
| pass | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[ | |
| input_video, | |
| denoise_steps, | |
| num_frames, | |
| decode_chunk_size, | |
| num_interp_frames, | |
| num_overlap_frames, | |
| max_resolution, | |
| ], | |
| outputs=output_video, | |
| fn=infer_depth, | |
| cache_examples="lazy", | |
| ) | |
| generate_btn.click( | |
| fn=infer_depth, | |
| inputs=[ | |
| input_video, | |
| denoise_steps, | |
| num_frames, | |
| decode_chunk_size, | |
| num_interp_frames, | |
| num_overlap_frames, | |
| max_resolution, | |
| ], | |
| outputs=output_video, | |
| ) | |
| return depthanyvideo_iface | |
| demo = construct_demo() | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(share=True) | |