diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..bb1fe1d1870bb36911f3ec7a7a9c455db9149166 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ + + + +models/ + +# Packages +*.egg +*.egg-info +dist +build +eggs +parts +bin +var +sdist +develop-eggs +.installed.cfg +lib64 +__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index 4044e3ae5bca896259d9b5b38938eed559211593..aceb57db0be71541d91d4621b59b4712e6df0c38 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,17 @@ --- title: ZIT Controlnet -emoji: 📊 -colorFrom: yellow -colorTo: purple +emoji: 🖼 +colorFrom: purple +colorTo: red sdk: gradio -sdk_version: 6.0.2 +sdk_version: 5.44.0 app_file: app.py pinned: false license: apache-2.0 short_description: Supports Canny, HED, Depth, Pose and MLSD +models: + - Tongyi-MAI/Z-Image-Turbo + - alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference \ No newline at end of file diff --git a/app.py b/app.py index 04cc31aa8d0e06aeaac3b59bb361ed71d831e43f..682679b53ed8f58e2e76732d7ee4f317564d9cb0 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,227 @@ import gradio as gr +import numpy as np +import random +import json +import spaces +import torch +from diffusers import DiffusionPipeline +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler +from videox_fun.pipeline import ZImageControlPipeline +from videox_fun.models import ZImageControlTransformer2DModel +from transformers import AutoTokenizer, Qwen3ForCausalLM +from diffusers import AutoencoderKL +from image_utils import get_image_latent, scale_image +# from videox_fun.utils.utils import get_image_latent -def greet(name): - return "Hello " + name + "!!" -demo = gr.Interface(fn=greet, inputs="text", outputs="text") -demo.launch() +MODEL_REPO = "Tongyi-MAI/Z-Image-Turbo" +MAX_SEED = np.iinfo(np.int32).max +MAX_IMAGE_SIZE = 1280 + +MODEL_LOCAL = "models/Z-Image-Turbo/" +TRANSFORMER_LOCAL = "models/Z-Image-Turbo-Fun-Controlnet-Union.safetensors" + + +weight_dtype = torch.bfloat16 + +# load transformer +transformer = ZImageControlTransformer2DModel.from_pretrained( + MODEL_LOCAL, + subfolder="transformer", + low_cpu_mem_usage=True, + torch_dtype=torch.bfloat16, + transformer_additional_kwargs={ + "control_layers_places": [0, 5, 10, 15, 20, 25], + "control_in_dim": 16 + }, +).to(torch.bfloat16) + +if TRANSFORMER_LOCAL is not None: + print(f"From checkpoint: {TRANSFORMER_LOCAL}") + if TRANSFORMER_LOCAL.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(TRANSFORMER_LOCAL) + else: + state_dict = torch.load(TRANSFORMER_LOCAL, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Load MODEL_REPO +# Get Vae +vae = AutoencoderKL.from_pretrained( + MODEL_LOCAL, + subfolder="vae" +).to(weight_dtype) + +tokenizer = AutoTokenizer.from_pretrained( + MODEL_LOCAL, subfolder="tokenizer" +) +text_encoder = Qwen3ForCausalLM.from_pretrained( + MODEL_LOCAL, subfolder="text_encoder", torch_dtype=weight_dtype, + low_cpu_mem_usage=True, +) +scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=3) +pipe = ZImageControlPipeline( + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, +) +pipe.transformer = transformer +pipe.to("cuda") + +# ======== AoTI compilation + FA3 ======== +pipe.transformer.layers._repeated_blocks = ["ZImageTransformerBlock"] +spaces.aoti_blocks_load(pipe.transformer.layers, + "zerogpu-aoti/Z-Image", variant="fa3") + + +@spaces.GPU +def inference( + prompt, + input_image, + image_scale=1.0, + control_context_scale = 0.75, + seed=42, + randomize_seed=True, + guidance_scale=1.5, + num_inference_steps=8, + progress=gr.Progress(track_tqdm=True), +): + # process image + if input_image is None: + print("Error: input_image is empty.") + return None + + input_image, width, height = scale_image(input_image, image_scale) + + control_image = get_image_latent(input_image, sample_size=[height, width])[:, :, 0] + + # generation + if randomize_seed: + seed = random.randint(0, MAX_SEED) + + generator = torch.Generator().manual_seed(seed) + + image = pipe( + prompt=prompt, + height=height, + width=width, + generator=generator, + guidance_scale=guidance_scale, + control_image=control_image, + num_inference_steps=num_inference_steps, + control_context_scale=control_context_scale, + ).images[0] + + return image, seed + + +def read_file(path: str) -> str: + with open(path, 'r', encoding='utf-8') as f: + content = f.read() + return content + + +css = """ +#col-container { + margin: 0 auto; + max-width: 960px; +} +""" + +with open('static/data.json', 'r') as file: + data = json.load(file) +examples = data['examples'] + +with gr.Blocks() as demo: + with gr.Column(elem_id="col-container"): + with gr.Column(): + gr.HTML(read_file("static/header.html")) + with gr.Row(equal_height=True): + with gr.Column(): + input_image = gr.Image( + height=290, sources=['upload', 'clipboard'], + image_mode='RGB', + # elem_id="image_upload", + type="pil", label="Upload") + + prompt = gr.Textbox( + label="Prompt", + show_label=False, + lines=2, + placeholder="Enter your prompt", + container=False, + ) + + run_button = gr.Button("Run", variant="primary") + with gr.Column(): + output_image = gr.Image(label="Result", show_label=False) + + with gr.Accordion("Advanced Settings", open=False): + seed = gr.Slider( + label="Seed", + minimum=0, + maximum=MAX_SEED, + step=1, + value=0, + ) + + randomize_seed = gr.Checkbox(label="Randomize seed", value=True) + + with gr.Row(): + image_scale = gr.Slider( + label="Image scale", + minimum=0.5, + maximum=2.0, + step=0.1, + value=1.0, + ) + control_context_scale = gr.Slider( + label="Control context scale", + minimum=0.0, + maximum=1.0, + step=0.1, + value=0.75, + ) + + with gr.Row(): + guidance_scale = gr.Slider( + label="Guidance scale", + minimum=0.0, + maximum=10.0, + step=0.1, + value=2.5, + ) + + num_inference_steps = gr.Slider( + label="Number of inference steps", + minimum=1, + maximum=30, + step=1, + value=8, + ) + gr.Examples(examples=examples, inputs=[input_image, prompt]) + + gr.HTML(read_file("static/footer.html")) + gr.on( + triggers=[run_button.click, prompt.submit], + fn=inference, + inputs=[ + prompt, + input_image, + image_scale, + control_context_scale, + seed, + randomize_seed, + guidance_scale, + num_inference_steps, + ], + outputs=[output_image, seed], + ) + +if __name__ == "__main__": + demo.launch(mcp_server=True) diff --git a/examples/depth.jpg b/examples/depth.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3bdea7fb076fcd9792274561d0db4d01966a3cf2 Binary files /dev/null and b/examples/depth.jpg differ diff --git a/examples/hed.jpg b/examples/hed.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9e2a4e8445f9300ab4574e2adf85780b95a74b9a Binary files /dev/null and b/examples/hed.jpg differ diff --git a/examples/pose.jpg b/examples/pose.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8e18542696c64cdbdcfbd1d08e7cae8ba1d666c9 Binary files /dev/null and b/examples/pose.jpg differ diff --git a/examples/pose2.jpg b/examples/pose2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1987adba0dd5d3bc8ec706f245b51028385ab7c2 Binary files /dev/null and b/examples/pose2.jpg differ diff --git a/image_utils.py b/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34a26879b32c8a4953e57a170fdddfd9e643a3c2 --- /dev/null +++ b/image_utils.py @@ -0,0 +1,70 @@ +import torch +from PIL import Image +import numpy as np + +def scale_image(img, scale): + w, h = img.size + new_w = int(w * scale) + new_h = int(h * scale) + + # Adjust to nearest multiple of 32 + new_w = (new_w // 32) * 32 + new_h = (new_h // 32) * 32 + + return img.resize((new_w, new_h), Image.LANCZOS), new_w, new_h + +def padding_image(images, new_width, new_height): + new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255)) + + aspect_ratio = images.width / images.height + if new_width / new_height > 1: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + else: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + + resized_img = images.resize((new_img_width, new_img_height)) + + paste_x = (new_width - new_img_width) // 2 + paste_y = (new_height - new_img_height) // 2 + + new_image.paste(resized_img, (paste_x, paste_y)) + + return new_image + + +def get_image_latent(ref_image=None, sample_size=None, padding=False): + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + if padding: + ref_image = padding_image( + ref_image, sample_size[1], sample_size[0]) + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute( + [3, 0, 1, 2]).unsqueeze(0) / 255 + elif isinstance(ref_image, Image.Image): + ref_image = ref_image.convert("RGB") + if padding: + ref_image = padding_image( + ref_image, sample_size[1], sample_size[0]) + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute( + [3, 0, 1, 2]).unsqueeze(0) / 255 + else: + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute( + [3, 0, 1, 2]).unsqueeze(0) / 255 + + return ref_image \ No newline at end of file diff --git a/predict_t2i_control.py b/predict_t2i_control.py new file mode 100644 index 0000000000000000000000000000000000000000..788bd2e45cc89405237de91cbde0ab7e4e43a6b1 --- /dev/null +++ b/predict_t2i_control.py @@ -0,0 +1,228 @@ +import os +import sys + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from omegaconf import OmegaConf +from PIL import Image + +current_file_path = os.path.abspath(__file__) +project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] +for project_root in project_roots: + sys.path.insert(0, project_root) if project_root not in sys.path else None + +from videox_fun.dist import set_multi_gpus_devices, shard_model +from videox_fun.models import (AutoencoderKL, AutoTokenizer, + Qwen3ForCausalLM, ZImageControlTransformer2DModel) +from videox_fun.models.cache_utils import get_teacache_coefficients +from videox_fun.pipeline import ZImageControlPipeline +from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler +from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper) +from videox_fun.utils.lora_utils import merge_lora, unmerge_lora +from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, get_image, + get_video_to_video_latent, + save_videos_grid) + +# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. +# model_full_load means that the entire model will be moved to the GPU. +# +# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. +# +# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, +# and the transformer model has been quantized to float8, which can save more GPU memory. +# +# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, +# resulting in slower speeds but saving a large amount of GPU memory. +GPU_memory_mode = "model_cpu_offload" +# Multi GPUs config +# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. +# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. +# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. +ulysses_degree = 1 +ring_degree = 1 +# Use FSDP to save more GPU memory in multi gpus. +fsdp_dit = False +fsdp_text_encoder = False +# Compile will give a speedup in fixed resolution and need a little GPU memory. +# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload. +compile_dit = False + +# Config and model path +config_path = "config/z_image/z_image_control.yaml" +# model path +model_name = "models/Diffusion_Transformer/Z-Image-Turbo/" + +# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" +sampler_name = "Flow" + +# Load pretrained model if need +transformer_path = "models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors" +vae_path = None +lora_path = None + +# Other params +sample_size = [1728, 992] + +# Use torch.float16 if GPU does not support torch.bfloat16 +# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 +weight_dtype = torch.bfloat16 +control_image = "asset/pose.jpg" +control_context_scale = 0.75 + +# 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性 +# 在neg prompt中添加"安静,固定"等词语可以增加动态性。 +prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。" +negative_prompt = " " +guidance_scale = 0.00 +seed = 43 +num_inference_steps = 9 +lora_weight = 0.55 +save_path = "samples/z-image-t2i-control" + +device = set_multi_gpus_devices(ulysses_degree, ring_degree) +config = OmegaConf.load(config_path) + +transformer = ZImageControlTransformer2DModel.from_pretrained( + model_name, + subfolder="transformer", + low_cpu_mem_usage=True, + torch_dtype=weight_dtype, + transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), +).to(weight_dtype) + +if transformer_path is not None: + print(f"From checkpoint: {transformer_path}") + if transformer_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(transformer_path) + else: + state_dict = torch.load(transformer_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = transformer.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get Vae +vae = AutoencoderKL.from_pretrained( + model_name, + subfolder="vae" +).to(weight_dtype) + +if vae_path is not None: + print(f"From checkpoint: {vae_path}") + if vae_path.endswith("safetensors"): + from safetensors.torch import load_file, safe_open + state_dict = load_file(vae_path) + else: + state_dict = torch.load(vae_path, map_location="cpu") + state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict + + m, u = vae.load_state_dict(state_dict, strict=False) + print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + +# Get tokenizer and text_encoder +tokenizer = AutoTokenizer.from_pretrained( + model_name, subfolder="tokenizer" +) +text_encoder = Qwen3ForCausalLM.from_pretrained( + model_name, subfolder="text_encoder", torch_dtype=weight_dtype, + low_cpu_mem_usage=True, +) + +# Get Scheduler +Chosen_Scheduler = scheduler_dict = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +}[sampler_name] +scheduler = Chosen_Scheduler.from_pretrained( + model_name, + subfolder="scheduler" +) + +pipeline = ZImageControlPipeline( + vae=vae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, +) + +if ulysses_degree > 1 or ring_degree > 1: + from functools import partial + transformer.enable_multi_gpus_inference() + if fsdp_dit: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.transformer_blocks) + list(transformer.single_transformer_blocks)) + pipeline.transformer = shard_fn(pipeline.transformer) + print("Add FSDP DIT") + if fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers, ignored_modules=[text_encoder.language_model.embed_tokens], transformer_layer_cls_to_wrap=["MistralDecoderLayer", "PixtralTransformer"]) + text_encoder = shard_fn(text_encoder) + print("Add FSDP TEXT ENCODER") + +if compile_dit: + for i in range(len(pipeline.transformer.transformer_blocks)): + pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i]) + print("Add Compile") + +if GPU_memory_mode == "sequential_cpu_offload": + pipeline.enable_sequential_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_cpu_offload": + pipeline.enable_model_cpu_offload(device=device) +elif GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) + convert_weight_dtype_wrapper(transformer, weight_dtype) + pipeline.to(device=device) +else: + pipeline.to(device=device) + +generator = torch.Generator(device=device).manual_seed(seed) + +if lora_path is not None: + pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +with torch.no_grad(): + if control_image is not None: + control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0] + + sample = pipeline( + prompt = prompt, + negative_prompt = negative_prompt, + height = sample_size[0], + width = sample_size[1], + generator = generator, + guidance_scale = guidance_scale, + control_image = control_image, + num_inference_steps = num_inference_steps, + control_context_scale = control_context_scale, + ).images + +if lora_path is not None: + pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) + +def save_results(): + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + + index = len([path for path in os.listdir(save_path)]) + 1 + prefix = str(index).zfill(8) + video_path = os.path.join(save_path, prefix + ".png") + image = sample[0] + image.save(video_path) + +if ulysses_degree * ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_results() +else: + save_results() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..fec320bb3832954b9f5c3067d5d6ecd1e06a543a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,7 @@ +gradio +torch +transformers +accelerate +spaces +git+https://github.com/huggingface/diffusers.git +kernels \ No newline at end of file diff --git a/static/data.json b/static/data.json new file mode 100644 index 0000000000000000000000000000000000000000..cfe685696e61de79c70b54e951a9d87d80e3481a --- /dev/null +++ b/static/data.json @@ -0,0 +1,8 @@ +{ + "examples": [ + ["examples/hed.jpg", "A middle-aged man with a short beard, wearing a casual button-down shirt, sitting at a polished dark wooden table, holding a tumbler of whiskey with ice and taking a thoughtful sip. The background is a softly lit."], + ["examples/depth.jpg", "Modern minimalist, clean lines, open plan, natural light, spacious, serene, contemporary, elegant, architectural, inviting, sophisticated, light-filled, harmonious, texture, shadows, high ceilings."], + ["examples/pose.jpg", "A fit, athletic young woman, squatting low, glancing confidently at the camera. She's on a picturesque tropical beach with gentle waves lapping the shore. The image has the crisp, high-contrast look of a fashion magazine cover. Dynamic pose, bright and inviting."], + ["examples/pose2.jpg", "A majestic female paladin in gleaming plate armor, standing tall and proud, bathed in a celestial glow, with a determined expression, holding a radiant sword aloft against a backdrop of a sun-drenched, ancient castle."] + ] +} diff --git a/static/footer.html b/static/footer.html new file mode 100644 index 0000000000000000000000000000000000000000..1b4982ac4fa2871921122e706eb9364760ac29ef --- /dev/null +++ b/static/footer.html @@ -0,0 +1,16 @@ +
+ I made this space after seeing a Reddit post about using ControlNet editing with Z-Image from Alibaba. + The code looks solid and serves as a great example. I believe there’s a lot of potential to build on top of this, add new features, and explore even more creative ideas using this technique. + +

Usage

+ You can change control_context_scale for more control and better detail. + For best results, use a detailed prompt. + The recommended control_context_scale range is 0.65 to 0.80. + +

Reference

+ +
\ No newline at end of file diff --git a/static/header.html b/static/header.html new file mode 100644 index 0000000000000000000000000000000000000000..8d8698d34ddea8bf7dc0ee1a84d6ad39d8c96859 --- /dev/null +++ b/static/header.html @@ -0,0 +1,11 @@ +
+

+ Z Image Turbo (ZIT) - Controlnet +

+
+

+ Supports multiple control conditions - including Canny, HED, Depth, Pose and MLSD. +
+ If you like my spaces, please support me by visiting AiSudo for more image generation 😊 +

+
diff --git a/videox_fun/__init__.py b/videox_fun/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/videox_fun/api/api.py b/videox_fun/api/api.py new file mode 100755 index 0000000000000000000000000000000000000000..a3c0238a7ce3e8c87e61fd3c4e062b956a456a1f --- /dev/null +++ b/videox_fun/api/api.py @@ -0,0 +1,226 @@ +import base64 +import gc +import hashlib +import io +import os +import tempfile +from io import BytesIO + +import gradio as gr +import requests +import torch +from fastapi import FastAPI +from PIL import Image + + +# Function to encode a file to Base64 +def encode_file_to_base64(file_path): + with open(file_path, "rb") as file: + # Encode the data to Base64 + file_base64 = base64.b64encode(file.read()) + return file_base64 + +def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller): + @app.post("/videox_fun/update_diffusion_transformer") + def _update_diffusion_transformer_api( + datas: dict, + ): + diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none') + + try: + controller.update_diffusion_transformer( + diffusion_transformer_path + ) + comment = "Success" + except Exception as e: + torch.cuda.empty_cache() + comment = f"Error. error information is {str(e)}" + + return {"message": comment} + +def download_from_url(url, timeout=10): + try: + response = requests.get(url, timeout=timeout) + response.raise_for_status() # 检查请求是否成功 + return response.content + except requests.exceptions.RequestException as e: + print(f"Error downloading from {url}: {e}") + return None + +def save_base64_video(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.mp4" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + + return file_path + +def save_base64_image(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.jpg" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + + return file_path + +def save_url_video(url): + video_data = download_from_url(url) + if video_data: + return save_base64_video(base64.b64encode(video_data)) + return None + +def save_url_image(url): + image_data = download_from_url(url) + if image_data: + return save_base64_image(base64.b64encode(image_data)) + return None + +def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): + @app.post("/videox_fun/infer_forward") + def _infer_forward_api( + datas: dict, + ): + base_model_path = datas.get('base_model_path', 'none') + base_model_2_path = datas.get('base_model_2_path', 'none') + lora_model_path = datas.get('lora_model_path', 'none') + lora_model_2_path = datas.get('lora_model_2_path', 'none') + lora_alpha_slider = datas.get('lora_alpha_slider', 0.55) + prompt_textbox = datas.get('prompt_textbox', None) + negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ') + sampler_dropdown = datas.get('sampler_dropdown', 'Euler') + sample_step_slider = datas.get('sample_step_slider', 30) + resize_method = datas.get('resize_method', "Generate by") + width_slider = datas.get('width_slider', 672) + height_slider = datas.get('height_slider', 384) + base_resolution = datas.get('base_resolution', 512) + is_image = datas.get('is_image', False) + generation_method = datas.get('generation_method', False) + length_slider = datas.get('length_slider', 49) + overlap_video_length = datas.get('overlap_video_length', 4) + partial_video_length = datas.get('partial_video_length', 72) + cfg_scale_slider = datas.get('cfg_scale_slider', 6) + start_image = datas.get('start_image', None) + end_image = datas.get('end_image', None) + validation_video = datas.get('validation_video', None) + validation_video_mask = datas.get('validation_video_mask', None) + control_video = datas.get('control_video', None) + denoise_strength = datas.get('denoise_strength', 0.70) + seed_textbox = datas.get("seed_textbox", 43) + + ref_image = datas.get('ref_image', None) + enable_teacache = datas.get('enable_teacache', True) + teacache_threshold = datas.get('teacache_threshold', 0.10) + num_skip_start_steps = datas.get('num_skip_start_steps', 1) + teacache_offload = datas.get('teacache_offload', False) + cfg_skip_ratio = datas.get('cfg_skip_ratio', 0) + enable_riflex = datas.get('enable_riflex', False) + riflex_k = datas.get('riflex_k', 6) + fps = datas.get('fps', None) + + generation_method = "Image Generation" if is_image else generation_method + + if start_image is not None: + if start_image.startswith('http'): + start_image = save_url_image(start_image) + start_image = [Image.open(start_image).convert("RGB")] + else: + start_image = base64.b64decode(start_image) + start_image = [Image.open(BytesIO(start_image)).convert("RGB")] + + if end_image is not None: + if end_image.startswith('http'): + end_image = save_url_image(end_image) + end_image = [Image.open(end_image).convert("RGB")] + else: + end_image = base64.b64decode(end_image) + end_image = [Image.open(BytesIO(end_image)).convert("RGB")] + + if validation_video is not None: + if validation_video.startswith('http'): + validation_video = save_url_video(validation_video) + else: + validation_video = save_base64_video(validation_video) + + if validation_video_mask is not None: + if validation_video_mask.startswith('http'): + validation_video_mask = save_url_image(validation_video_mask) + else: + validation_video_mask = save_base64_image(validation_video_mask) + + if control_video is not None: + if control_video.startswith('http'): + control_video = save_url_video(control_video) + else: + control_video = save_base64_video(control_video) + + if ref_image is not None: + if ref_image.startswith('http'): + ref_image = save_url_image(ref_image) + ref_image = [Image.open(ref_image).convert("RGB")] + else: + ref_image = base64.b64decode(ref_image) + ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")] + + try: + save_sample_path, comment = controller.generate( + "", + base_model_path, + lora_model_path, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = ref_image, + enable_teacache = enable_teacache, + teacache_threshold = teacache_threshold, + num_skip_start_steps = num_skip_start_steps, + teacache_offload = teacache_offload, + cfg_skip_ratio = cfg_skip_ratio, + enable_riflex = enable_riflex, + riflex_k = riflex_k, + base_model_2_dropdown = base_model_2_path, + lora_model_2_dropdown = lora_model_2_path, + fps = fps, + is_api = True, + ) + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_sample_path = "" + comment = f"Error. error information is {str(e)}" + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + + if save_sample_path != "": + return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)} + else: + return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": None} \ No newline at end of file diff --git a/videox_fun/api/api_multi_nodes.py b/videox_fun/api/api_multi_nodes.py new file mode 100755 index 0000000000000000000000000000000000000000..ebd98d9f8435127f216da441baf29b1b7c916689 --- /dev/null +++ b/videox_fun/api/api_multi_nodes.py @@ -0,0 +1,320 @@ +# This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py +import base64 +import gc +import hashlib +import io +import os +import tempfile +from io import BytesIO + +import gradio as gr +import requests +import torch +import torch.distributed as dist +from fastapi import FastAPI, HTTPException +from PIL import Image + +from .api import download_from_url, encode_file_to_base64 + +try: + import ray +except: + print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.") + ray = None + +def save_base64_video_dist(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.mp4" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + if dist.is_initialized(): + if dist.get_rank() == 0: + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + dist.barrier() + else: + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + return file_path + +def save_base64_image_dist(base64_string): + video_data = base64.b64decode(base64_string) + + md5_hash = hashlib.md5(video_data).hexdigest() + filename = f"{md5_hash}.jpg" + + temp_dir = tempfile.gettempdir() + file_path = os.path.join(temp_dir, filename) + + if dist.is_initialized(): + if dist.get_rank() == 0: + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + dist.barrier() + else: + with open(file_path, 'wb') as video_file: + video_file.write(video_data) + return file_path + +def save_url_video_dist(url): + video_data = download_from_url(url) + if video_data: + return save_base64_video_dist(base64.b64encode(video_data)) + return None + +def save_url_image_dist(url): + image_data = download_from_url(url) + if image_data: + return save_base64_image_dist(base64.b64encode(image_data)) + return None + +if ray is not None: + @ray.remote(num_gpus=1) + class MultiNodesGenerator: + def __init__( + self, rank: int, world_size: int, Controller, + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=None, ulysses_degree=1, ring_degree=1, + fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False, + weight_dtype=None, savedir_sample=None, + ): + # Set PyTorch distributed environment variables + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + + self.rank = rank + self.controller = Controller( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path, + ulysses_degree=ulysses_degree, ring_degree=ring_degree, + fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + def generate(self, datas): + try: + base_model_path = datas.get('base_model_path', 'none') + base_model_2_path = datas.get('base_model_2_path', 'none') + lora_model_path = datas.get('lora_model_path', 'none') + lora_model_2_path = datas.get('lora_model_2_path', 'none') + lora_alpha_slider = datas.get('lora_alpha_slider', 0.55) + prompt_textbox = datas.get('prompt_textbox', None) + negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ') + sampler_dropdown = datas.get('sampler_dropdown', 'Euler') + sample_step_slider = datas.get('sample_step_slider', 30) + resize_method = datas.get('resize_method', "Generate by") + width_slider = datas.get('width_slider', 672) + height_slider = datas.get('height_slider', 384) + base_resolution = datas.get('base_resolution', 512) + is_image = datas.get('is_image', False) + generation_method = datas.get('generation_method', False) + length_slider = datas.get('length_slider', 49) + overlap_video_length = datas.get('overlap_video_length', 4) + partial_video_length = datas.get('partial_video_length', 72) + cfg_scale_slider = datas.get('cfg_scale_slider', 6) + start_image = datas.get('start_image', None) + end_image = datas.get('end_image', None) + validation_video = datas.get('validation_video', None) + validation_video_mask = datas.get('validation_video_mask', None) + control_video = datas.get('control_video', None) + denoise_strength = datas.get('denoise_strength', 0.70) + seed_textbox = datas.get("seed_textbox", 43) + + ref_image = datas.get('ref_image', None) + enable_teacache = datas.get('enable_teacache', True) + teacache_threshold = datas.get('teacache_threshold', 0.10) + num_skip_start_steps = datas.get('num_skip_start_steps', 1) + teacache_offload = datas.get('teacache_offload', False) + cfg_skip_ratio = datas.get('cfg_skip_ratio', 0) + enable_riflex = datas.get('enable_riflex', False) + riflex_k = datas.get('riflex_k', 6) + fps = datas.get('fps', None) + + generation_method = "Image Generation" if is_image else generation_method + + if start_image is not None: + if start_image.startswith('http'): + start_image = save_url_image_dist(start_image) + start_image = [Image.open(start_image).convert("RGB")] + else: + start_image = base64.b64decode(start_image) + start_image = [Image.open(BytesIO(start_image)).convert("RGB")] + + if end_image is not None: + if end_image.startswith('http'): + end_image = save_url_image_dist(end_image) + end_image = [Image.open(end_image).convert("RGB")] + else: + end_image = base64.b64decode(end_image) + end_image = [Image.open(BytesIO(end_image)).convert("RGB")] + + if validation_video is not None: + if validation_video.startswith('http'): + validation_video = save_url_video_dist(validation_video) + else: + validation_video = save_base64_video_dist(validation_video) + + if validation_video_mask is not None: + if validation_video_mask.startswith('http'): + validation_video_mask = save_url_image_dist(validation_video_mask) + else: + validation_video_mask = save_base64_image_dist(validation_video_mask) + + if control_video is not None: + if control_video.startswith('http'): + control_video = save_url_video_dist(control_video) + else: + control_video = save_base64_video_dist(control_video) + + if ref_image is not None: + if ref_image.startswith('http'): + ref_image = save_url_image_dist(ref_image) + ref_image = [Image.open(ref_image).convert("RGB")] + else: + ref_image = base64.b64decode(ref_image) + ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")] + + try: + save_sample_path, comment = self.controller.generate( + "", + base_model_path, + lora_model_path, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = ref_image, + enable_teacache = enable_teacache, + teacache_threshold = teacache_threshold, + num_skip_start_steps = num_skip_start_steps, + teacache_offload = teacache_offload, + cfg_skip_ratio = cfg_skip_ratio, + enable_riflex = enable_riflex, + riflex_k = riflex_k, + base_model_2_dropdown = base_model_2_path, + lora_model_2_dropdown = lora_model_2_path, + fps = fps, + is_api = True, + ) + except Exception as e: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + save_sample_path = "" + comment = f"Error. error information is {str(e)}" + if dist.is_initialized(): + if dist.get_rank() == 0: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + else: + return None + else: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + + + if dist.is_initialized(): + if dist.get_rank() == 0: + if save_sample_path != "": + return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)} + else: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + else: + return None + else: + if save_sample_path != "": + return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)} + else: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + + except Exception as e: + print(f"Error generating: {str(e)}") + comment = f"Error generating: {str(e)}" + if dist.is_initialized(): + if dist.get_rank() == 0: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + else: + return None + else: + return {"message": comment, "save_sample_path": None, "base64_encoding": None} + + class MultiNodesEngine: + def __init__( + self, + world_size, + Controller, + GPU_memory_mode, + scheduler_dict, + model_name, + model_type, + config_path, + ulysses_degree=1, + ring_degree=1, + fsdp_dit=False, + fsdp_text_encoder=False, + compile_dit=False, + weight_dtype=torch.bfloat16, + savedir_sample="samples" + ): + # Ensure Ray is initialized + if not ray.is_initialized(): + ray.init() + + num_workers = world_size + self.workers = [ + MultiNodesGenerator.remote( + rank, world_size, Controller, + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path, + ulysses_degree=ulysses_degree, ring_degree=ring_degree, + fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + for rank in range(num_workers) + ] + print("Update workers done") + + async def generate(self, data): + results = ray.get([ + worker.generate.remote(data) + for worker in self.workers + ]) + + return next(path for path in results if path is not None) + + def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine): + + @app.post("/videox_fun/infer_forward") + async def _multi_nodes_infer_forward_api( + datas: dict, + ): + try: + result = await engine.generate(datas) + return result + except Exception as e: + if isinstance(e, HTTPException): + raise e + raise HTTPException(status_code=500, detail=str(e)) +else: + MultiNodesEngine = None + MultiNodesGenerator = None + multi_nodes_infer_forward_api = None \ No newline at end of file diff --git a/videox_fun/data/__init__.py b/videox_fun/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..babf155e18d36fafd2d4f41183392a58ab87dba9 --- /dev/null +++ b/videox_fun/data/__init__.py @@ -0,0 +1,9 @@ +from .dataset_image import CC15M, ImageEditDataset +from .dataset_image_video import (ImageVideoControlDataset, ImageVideoDataset, TextDataset, + ImageVideoSampler) +from .dataset_video import VideoDataset, VideoSpeechDataset, VideoAnimateDataset, WebVid10M +from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager, + custom_meshgrid, get_random_mask, get_relative_pose, + get_video_reader_batch, padding_image, process_pose_file, + process_pose_params, ray_condition, resize_frame, + resize_image_with_target_area) diff --git a/videox_fun/data/bucket_sampler.py b/videox_fun/data/bucket_sampler.py new file mode 100755 index 0000000000000000000000000000000000000000..24b4160f3d2bbadca1d23e90c23e887ea6d15f70 --- /dev/null +++ b/videox_fun/data/bucket_sampler.py @@ -0,0 +1,379 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from typing import (Generic, Iterable, Iterator, List, Optional, Sequence, + Sized, TypeVar, Union) + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils.data import BatchSampler, Dataset, Sampler + +ASPECT_RATIO_512 = { + '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0], + '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0], + '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0], + '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0], + '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0], + '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0], + '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0], + '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0], + '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0], + '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0] +} +ASPECT_RATIO_RANDOM_CROP_512 = { + '0.42': [320.0, 768.0], '0.5': [352.0, 704.0], + '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0], + '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], + '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0], + '2.0': [704.0, 352.0], '2.4': [768.0, 320.0] +} +ASPECT_RATIO_RANDOM_CROP_PROB = [ + 1, 2, + 4, 4, 4, 4, + 8, 8, 8, + 4, 4, 4, 4, + 2, 1 +] +ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB) + +def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512): + aspect_ratio = height / width + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) + return ratios[closest_ratio], float(closest_ratio) + +def get_image_size_without_loading(path): + with Image.open(path) as img: + return img.size # (width, height) + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Dataset): dataset to sample from + replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` + num_samples (int): number of samples to draw, default=`len(dataset)`. + generator (Generator): Generator used in sampling. + """ + + data_source: Sized + replacement: bool + + def __init__(self, data_source: Sized, replacement: bool = False, + num_samples: Optional[int] = None, generator=None) -> None: + self.data_source = data_source + self.replacement = replacement + self._num_samples = num_samples + self.generator = generator + self._pos_start = 0 + + if not isinstance(self.replacement, bool): + raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}") + + if not isinstance(self.num_samples, int) or self.num_samples <= 0: + raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}") + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + if self.replacement: + for _ in range(self.num_samples // 32): + yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist() + yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist() + else: + for _ in range(self.num_samples // n): + xx = torch.randperm(n, generator=generator).tolist() + if self._pos_start >= n: + self._pos_start = 0 + print("xx top 10", xx[:10], self._pos_start) + for idx in range(self._pos_start, n): + yield xx[idx] + self._pos_start = (self._pos_start + 1) % n + self._pos_start = 0 + yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n] + + def __len__(self) -> int: + return self.num_samples + +class AspectRatioBatchImageSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + def __init__( + self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + train_folder: str = None, + aspect_ratios: dict = ASPECT_RATIO_512, + drop_last: bool = False, + config=None, + **kwargs + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.train_folder = train_folder + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + self.config = config + # buckets for each aspect ratio + self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} + # [str(k) for k, v in aspect_ratios] + self.current_available_bucket_keys = list(aspect_ratios.keys()) + + def __iter__(self): + for idx in self.sampler: + try: + image_dict = self.dataset[idx] + + width, height = image_dict.get("width", None), image_dict.get("height", None) + if width is None or height is None: + image_id, name = image_dict['file_path'], image_dict['text'] + if self.train_folder is None: + image_dir = image_id + else: + image_dir = os.path.join(self.train_folder, image_id) + + width, height = get_image_size_without_loading(image_dir) + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e) + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self._aspect_ratio_buckets[closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + +class AspectRatioBatchSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + def __init__( + self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + video_folder: str = None, + train_data_format: str = "webvid", + aspect_ratios: dict = ASPECT_RATIO_512, + drop_last: bool = False, + config=None, + **kwargs + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.video_folder = video_folder + self.train_data_format = train_data_format + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + self.config = config + # buckets for each aspect ratio + self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios} + # [str(k) for k, v in aspect_ratios] + self.current_available_bucket_keys = list(aspect_ratios.keys()) + + def __iter__(self): + for idx in self.sampler: + try: + video_dict = self.dataset[idx] + width, more = video_dict.get("width", None), video_dict.get("height", None) + + if width is None or height is None: + if self.train_data_format == "normal": + video_id, name = video_dict['file_path'], video_dict['text'] + if self.video_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.video_folder, video_id) + else: + videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] + video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") + cap = cv2.VideoCapture(video_dir) + + # 获取视频尺寸 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数 + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数 + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e, self.dataset[idx], "This item is error, please check it.") + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self._aspect_ratio_buckets[closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + +class AspectRatioBatchImageVideoSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + + def __init__(self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + train_folder: str = None, + aspect_ratios: dict = ASPECT_RATIO_512, + drop_last: bool = False + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.train_folder = train_folder + self.batch_size = batch_size + self.aspect_ratios = aspect_ratios + self.drop_last = drop_last + + # buckets for each aspect ratio + self.current_available_bucket_keys = list(aspect_ratios.keys()) + self.bucket = { + 'image':{ratio: [] for ratio in aspect_ratios}, + 'video':{ratio: [] for ratio in aspect_ratios} + } + + def __iter__(self): + for idx in self.sampler: + content_type = self.dataset[idx].get('type', 'image') + if content_type == 'image': + try: + image_dict = self.dataset[idx] + + width, height = image_dict.get("width", None), image_dict.get("height", None) + if width is None or height is None: + image_id, name = image_dict['file_path'], image_dict['text'] + if self.train_folder is None: + image_dir = image_id + else: + image_dir = os.path.join(self.train_folder, image_id) + + width, height = get_image_size_without_loading(image_dir) + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e, self.dataset[idx], "This item is error, please check it.") + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self.bucket['image'][closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] + else: + try: + video_dict = self.dataset[idx] + width, height = video_dict.get("width", None), video_dict.get("height", None) + + if width is None or height is None: + video_id, name = video_dict['file_path'], video_dict['text'] + if self.train_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.train_folder, video_id) + cap = cv2.VideoCapture(video_dir) + + # 获取视频尺寸 + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数 + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数 + + ratio = height / width # self.dataset[idx] + else: + height = int(height) + width = int(width) + ratio = height / width # self.dataset[idx] + except Exception as e: + print(e, self.dataset[idx], "This item is error, please check it.") + continue + # find the closest aspect ratio + closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio)) + if closest_ratio not in self.current_available_bucket_keys: + continue + bucket = self.bucket['video'][closest_ratio] + bucket.append(idx) + # yield a batch of indices in the same aspect ratio group + if len(bucket) == self.batch_size: + yield bucket[:] + del bucket[:] \ No newline at end of file diff --git a/videox_fun/data/dataset_image.py b/videox_fun/data/dataset_image.py new file mode 100644 index 0000000000000000000000000000000000000000..18c672a9ca71a7f57f211a3aedd1f6a499b5e487 --- /dev/null +++ b/videox_fun/data/dataset_image.py @@ -0,0 +1,191 @@ +import json +import os +import random + +import numpy as np +import torch +import torchvision.transforms as transforms +from PIL import Image +from torch.utils.data.dataset import Dataset + + +class CC15M(Dataset): + def __init__( + self, + json_path, + video_folder=None, + resolution=512, + enable_bucket=False, + ): + print(f"loading annotations from {json_path} ...") + self.dataset = json.load(open(json_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.enable_bucket = enable_bucket + self.video_folder = video_folder + + resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution) + self.pixel_transforms = transforms.Compose([ + transforms.Resize(resolution[0]), + transforms.CenterCrop(resolution), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, name = video_dict['file_path'], video_dict['text'] + + if self.video_folder is None: + video_dir = video_id + else: + video_dir = os.path.join(self.video_folder, video_id) + + pixel_values = Image.open(video_dir).convert("RGB") + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + except Exception as e: + print(e) + idx = random.randint(0, self.length-1) + + if not self.enable_bucket: + pixel_values = self.pixel_transforms(pixel_values) + else: + pixel_values = np.array(pixel_values) + + sample = dict(pixel_values=pixel_values, text=name) + return sample + +class ImageEditDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + image_sample_size=512, + text_drop_ratio=0.1, + enable_bucket=False, + enable_inpaint=False, + return_file_name=False, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + self.dataset = dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + self.enable_inpaint = enable_inpaint + self.return_file_name = return_file_name + + # Image params + self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) + self.image_transforms = transforms.Compose([ + transforms.Resize(min(self.image_sample_size)), + transforms.CenterCrop(self.image_sample_size), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) + ]) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + + image_path, text = data_info['file_path'], data_info['text'] + if self.data_root is not None: + image_path = os.path.join(self.data_root, image_path) + image = Image.open(image_path).convert('RGB') + + if not self.enable_bucket: + raise ValueError("Not enable_bucket is not supported now. ") + else: + image = np.expand_dims(np.array(image), 0) + + source_image_path = data_info.get('source_file_path', []) + source_image = [] + if isinstance(source_image_path, list): + for _source_image_path in source_image_path: + if self.data_root is not None: + _source_image_path = os.path.join(self.data_root, _source_image_path) + _source_image = Image.open(_source_image_path).convert('RGB') + source_image.append(_source_image) + else: + if self.data_root is not None: + _source_image_path = os.path.join(self.data_root, source_image_path) + _source_image = Image.open(_source_image_path).convert('RGB') + source_image.append(_source_image) + + if not self.enable_bucket: + raise ValueError("Not enable_bucket is not supported now. ") + else: + source_image = [np.array(_source_image) for _source_image in source_image] + + if random.random() < self.text_drop_ratio: + text = '' + return image, source_image, text, 'image', image_path + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, source_pixel_values, name, data_type, file_path = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["source_pixel_values"] = source_pixel_values + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + if self.return_file_name: + sample["file_name"] = os.path.basename(file_path) + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + +if __name__ == "__main__": + dataset = CC15M( + csv_path="./cc15m_add_index.json", + resolution=512, + ) + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) + for idx, batch in enumerate(dataloader): + print(batch["pixel_values"].shape, len(batch["text"])) \ No newline at end of file diff --git a/videox_fun/data/dataset_image_video.py b/videox_fun/data/dataset_image_video.py new file mode 100755 index 0000000000000000000000000000000000000000..449a2f7b1df4a4c70e236658a93e435f6c1ff9d0 --- /dev/null +++ b/videox_fun/data/dataset_image_video.py @@ -0,0 +1,657 @@ +import csv +import gc +import io +import json +import math +import os +import random +from contextlib import contextmanager +from random import shuffle +from threading import Thread + +import albumentations +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms +from decord import VideoReader +from einops import rearrange +from func_timeout import FunctionTimedOut, func_timeout +from packaging import version as pver +from PIL import Image +from safetensors.torch import load_file +from torch.utils.data import BatchSampler, Sampler +from torch.utils.data.dataset import Dataset + +from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager, + custom_meshgrid, get_random_mask, get_relative_pose, + get_video_reader_batch, padding_image, process_pose_file, + process_pose_params, ray_condition, resize_frame, + resize_image_with_target_area) + + +class ImageVideoSampler(BatchSampler): + """A sampler wrapper for grouping images with similar aspect ratio into a same batch. + + Args: + sampler (Sampler): Base sampler. + dataset (Dataset): Dataset providing data information. + batch_size (int): Size of mini-batch. + drop_last (bool): If ``True``, the sampler will drop the last batch if + its size would be less than ``batch_size``. + aspect_ratios (dict): The predefined aspect ratios. + """ + + def __init__(self, + sampler: Sampler, + dataset: Dataset, + batch_size: int, + drop_last: bool = False + ) -> None: + if not isinstance(sampler, Sampler): + raise TypeError('sampler should be an instance of ``Sampler``, ' + f'but got {sampler}') + if not isinstance(batch_size, int) or batch_size <= 0: + raise ValueError('batch_size should be a positive integer value, ' + f'but got batch_size={batch_size}') + self.sampler = sampler + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + + # buckets for each aspect ratio + self.bucket = {'image':[], 'video':[]} + + def __iter__(self): + for idx in self.sampler: + content_type = self.dataset.dataset[idx].get('type', 'image') + self.bucket[content_type].append(idx) + + # yield a batch of indices in the same aspect ratio group + if len(self.bucket['video']) == self.batch_size: + bucket = self.bucket['video'] + yield bucket[:] + del bucket[:] + elif len(self.bucket['image']) == self.batch_size: + bucket = self.bucket['image'] + yield bucket[:] + del bucket[:] + + +class ImageVideoDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + image_sample_size=512, + video_repeat=0, + text_drop_ratio=0.1, + enable_bucket=False, + video_length_drop_start=0.0, + video_length_drop_end=1.0, + enable_inpaint=False, + return_file_name=False, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + + # It's used to balance num of images and videos. + if video_repeat > 0: + self.dataset = [] + for data in dataset: + if data.get('type', 'image') != 'video': + self.dataset.append(data) + + for _ in range(video_repeat): + for data in dataset: + if data.get('type', 'image') == 'video': + self.dataset.append(data) + else: + self.dataset = dataset + del dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + self.enable_inpaint = enable_inpaint + self.return_file_name = return_file_name + + self.video_length_drop_start = video_length_drop_start + self.video_length_drop_end = video_length_drop_end + + # Video params + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_transforms = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + # Image params + self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) + self.image_transforms = transforms.Compose([ + transforms.Resize(min(self.image_sample_size)), + transforms.CenterCrop(self.image_sample_size), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) + ]) + + self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + + if data_info.get('type', 'image')=='video': + video_id, text = data_info['file_path'], data_info['text'] + + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + return pixel_values, text, 'video', video_dir + else: + image_path, text = data_info['file_path'], data_info['text'] + if self.data_root is not None: + image_path = os.path.join(self.data_root, image_path) + image = Image.open(image_path).convert('RGB') + if not self.enable_bucket: + image = self.image_transforms(image).unsqueeze(0) + else: + image = np.expand_dims(np.array(image), 0) + if random.random() < self.text_drop_ratio: + text = '' + return image, text, 'image', image_path + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, name, data_type, file_path = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + if self.return_file_name: + sample["file_name"] = os.path.basename(file_path) + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class ImageVideoControlDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + image_sample_size=512, + video_repeat=0, + text_drop_ratio=0.1, + enable_bucket=False, + video_length_drop_start=0.1, + video_length_drop_end=0.9, + enable_inpaint=False, + enable_camera_info=False, + return_file_name=False, + enable_subject_info=False, + padding_subject_info=True, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + + # It's used to balance num of images and videos. + if video_repeat > 0: + self.dataset = [] + for data in dataset: + if data.get('type', 'image') != 'video': + self.dataset.append(data) + + for _ in range(video_repeat): + for data in dataset: + if data.get('type', 'image') == 'video': + self.dataset.append(data) + else: + self.dataset = dataset + del dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + self.enable_inpaint = enable_inpaint + self.enable_camera_info = enable_camera_info + self.enable_subject_info = enable_subject_info + self.padding_subject_info = padding_subject_info + + self.video_length_drop_start = video_length_drop_start + self.video_length_drop_end = video_length_drop_end + + # Video params + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_transforms = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + if self.enable_camera_info: + self.video_transforms_camera = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size) + ] + ) + + # Image params + self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size) + self.image_transforms = transforms.Compose([ + transforms.Resize(min(self.image_sample_size)), + transforms.CenterCrop(self.image_sample_size), + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5]) + ]) + + self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size)) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + video_id, text = data_info['file_path'], data_info['text'] + + if data_info.get('type', 'image')=='video': + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + + control_video_id = data_info['control_file_path'] + + if control_video_id is not None: + if self.data_root is None: + control_video_id = control_video_id + else: + control_video_id = os.path.join(self.data_root, control_video_id) + + if self.enable_camera_info: + if control_video_id.lower().endswith('.txt'): + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + + control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0]) + control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous() + control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True) + control_camera_values = self.video_transforms_camera(control_camera_values) + else: + control_pixel_values = np.zeros_like(pixel_values) + + control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True) + control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0) + control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0] + control_camera_values = np.array([control_camera_values[index] for index in batch_index]) + else: + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + control_camera_values = None + else: + control_pixel_values = np.zeros_like(pixel_values) + control_camera_values = None + else: + if control_video_id is not None: + with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: + try: + sample_args = (control_video_reader, batch_index) + control_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(control_pixel_values)): + frame = control_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + control_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + del control_video_reader + else: + control_pixel_values = control_pixel_values + + if not self.enable_bucket: + control_pixel_values = self.video_transforms(control_pixel_values) + else: + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + else: + control_pixel_values = np.zeros_like(pixel_values) + control_camera_values = None + + if self.enable_subject_info: + if not self.enable_bucket: + visual_height, visual_width = pixel_values.shape[-2:] + else: + visual_height, visual_width = pixel_values.shape[1:3] + + subject_id = data_info.get('object_file_path', []) + shuffle(subject_id) + subject_images = [] + for i in range(min(len(subject_id), 4)): + subject_image = Image.open(subject_id[i]) + width, height = subject_image.size + total_pixels = width * height + + if self.padding_subject_info: + img = padding_image(subject_image, visual_width, visual_height) + else: + img = resize_image_with_target_area(subject_image, 1024 * 1024) + + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + subject_images.append(np.array(img)) + if self.padding_subject_info: + subject_image = np.array(subject_images) + else: + subject_image = subject_images + else: + subject_image = None + + return pixel_values, control_pixel_values, subject_image, control_camera_values, text, "video" + else: + image_path, text = data_info['file_path'], data_info['text'] + if self.data_root is not None: + image_path = os.path.join(self.data_root, image_path) + image = Image.open(image_path).convert('RGB') + if not self.enable_bucket: + image = self.image_transforms(image).unsqueeze(0) + else: + image = np.expand_dims(np.array(image), 0) + + if random.random() < self.text_drop_ratio: + text = '' + + control_image_id = data_info['control_file_path'] + + if self.data_root is None: + control_image_id = control_image_id + else: + control_image_id = os.path.join(self.data_root, control_image_id) + + control_image = Image.open(control_image_id).convert('RGB') + if not self.enable_bucket: + control_image = self.image_transforms(control_image).unsqueeze(0) + else: + control_image = np.expand_dims(np.array(control_image), 0) + + if self.enable_subject_info: + if not self.enable_bucket: + visual_height, visual_width = image.shape[-2:] + else: + visual_height, visual_width = image.shape[1:3] + + subject_id = data_info.get('object_file_path', []) + shuffle(subject_id) + subject_images = [] + for i in range(min(len(subject_id), 4)): + subject_image = Image.open(subject_id[i]).convert('RGB') + width, height = subject_image.size + total_pixels = width * height + + if self.padding_subject_info: + img = padding_image(subject_image, visual_width, visual_height) + else: + img = resize_image_with_target_area(subject_image, 1024 * 1024) + + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + subject_images.append(np.array(img)) + if self.padding_subject_info: + subject_image = np.array(subject_images) + else: + subject_image = subject_images + else: + subject_image = None + + return image, control_image, subject_image, None, text, 'image' + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, control_pixel_values, subject_image, control_camera_values, name, data_type = self.get_batch(idx) + + sample["pixel_values"] = pixel_values + sample["control_pixel_values"] = control_pixel_values + sample["subject_image"] = subject_image + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + + if self.enable_camera_info: + sample["control_camera_values"] = control_camera_values + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class ImageVideoSafetensorsDataset(Dataset): + def __init__( + self, + ann_path, + data_root=None, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + self.dataset = dataset + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + def __len__(self): + return self.length + + def __getitem__(self, idx): + if self.data_root is None: + path = self.dataset[idx]["file_path"] + else: + path = os.path.join(self.data_root, self.dataset[idx]["file_path"]) + state_dict = load_file(path) + return state_dict + + +class TextDataset(Dataset): + def __init__(self, ann_path, text_drop_ratio=0.0): + print(f"loading annotations from {ann_path} ...") + with open(ann_path, 'r') as f: + self.dataset = json.load(f) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + self.text_drop_ratio = text_drop_ratio + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + item = self.dataset[idx] + text = item['text'] + + # Randomly drop text (for classifier-free guidance) + if random.random() < self.text_drop_ratio: + text = '' + + sample = { + "text": text, + "idx": idx + } + return sample + + except Exception as e: + print(f"Error at index {idx}: {e}, retrying with random index...") + idx = np.random.randint(0, self.length - 1) \ No newline at end of file diff --git a/videox_fun/data/dataset_video.py b/videox_fun/data/dataset_video.py new file mode 100644 index 0000000000000000000000000000000000000000..230528eaab912d3151ee5bbf663f1de544c99543 --- /dev/null +++ b/videox_fun/data/dataset_video.py @@ -0,0 +1,901 @@ +import csv +import gc +import io +import json +import math +import os +import random +from contextlib import contextmanager +from threading import Thread + +import albumentations +import cv2 +import librosa +import numpy as np +import torch +import torchvision.transforms as transforms +from decord import VideoReader +from einops import rearrange +from func_timeout import FunctionTimedOut, func_timeout +from PIL import Image +from torch.utils.data import BatchSampler, Sampler +from torch.utils.data.dataset import Dataset + +from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager, + custom_meshgrid, get_random_mask, get_relative_pose, + get_video_reader_batch, padding_image, process_pose_file, + process_pose_params, ray_condition, resize_frame, + resize_image_with_target_area) + + +class WebVid10M(Dataset): + def __init__( + self, + csv_path, video_folder, + sample_size=256, sample_stride=4, sample_n_frames=16, + enable_bucket=False, enable_inpaint=False, is_image=False, + ): + print(f"loading annotations from {csv_path} ...") + with open(csv_path, 'r') as csvfile: + self.dataset = list(csv.DictReader(csvfile)) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.video_folder = video_folder + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + self.is_image = is_image + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose([ + transforms.Resize(sample_size[0]), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] + + video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") + video_reader = VideoReader(video_dir) + video_length = len(video_reader) + + if not self.is_image: + clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + else: + batch_index = [random.randint(0, video_length - 1)] + + if not self.enable_bucket: + pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = video_reader.get_batch(batch_index).asnumpy() + + if self.is_image: + pixel_values = pixel_values[0] + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + + except Exception as e: + print("Error info:", e) + idx = random.randint(0, self.length-1) + + if not self.enable_bucket: + pixel_values = self.pixel_transforms(pixel_values) + if self.enable_inpaint: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask + sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name) + else: + sample = dict(pixel_values=pixel_values, text=name) + return sample + + +class VideoDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + sample_size=256, sample_stride=4, sample_n_frames=16, + enable_bucket=False, enable_inpaint=False + ): + print(f"loading annotations from {ann_path} ...") + self.dataset = json.load(open(ann_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.data_root = data_root + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose( + [ + transforms.Resize(sample_size[0]), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, text = video_dict['file_path'], video_dict['text'] + + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + return pixel_values, text + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + sample = {} + try: + pixel_values, name = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["text"] = name + sample["idx"] = idx + if len(sample) > 0: + break + + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size()) + mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class VideoSpeechDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + enable_bucket=False, enable_inpaint=False, + audio_sr=16000, # 新增:目标音频采样率 + text_drop_ratio=0.1 # 新增:文本丢弃概率 + ): + print(f"loading annotations from {ann_path} ...") + self.dataset = json.load(open(ann_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.data_root = data_root + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + self.audio_sr = audio_sr + self.text_drop_ratio = text_drop_ratio + + video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.pixel_transforms = transforms.Compose( + [ + transforms.Resize(video_sample_size[0]), + transforms.CenterCrop(video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, text = video_dict['file_path'], video_dict['text'] + audio_id = video_dict['audio_path'] + + if self.data_root is None: + video_path = video_id + else: + video_path = os.path.join(self.data_root, video_id) + + if self.data_root is None: + audio_path = audio_id + else: + audio_path = os.path.join(self.data_root, audio_id) + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found for {video_path}") + + with VideoReader_contextmanager(video_path, num_threads=2) as video_reader: + total_frames = len(video_reader) + fps = video_reader.get_avg_fps() # 获取原始视频帧率 + + # 计算实际采样的视频帧数(考虑边界) + max_possible_frames = (total_frames - 1) // self.video_sample_stride + 1 + actual_n_frames = min(self.video_sample_n_frames, max_possible_frames) + if actual_n_frames <= 0: + raise ValueError(f"Video too short: {video_path}") + + # 随机选择起始帧 + max_start = total_frames - (actual_n_frames - 1) * self.video_sample_stride - 1 + start_frame = random.randint(0, max_start) if max_start > 0 else 0 + frame_indices = [start_frame + i * self.video_sample_stride for i in range(actual_n_frames)] + + # 读取视频帧 + try: + sample_args = (video_reader, frame_indices) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + # 视频后处理 + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + pixel_values = self.pixel_transforms(pixel_values) + + # === 新增:加载并截取对应音频 === + # 视频片段的起止时间(秒) + start_time = start_frame / fps + end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps + duration = end_time - start_time + + # 使用 librosa 加载整个音频(或仅加载所需部分,但 librosa.load 不支持精确 seek,所以先加载再切) + audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # 重采样到目标 sr + + # 转换为样本索引 + start_sample = int(start_time * self.audio_sr) + end_sample = int(end_time * self.audio_sr) + + # 安全截取 + if start_sample >= len(audio_input): + # 音频太短,用零填充或截断 + audio_segment = np.zeros(int(duration * self.audio_sr), dtype=np.float32) + else: + audio_segment = audio_input[start_sample:end_sample] + # 如果太短,补零 + target_len = int(duration * self.audio_sr) + if len(audio_segment) < target_len: + audio_segment = np.pad(audio_segment, (0, target_len - len(audio_segment)), mode='constant') + + # === 文本随机丢弃 === + if random.random() < self.text_drop_ratio: + text = '' + + return pixel_values, text, audio_segment, sample_rate + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + sample = {} + try: + pixel_values, text, audio, sample_rate = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["text"] = text + sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor + sample["sample_rate"] = sample_rate + sample["idx"] = idx + break + except Exception as e: + print(f"Error processing {idx}: {e}, retrying with random idx...") + idx = random.randint(0, self.length - 1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size(), image_start_only=True) + mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class VideoSpeechControlDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16, + enable_bucket=False, enable_inpaint=False, + audio_sr=16000, + text_drop_ratio=0.1, + enable_motion_info=False, + motion_frames=73, + ): + print(f"loading annotations from {ann_path} ...") + self.dataset = json.load(open(ann_path, 'r')) + self.length = len(self.dataset) + print(f"data scale: {self.length}") + + self.data_root = data_root + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.enable_bucket = enable_bucket + self.enable_inpaint = enable_inpaint + self.audio_sr = audio_sr + self.text_drop_ratio = text_drop_ratio + self.enable_motion_info = enable_motion_info + self.motion_frames = motion_frames + + video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.pixel_transforms = transforms.Compose( + [ + transforms.Resize(video_sample_size[0]), + transforms.CenterCrop(video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + self.video_sample_size = video_sample_size + + def get_batch(self, idx): + video_dict = self.dataset[idx] + video_id, text = video_dict['file_path'], video_dict['text'] + audio_id = video_dict['audio_path'] + control_video_id = video_dict['control_file_path'] + + if self.data_root is None: + video_path = video_id + else: + video_path = os.path.join(self.data_root, video_id) + + if self.data_root is None: + audio_path = audio_id + else: + audio_path = os.path.join(self.data_root, audio_id) + + if self.data_root is None: + control_video_id = control_video_id + else: + control_video_id = os.path.join(self.data_root, control_video_id) + + if not os.path.exists(audio_path): + raise FileNotFoundError(f"Audio file not found for {video_path}") + + # Video information + with VideoReader_contextmanager(video_path, num_threads=2) as video_reader: + total_frames = len(video_reader) + fps = video_reader.get_avg_fps() + if fps <= 0: + raise ValueError(f"Video has negative fps: {video_path}") + local_video_sample_stride = self.video_sample_stride + new_fps = int(fps // local_video_sample_stride) + while new_fps > 30: + local_video_sample_stride = local_video_sample_stride + 1 + new_fps = int(fps // local_video_sample_stride) + + max_possible_frames = (total_frames - 1) // local_video_sample_stride + 1 + actual_n_frames = min(self.video_sample_n_frames, max_possible_frames) + if actual_n_frames <= 0: + raise ValueError(f"Video too short: {video_path}") + + max_start = total_frames - (actual_n_frames - 1) * local_video_sample_stride - 1 + start_frame = random.randint(0, max_start) if max_start > 0 else 0 + frame_indices = [start_frame + i * local_video_sample_stride for i in range(actual_n_frames)] + + try: + sample_args = (video_reader, frame_indices) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + _, height, width, channel = np.shape(pixel_values) + if self.enable_motion_info: + motion_pixel_values = np.ones([self.motion_frames, height, width, channel]) * 127.5 + if start_frame > 0: + motion_max_possible_frames = (start_frame - 1) // local_video_sample_stride + 1 + motion_frame_indices = [0 + i * local_video_sample_stride for i in range(motion_max_possible_frames)] + motion_frame_indices = motion_frame_indices[-self.motion_frames:] + + _motion_sample_args = (video_reader, motion_frame_indices) + _motion_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=_motion_sample_args + ) + motion_pixel_values[-len(motion_frame_indices):] = _motion_pixel_values + + if not self.enable_bucket: + motion_pixel_values = torch.from_numpy(motion_pixel_values).permute(0, 3, 1, 2).contiguous() + motion_pixel_values = motion_pixel_values / 255. + motion_pixel_values = self.pixel_transforms(motion_pixel_values) + else: + motion_pixel_values = None + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + pixel_values = self.pixel_transforms(pixel_values) + + # Audio information + start_time = start_frame / fps + end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps + duration = end_time - start_time + + audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) + start_sample = int(start_time * self.audio_sr) + end_sample = int(end_time * self.audio_sr) + + if start_sample >= len(audio_input): + raise ValueError(f"Audio file too short: {audio_path}") + else: + audio_segment = audio_input[start_sample:end_sample] + target_len = int(duration * self.audio_sr) + if len(audio_segment) < target_len: + raise ValueError(f"Audio file too short: {audio_path}") + + # Control information + with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: + try: + sample_args = (control_video_reader, frame_indices) + control_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(control_pixel_values)): + frame = control_pixel_values[i] + resized_frame = resize_frame(frame, max(self.video_sample_size)) + resized_frames.append(resized_frame) + control_pixel_values = np.array(control_pixel_values) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + del control_video_reader + else: + control_pixel_values = control_pixel_values + + if not self.enable_bucket: + control_pixel_values = self.video_transforms(control_pixel_values) + + if random.random() < self.text_drop_ratio: + text = '' + + return pixel_values, motion_pixel_values, control_pixel_values, text, audio_segment, sample_rate, new_fps + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + sample = {} + try: + pixel_values, motion_pixel_values, control_pixel_values, text, audio, sample_rate, new_fps = self.get_batch(idx) + sample["pixel_values"] = pixel_values + sample["motion_pixel_values"] = motion_pixel_values + sample["control_pixel_values"] = control_pixel_values + sample["text"] = text + sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor + sample["sample_rate"] = sample_rate + sample["fps"] = new_fps + sample["idx"] = idx + break + except Exception as e: + print(f"Error processing {idx}: {e}, retrying with random idx...") + idx = random.randint(0, self.length - 1) + + if self.enable_inpaint and not self.enable_bucket: + mask = get_random_mask(pixel_values.size(), image_start_only=True) + mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask + sample["mask_pixel_values"] = mask_pixel_values + sample["mask"] = mask + + clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous() + clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255 + sample["clip_pixel_values"] = clip_pixel_values + + return sample + + +class VideoAnimateDataset(Dataset): + def __init__( + self, + ann_path, data_root=None, + video_sample_size=512, + video_sample_stride=4, + video_sample_n_frames=16, + video_repeat=0, + text_drop_ratio=0.1, + enable_bucket=False, + video_length_drop_start=0.1, + video_length_drop_end=0.9, + return_file_name=False, + ): + # Loading annotations from files + print(f"loading annotations from {ann_path} ...") + if ann_path.endswith('.csv'): + with open(ann_path, 'r') as csvfile: + dataset = list(csv.DictReader(csvfile)) + elif ann_path.endswith('.json'): + dataset = json.load(open(ann_path)) + + self.data_root = data_root + + # It's used to balance num of images and videos. + if video_repeat > 0: + self.dataset = [] + for data in dataset: + if data.get('type', 'image') != 'video': + self.dataset.append(data) + + for _ in range(video_repeat): + for data in dataset: + if data.get('type', 'image') == 'video': + self.dataset.append(data) + else: + self.dataset = dataset + del dataset + + self.length = len(self.dataset) + print(f"data scale: {self.length}") + # TODO: enable bucket training + self.enable_bucket = enable_bucket + self.text_drop_ratio = text_drop_ratio + + self.video_length_drop_start = video_length_drop_start + self.video_length_drop_end = video_length_drop_end + + # Video params + self.video_sample_stride = video_sample_stride + self.video_sample_n_frames = video_sample_n_frames + self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size) + self.video_transforms = transforms.Compose( + [ + transforms.Resize(min(self.video_sample_size)), + transforms.CenterCrop(self.video_sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ] + ) + + self.larger_side_of_image_and_video = min(self.video_sample_size) + + def get_batch(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + video_id, text = data_info['file_path'], data_info['text'] + + if self.data_root is None: + video_dir = video_id + else: + video_dir = os.path.join(self.data_root, video_id) + + with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader: + min_sample_n_frames = min( + self.video_sample_n_frames, + int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride) + ) + if min_sample_n_frames == 0: + raise ValueError(f"No Frames in video.") + + video_length = int(self.video_length_drop_end * len(video_reader)) + clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1) + start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0 + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int) + + try: + sample_args = (video_reader, batch_index) + pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(pixel_values)): + frame = pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + else: + pixel_values = pixel_values + + if not self.enable_bucket: + pixel_values = self.video_transforms(pixel_values) + + # Random use no text generation + if random.random() < self.text_drop_ratio: + text = '' + + control_video_id = data_info['control_file_path'] + + if control_video_id is not None: + if self.data_root is None: + control_video_id = control_video_id + else: + control_video_id = os.path.join(self.data_root, control_video_id) + + if control_video_id is not None: + with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader: + try: + sample_args = (control_video_reader, batch_index) + control_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(control_pixel_values)): + frame = control_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + control_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous() + control_pixel_values = control_pixel_values / 255. + del control_video_reader + else: + control_pixel_values = control_pixel_values + + if not self.enable_bucket: + control_pixel_values = self.video_transforms(control_pixel_values) + else: + if not self.enable_bucket: + control_pixel_values = torch.zeros_like(pixel_values) + else: + control_pixel_values = np.zeros_like(pixel_values) + + face_video_id = data_info['face_file_path'] + + if face_video_id is not None: + if self.data_root is None: + face_video_id = face_video_id + else: + face_video_id = os.path.join(self.data_root, face_video_id) + + if face_video_id is not None: + with VideoReader_contextmanager(face_video_id, num_threads=2) as face_video_reader: + try: + sample_args = (face_video_reader, batch_index) + face_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(face_pixel_values)): + frame = face_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + face_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + face_pixel_values = torch.from_numpy(face_pixel_values).permute(0, 3, 1, 2).contiguous() + face_pixel_values = face_pixel_values / 255. + del face_video_reader + else: + face_pixel_values = face_pixel_values + + if not self.enable_bucket: + face_pixel_values = self.video_transforms(face_pixel_values) + else: + if not self.enable_bucket: + face_pixel_values = torch.zeros_like(pixel_values) + else: + face_pixel_values = np.zeros_like(pixel_values) + + background_video_id = data_info.get('background_file_path', None) + + if background_video_id is not None: + if self.data_root is None: + background_video_id = background_video_id + else: + background_video_id = os.path.join(self.data_root, background_video_id) + + if background_video_id is not None: + with VideoReader_contextmanager(background_video_id, num_threads=2) as background_video_reader: + try: + sample_args = (background_video_reader, batch_index) + background_pixel_values = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(background_pixel_values)): + frame = background_pixel_values[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + background_pixel_values = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + background_pixel_values = torch.from_numpy(background_pixel_values).permute(0, 3, 1, 2).contiguous() + background_pixel_values = background_pixel_values / 255. + del background_video_reader + else: + background_pixel_values = background_pixel_values + + if not self.enable_bucket: + background_pixel_values = self.video_transforms(background_pixel_values) + else: + if not self.enable_bucket: + background_pixel_values = torch.ones_like(pixel_values) * 127.5 + else: + background_pixel_values = np.ones_like(pixel_values) * 127.5 + + mask_video_id = data_info.get('mask_file_path', None) + + if mask_video_id is not None: + if self.data_root is None: + mask_video_id = mask_video_id + else: + mask_video_id = os.path.join(self.data_root, mask_video_id) + + if mask_video_id is not None: + with VideoReader_contextmanager(mask_video_id, num_threads=2) as mask_video_reader: + try: + sample_args = (mask_video_reader, batch_index) + mask = func_timeout( + VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args + ) + resized_frames = [] + for i in range(len(mask)): + frame = mask[i] + resized_frame = resize_frame(frame, self.larger_side_of_image_and_video) + resized_frames.append(resized_frame) + mask = np.array(resized_frames) + except FunctionTimedOut: + raise ValueError(f"Read {idx} timeout.") + except Exception as e: + raise ValueError(f"Failed to extract frames from video. Error is {e}.") + + if not self.enable_bucket: + mask = torch.from_numpy(mask).permute(0, 3, 1, 2).contiguous() + mask = mask / 255. + del mask_video_reader + else: + mask = mask + else: + if not self.enable_bucket: + mask = torch.ones_like(pixel_values) + else: + mask = np.ones_like(pixel_values) * 255 + mask = mask[:, :, :, :1] + + ref_pixel_values_path = data_info.get('ref_file_path', []) + if self.data_root is not None: + ref_pixel_values_path = os.path.join(self.data_root, ref_pixel_values_path) + ref_pixel_values = Image.open(ref_pixel_values_path).convert('RGB') + + if not self.enable_bucket: + raise ValueError("Not enable_bucket is not supported now. ") + else: + ref_pixel_values = np.array(ref_pixel_values) + + return pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, text, "video" + + def __len__(self): + return self.length + + def __getitem__(self, idx): + data_info = self.dataset[idx % len(self.dataset)] + data_type = data_info.get('type', 'image') + while True: + sample = {} + try: + data_info_local = self.dataset[idx % len(self.dataset)] + data_type_local = data_info_local.get('type', 'image') + if data_type_local != data_type: + raise ValueError("data_type_local != data_type") + + pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, name, data_type = \ + self.get_batch(idx) + + sample["pixel_values"] = pixel_values + sample["control_pixel_values"] = control_pixel_values + sample["face_pixel_values"] = face_pixel_values + sample["background_pixel_values"] = background_pixel_values + sample["mask"] = mask + sample["ref_pixel_values"] = ref_pixel_values + sample["clip_pixel_values"] = ref_pixel_values + sample["text"] = name + sample["data_type"] = data_type + sample["idx"] = idx + + if len(sample) > 0: + break + except Exception as e: + print(e, self.dataset[idx % len(self.dataset)]) + idx = random.randint(0, self.length-1) + + return sample + + +if __name__ == "__main__": + if 1: + dataset = VideoDataset( + json_path="./webvidval/results_2M_val.json", + sample_size=256, + sample_stride=4, sample_n_frames=16, + ) + + if 0: + dataset = WebVid10M( + csv_path="./webvid/results_2M_val.csv", + video_folder="./webvid/2M_val", + sample_size=256, + sample_stride=4, sample_n_frames=16, + is_image=False, + ) + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,) + for idx, batch in enumerate(dataloader): + print(batch["pixel_values"].shape, len(batch["text"])) \ No newline at end of file diff --git a/videox_fun/data/utils.py b/videox_fun/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..514a41bae567fbf2eb4d80f7163fefceaf9e8974 --- /dev/null +++ b/videox_fun/data/utils.py @@ -0,0 +1,347 @@ +import csv +import gc +import io +import json +import math +import os +import random +from contextlib import contextmanager +from random import shuffle +from threading import Thread + +import albumentations +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms as transforms +from decord import VideoReader +from einops import rearrange +from func_timeout import FunctionTimedOut, func_timeout +from packaging import version as pver +from PIL import Image +from safetensors.torch import load_file +from torch.utils.data import BatchSampler, Sampler +from torch.utils.data.dataset import Dataset + +VIDEO_READER_TIMEOUT = 20 + +def get_random_mask(shape, image_start_only=False): + f, c, h, w = shape + mask = torch.zeros((f, 1, h, w), dtype=torch.uint8) + + if not image_start_only: + if f != 1: + mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05]) + else: + mask_index = np.random.choice([0, 1, 7, 8], p = [0.2, 0.7, 0.05, 0.05]) + if mask_index == 0: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 + block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 + + start_x = max(center_x - block_size_x // 2, 0) + end_x = min(center_x + block_size_x // 2, w) + start_y = max(center_y - block_size_y // 2, 0) + end_y = min(center_y + block_size_y // 2, h) + mask[:, :, start_y:end_y, start_x:end_x] = 1 + elif mask_index == 1: + mask[:, :, :, :] = 1 + elif mask_index == 2: + mask_frame_index = np.random.randint(1, 5) + mask[mask_frame_index:, :, :, :] = 1 + elif mask_index == 3: + mask_frame_index = np.random.randint(1, 5) + mask[mask_frame_index:-mask_frame_index, :, :, :] = 1 + elif mask_index == 4: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围 + block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围 + + start_x = max(center_x - block_size_x // 2, 0) + end_x = min(center_x + block_size_x // 2, w) + start_y = max(center_y - block_size_y // 2, 0) + end_y = min(center_y + block_size_y // 2, h) + + mask_frame_before = np.random.randint(0, f // 2) + mask_frame_after = np.random.randint(f // 2, f) + mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1 + elif mask_index == 5: + mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8) + elif mask_index == 6: + num_frames_to_mask = random.randint(1, max(f // 2, 1)) + frames_to_mask = random.sample(range(f), num_frames_to_mask) + + for i in frames_to_mask: + block_height = random.randint(1, h // 4) + block_width = random.randint(1, w // 4) + top_left_y = random.randint(0, h - block_height) + top_left_x = random.randint(0, w - block_width) + mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1 + elif mask_index == 7: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴 + b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴 + + for i in range(h): + for j in range(w): + if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1: + mask[:, :, i, j] = 1 + elif mask_index == 8: + center_x = torch.randint(0, w, (1,)).item() + center_y = torch.randint(0, h, (1,)).item() + radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() + for i in range(h): + for j in range(w): + if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2: + mask[:, :, i, j] = 1 + elif mask_index == 9: + for idx in range(f): + if np.random.rand() > 0.5: + mask[idx, :, :, :] = 1 + else: + raise ValueError(f"The mask_index {mask_index} is not define") + else: + if f != 1: + mask[1:, :, :, :] = 1 + else: + mask[:, :, :, :] = 1 + return mask + +@contextmanager +def VideoReader_contextmanager(*args, **kwargs): + vr = VideoReader(*args, **kwargs) + try: + yield vr + finally: + del vr + gc.collect() + +def get_video_reader_batch(video_reader, batch_index): + frames = video_reader.get_batch(batch_index).asnumpy() + return frames + +def resize_frame(frame, target_short_side): + h, w, _ = frame.shape + if h < w: + if target_short_side > h: + return frame + new_h = target_short_side + new_w = int(target_short_side * w / h) + else: + if target_short_side > w: + return frame + new_w = target_short_side + new_h = int(target_short_side * h / w) + + resized_frame = cv2.resize(frame, (new_w, new_h)) + return resized_frame + +def padding_image(images, new_width, new_height): + new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255)) + + aspect_ratio = images.width / images.height + if new_width / new_height > 1: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + else: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + + resized_img = images.resize((new_img_width, new_img_height)) + + paste_x = (new_width - new_img_width) // 2 + paste_y = (new_height - new_img_height) // 2 + + new_image.paste(resized_img, (paste_x, paste_y)) + + return new_image + +def resize_image_with_target_area(img: Image.Image, target_area: int = 1024 * 1024) -> Image.Image: + """ + 将 PIL 图像缩放到接近指定像素面积(target_area),保持原始宽高比, + 并确保新宽度和高度均为 32 的整数倍。 + + 参数: + img (PIL.Image.Image): 输入图像 + target_area (int): 目标像素总面积,例如 1024*1024 = 1048576 + + 返回: + PIL.Image.Image: Resize 后的图像 + """ + orig_w, orig_h = img.size + if orig_w == 0 or orig_h == 0: + raise ValueError("Input image has zero width or height.") + + ratio = orig_w / orig_h + ideal_width = math.sqrt(target_area * ratio) + ideal_height = ideal_width / ratio + + new_width = round(ideal_width / 32) * 32 + new_height = round(ideal_height / 32) * 32 + + new_width = max(32, new_width) + new_height = max(32, new_height) + + new_width = int(new_width) + new_height = int(new_height) + + resized_img = img.resize((new_width, new_height), Image.LANCZOS) + return resized_img + +class Camera(object): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + def __init__(self, entry): + fx, fy, cx, cy = entry[1:5] + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + w2c_mat = np.array(entry[7:]).reshape(3, 4) + w2c_mat_4x4 = np.eye(4) + w2c_mat_4x4[:3, :] = w2c_mat + self.w2c_mat = w2c_mat_4x4 + self.c2w_mat = np.linalg.inv(w2c_mat_4x4) + +def custom_meshgrid(*args): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid + if pver.parse(torch.__version__) < pver.parse('1.10'): + return torch.meshgrid(*args) + else: + return torch.meshgrid(*args, indexing='ij') + +def get_relative_pose(cam_params): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] + abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] + cam_to_origin = 0 + target_cam_c2w = np.array([ + [1, 0, 0, 0], + [0, 1, 0, -cam_to_origin], + [0, 0, 1, 0], + [0, 0, 0, 1] + ]) + abs2rel = target_cam_c2w @ abs_w2cs[0] + ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] + ret_poses = np.array(ret_poses, dtype=np.float32) + return ret_poses + +def ray_condition(K, c2w, H, W, device): + """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + # c2w: B, V, 4, 4 + # K: B, V, 4 + + B = K.shape[0] + + j, i = custom_meshgrid( + torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), + torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), + ) + i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW] + + fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 + + zs = torch.ones_like(i) # [B, HxW] + xs = (i - cx) / fx * zs + ys = (j - cy) / fy * zs + zs = zs.expand_as(ys) + + directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 + directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 + + rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW + rays_o = c2w[..., :3, 3] # B, V, 3 + rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW + # c2w @ dirctions + rays_dxo = torch.cross(rays_o, rays_d) + plucker = torch.cat([rays_dxo, rays_d], dim=-1) + plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 + # plucker = plucker.permute(0, 1, 4, 2, 3) + return plucker + +def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False): + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + with open(pose_file_path, 'r') as f: + poses = f.readlines() + + poses = [pose.strip().split(' ') for pose in poses[1:]] + cam_params = [[float(x) for x in pose] for pose in poses] + if return_poses: + return cam_params + else: + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding + +def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'): + """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py + """ + cam_params = [Camera(cam_param) for cam_param in cam_params] + + sample_wh_ratio = width / height + pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed + + if pose_wh_ratio > sample_wh_ratio: + resized_ori_w = height * pose_wh_ratio + for cam_param in cam_params: + cam_param.fx = resized_ori_w * cam_param.fx / width + else: + resized_ori_h = width / pose_wh_ratio + for cam_param in cam_params: + cam_param.fy = resized_ori_h * cam_param.fy / height + + intrinsic = np.asarray([[cam_param.fx * width, + cam_param.fy * height, + cam_param.cx * width, + cam_param.cy * height] + for cam_param in cam_params], dtype=np.float32) + + K = torch.as_tensor(intrinsic)[None] # [1, 1, 4] + c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere + c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4] + plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W + plucker_embedding = plucker_embedding[None] + plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0] + return plucker_embedding \ No newline at end of file diff --git a/videox_fun/pipeline/__init__.py b/videox_fun/pipeline/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..7409cfc1a1fca422330a84e5d7e2cf89c919cc40 --- /dev/null +++ b/videox_fun/pipeline/__init__.py @@ -0,0 +1,62 @@ +# from .pipeline_cogvideox_fun import CogVideoXFunPipeline +# from .pipeline_cogvideox_fun_control import CogVideoXFunControlPipeline +# from .pipeline_cogvideox_fun_inpaint import CogVideoXFunInpaintPipeline +# from .pipeline_fantasy_talking import FantasyTalkingPipeline +# from .pipeline_flux import FluxPipeline +# from .pipeline_flux2 import Flux2Pipeline +# from .pipeline_flux2_control import Flux2ControlPipeline +# from .pipeline_hunyuanvideo import HunyuanVideoPipeline +# from .pipeline_hunyuanvideo_i2v import HunyuanVideoI2VPipeline +# from .pipeline_qwenimage import QwenImagePipeline +# from .pipeline_qwenimage_edit import QwenImageEditPipeline +# from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline +# from .pipeline_wan import WanPipeline +# from .pipeline_wan2_2 import Wan2_2Pipeline +# from .pipeline_wan2_2_animate import Wan2_2AnimatePipeline +# from .pipeline_wan2_2_fun_control import Wan2_2FunControlPipeline +# from .pipeline_wan2_2_fun_inpaint import Wan2_2FunInpaintPipeline +# from .pipeline_wan2_2_s2v import Wan2_2S2VPipeline +# from .pipeline_wan2_2_ti2v import Wan2_2TI2VPipeline +# from .pipeline_wan2_2_vace_fun import Wan2_2VaceFunPipeline +# from .pipeline_wan_fun_control import WanFunControlPipeline +# from .pipeline_wan_fun_inpaint import WanFunInpaintPipeline +# from .pipeline_wan_phantom import WanFunPhantomPipeline +# from .pipeline_wan_vace import WanVacePipeline +from .pipeline_z_image import ZImagePipeline +from .pipeline_z_image_control import ZImageControlPipeline + +# WanFunPipeline = WanPipeline +# WanI2VPipeline = WanFunInpaintPipeline + +# Wan2_2FunPipeline = Wan2_2Pipeline +# Wan2_2I2VPipeline = Wan2_2FunInpaintPipeline + +# import importlib.util + +# if importlib.util.find_spec("paifuser") is not None: +# # --------------------------------------------------------------- # +# # Sparse Attention +# # --------------------------------------------------------------- # +# from paifuser.ops import sparse_reset + +# # Wan2.1 +# WanFunInpaintPipeline.__call__ = sparse_reset(WanFunInpaintPipeline.__call__) +# WanFunPipeline.__call__ = sparse_reset(WanFunPipeline.__call__) +# WanFunControlPipeline.__call__ = sparse_reset(WanFunControlPipeline.__call__) +# WanI2VPipeline.__call__ = sparse_reset(WanI2VPipeline.__call__) +# WanPipeline.__call__ = sparse_reset(WanPipeline.__call__) +# WanVacePipeline.__call__ = sparse_reset(WanVacePipeline.__call__) + +# # Phantom +# WanFunPhantomPipeline.__call__ = sparse_reset(WanFunPhantomPipeline.__call__) + +# # Wan2.2 +# Wan2_2FunInpaintPipeline.__call__ = sparse_reset(Wan2_2FunInpaintPipeline.__call__) +# Wan2_2FunPipeline.__call__ = sparse_reset(Wan2_2FunPipeline.__call__) +# Wan2_2FunControlPipeline.__call__ = sparse_reset(Wan2_2FunControlPipeline.__call__) +# Wan2_2Pipeline.__call__ = sparse_reset(Wan2_2Pipeline.__call__) +# Wan2_2I2VPipeline.__call__ = sparse_reset(Wan2_2I2VPipeline.__call__) +# Wan2_2TI2VPipeline.__call__ = sparse_reset(Wan2_2TI2VPipeline.__call__) +# Wan2_2S2VPipeline.__call__ = sparse_reset(Wan2_2S2VPipeline.__call__) +# Wan2_2VaceFunPipeline.__call__ = sparse_reset(Wan2_2VaceFunPipeline.__call__) +# Wan2_2AnimatePipeline.__call__ = sparse_reset(Wan2_2AnimatePipeline.__call__) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_cogvideox_fun.py b/videox_fun/pipeline/pipeline_cogvideox_fun.py new file mode 100644 index 0000000000000000000000000000000000000000..68568a6069ddc30ad4b11bb833b6307b5de3ceb6 --- /dev/null +++ b/videox_fun/pipeline/pipeline_cogvideox_fun.py @@ -0,0 +1,862 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLCogVideoX, + CogVideoXTransformer3DModel, T5EncoderModel, + T5Tokenizer) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed +def get_3d_rotary_pos_embed( + embed_dim, + crops_coords, + grid_size, + temporal_size, + theta: int = 10000, + use_real: bool = True, + grid_type: str = "linspace", + max_size: Optional[Tuple[int, int]] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + grid_type (`str`): + Whether to use "linspace" or "slice" to compute grids. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + + if grid_type == "linspace": + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + elif grid_type == "slice": + max_h, max_w = max_size + grid_size_h, grid_size_w = grid_size + grid_h = np.arange(max_h, dtype=np.float32) + grid_w = np.arange(max_w, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + else: + raise ValueError("Invalid value passed for `grid_type`.") + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + + if grid_type == "slice": + t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] + h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] + w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] + + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class CogVideoXFunPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class CogVideoXFunPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX_Fun. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + ) -> Union[CogVideoXFunPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 5. Prepare latents + latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + + # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if num_frames != 1 and patch_size_t is not None and latent_frames % patch_size_t != 0: + additional_frames = patch_size_t - latent_frames % patch_size_t + num_frames += additional_frames * self.vae_scale_factor_temporal + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return CogVideoXFunPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_cogvideox_fun_control.py b/videox_fun/pipeline/pipeline_cogvideox_fun_control.py new file mode 100644 index 0000000000000000000000000000000000000000..e91df20dab9553896394a42af3e698dfea33f333 --- /dev/null +++ b/videox_fun/pipeline/pipeline_cogvideox_fun_control.py @@ -0,0 +1,956 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import (get_1d_rotary_pos_embed, + get_3d_rotary_pos_embed) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange + +from ..models import (AutoencoderKLCogVideoX, + CogVideoXTransformer3DModel, T5EncoderModel, + T5Tokenizer) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed +def get_3d_rotary_pos_embed( + embed_dim, + crops_coords, + grid_size, + temporal_size, + theta: int = 10000, + use_real: bool = True, + grid_type: str = "linspace", + max_size: Optional[Tuple[int, int]] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + grid_type (`str`): + Whether to use "linspace" or "slice" to compute grids. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + + if grid_type == "linspace": + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + elif grid_type == "slice": + max_h, max_w = max_size + grid_size_h, grid_size_w = grid_size + grid_h = np.arange(max_h, dtype=np.float32) + grid_w = np.arange(max_w, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + else: + raise ValueError("Invalid value passed for `grid_type`.") + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + + if grid_type == "slice": + t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] + h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] + w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] + + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class CogVideoXFunPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class CogVideoXFunControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX_Fun uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + shape = ( + batch_size, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + control_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + comfyui_progressbar: bool = False, + ) -> Union[CogVideoXFunPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + if control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + else: + control_video = None + + # Magvae needs the number of frames to be 4n + 1. + local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and local_latent_length % patch_size_t != 0: + additional_frames = local_latent_length % patch_size_t + num_frames -= additional_frames * self.vae_scale_factor_temporal + if num_frames <= 0: + num_frames = 1 + if video_length > num_frames: + logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ") + video_length = num_frames + control_video = control_video[:, :, :video_length] + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance + )[1] + control_video_latents_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ) + control_latents = rearrange(control_video_latents_input, "b c f h w -> b f c h w") + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + control_latents=control_latents, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return CogVideoXFunPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py b/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..7044d9d034079b18d862c9d80df631204e30569a --- /dev/null +++ b/videox_fun/pipeline/pipeline_cogvideox_fun_inpaint.py @@ -0,0 +1,1136 @@ +# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange + +from ..models import (AutoencoderKLCogVideoX, + CogVideoXTransformer3DModel, T5EncoderModel, + T5Tokenizer) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + +# Copied from diffusers.models.embeddings.get_3d_rotary_pos_embed +def get_3d_rotary_pos_embed( + embed_dim, + crops_coords, + grid_size, + temporal_size, + theta: int = 10000, + use_real: bool = True, + grid_type: str = "linspace", + max_size: Optional[Tuple[int, int]] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + RoPE for video tokens with 3D structure. + + Args: + embed_dim: (`int`): + The embedding dimension size, corresponding to hidden_size_head. + crops_coords (`Tuple[int]`): + The top-left and bottom-right coordinates of the crop. + grid_size (`Tuple[int]`): + The grid size of the spatial positional embedding (height, width). + temporal_size (`int`): + The size of the temporal dimension. + theta (`float`): + Scaling factor for frequency computation. + grid_type (`str`): + Whether to use "linspace" or "slice" to compute grids. + + Returns: + `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. + """ + if use_real is not True: + raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") + + if grid_type == "linspace": + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + elif grid_type == "slice": + max_h, max_w = max_size + grid_size_h, grid_size_w = grid_size + grid_h = np.arange(max_h, dtype=np.float32) + grid_w = np.arange(max_w, dtype=np.float32) + grid_t = np.arange(temporal_size, dtype=np.float32) + else: + raise ValueError("Invalid value passed for `grid_type`.") + + # Compute dimensions for each axis + dim_t = embed_dim // 4 + dim_h = embed_dim // 8 * 3 + dim_w = embed_dim // 8 * 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) + freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) + + # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor + def combine_time_height_width(freqs_t, freqs_h, freqs_w): + freqs_t = freqs_t[:, None, None, :].expand( + -1, grid_size_h, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_w, dim_t + freqs_h = freqs_h[None, :, None, :].expand( + temporal_size, -1, grid_size_w, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_h + freqs_w = freqs_w[None, None, :, :].expand( + temporal_size, grid_size_h, -1, -1 + ) # temporal_size, grid_size_h, grid_size_2, dim_w + + freqs = torch.cat( + [freqs_t, freqs_h, freqs_w], dim=-1 + ) # temporal_size, grid_size_h, grid_size_w, (dim_t + dim_h + dim_w) + freqs = freqs.view( + temporal_size * grid_size_h * grid_size_w, -1 + ) # (temporal_size * grid_size_h * grid_size_w), (dim_t + dim_h + dim_w) + return freqs + + t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t + h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h + w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w + + if grid_type == "slice": + t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] + h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] + w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] + + cos = combine_time_height_width(t_cos, h_cos, w_cos) + sin = combine_time_height_width(t_sin, h_sin, w_sin) + return cos, sin + + +# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid +def get_resize_crop_region_for_grid(src, tgt_width, tgt_height): + tw = tgt_width + th = tgt_height + h, w = src + r = h / w + if r > (th / tw): + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +def add_noise_to_reference_video(image, ratio=None): + if ratio is None: + sigma = torch.normal(mean=-3.0, std=0.5, size=(image.shape[0],)).to(image.device) + sigma = torch.exp(sigma).to(image.dtype) + else: + sigma = torch.ones((image.shape[0],)).to(image.device, image.dtype) * ratio + + image_noise = torch.randn_like(image) * sigma[:, None, None, None, None] + image_noise = torch.where(image==-1, torch.zeros_like(image), image_noise) + image = image + image_noise + return image + + +@dataclass +class CogVideoXFunPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class CogVideoXFunInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using CogVideoX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. CogVideoX_Fun uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`CogVideoXTransformer3DModel`]): + A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLCogVideoX, + transformer: CogVideoXTransformer3DModel, + scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler], + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + video_length, + dtype, + device, + generator, + latents=None, + video=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_video_latents=False, + ): + shape = ( + batch_size, + (video_length - 1) // self.vae_scale_factor_temporal + 1, + num_channels_latents, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if return_video_latents or (latents is None and not is_strength_max): + video = video.to(device=device, dtype=self.vae.dtype) + + bs = 1 + new_video = [] + for i in range(0, video.shape[0], bs): + video_bs = video[i : i + bs] + video_bs = self.vae.encode(video_bs)[0] + video_bs = video_bs.sample() + new_video.append(video_bs) + video = torch.cat(new_video, dim = 0) + video = video * self.vae.config.scaling_factor + + video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1) + video_latents = video_latents.to(device=device, dtype=dtype) + video_latents = rearrange(video_latents, "b c f h w -> b f c h w") + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + # scale the initial noise by the standard deviation required by the scheduler + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_video_latents: + outputs += (video_latents,) + + return outputs + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + if self.transformer.config.add_noise_in_inpaint_model: + masked_image = add_noise_to_reference_video(masked_image, ratio=noise_aug_strength) + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width] + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def fuse_qkv_projections(self) -> None: + r"""Enables fused QKV projections.""" + self.fusing_transformer = True + self.transformer.fuse_qkv_projections() + + def unfuse_qkv_projections(self) -> None: + r"""Disable QKV projection fusion if enabled.""" + if not self.fusing_transformer: + logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.") + else: + self.transformer.unfuse_qkv_projections() + self.fusing_transformer = False + + def _prepare_rotary_positional_embeddings( + self, + height: int, + width: int, + num_frames: int, + device: torch.device, + ) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + p = self.transformer.config.patch_size + p_t = self.transformer.config.patch_size_t + + base_size_width = self.transformer.config.sample_width // p + base_size_height = self.transformer.config.sample_height // p + + if p_t is None: + # CogVideoX 1.0 + grid_crops_coords = get_resize_crop_region_for_grid( + (grid_height, grid_width), base_size_width, base_size_height + ) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + else: + # CogVideoX 1.5 + base_num_frames = (num_frames + p_t - 1) // p_t + + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=None, + grid_size=(grid_height, grid_width), + temporal_size=base_num_frames, + grid_type="slice", + max_size=(base_size_height, base_size_width), + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + masked_video_latents: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + use_dynamic_cfg: bool = False, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 226, + strength: float = 1, + noise_aug_strength: float = 0.0563, + comfyui_progressbar: bool = False, + ) -> Union[CogVideoXFunPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_frames (`int`, defaults to `48`): + Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will + contain 1 extra frame because CogVideoX_Fun is conditioned with (num_seconds * fps + 1) frames where + num_seconds is 6 and fps is 4. However, since videos can be saved at any fps, the only condition that + needs to be satisfied is that of divisibility mentioned above. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, defaults to `226`): + Maximum sequence length in encoded prompt. Must be consistent with + `self.transformer.config.max_text_seq_length` otherwise may lead to poor results. + + Examples: + + Returns: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] or `tuple`: + [`~pipelines.cogvideo.pipeline_cogvideox.CogVideoXFunPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + num_frames = num_frames or self.transformer.config.sample_frames + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + # Magvae needs the number of frames to be 4n + 1. + local_latent_length = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + # For CogVideoX 1.5, the latent frames should be clipped to make it divisible by patch_size_t + patch_size_t = self.transformer.config.patch_size_t + additional_frames = 0 + if patch_size_t is not None and local_latent_length % patch_size_t != 0: + additional_frames = local_latent_length % patch_size_t + num_frames -= additional_frames * self.vae_scale_factor_temporal + if num_frames <= 0: + num_frames = 1 + if video_length > num_frames: + logger.warning("The length of condition video is not right, the latent frames should be clipped to make it divisible by patch_size_t. ") + video_length = num_frames + video = video[:, :, :video_length] + init_video = init_video[:, :, :video_length] + mask_video = mask_video[:, :, :video_length] + + num_channels_latents = self.vae.config.latent_channels + num_channels_transformer = self.transformer.config.in_channels + return_image_latents = num_channels_transformer == num_channels_latents + + latents_outputs = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + video_length, + prompt_embeds.dtype, + device, + generator, + latents, + video=init_video, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_video_latents=return_image_latents, + ) + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + if comfyui_progressbar: + pbar.update(1) + + if mask_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.zeros_like(latents)[:, :, :1].to(latents.device, latents.dtype) + masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) + + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype) + else: + # Prepare mask latent variables + video_length = video.shape[2] + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + if num_channels_transformer != num_channels_latents: + mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1]) + if masked_video_latents is None: + masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1 + else: + masked_video = masked_video_latents + + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=noise_aug_strength, + ) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents) + mask_latents = mask_latents.to(masked_video_latents.device) * self.vae.config.scaling_factor + + mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) + + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + + mask = rearrange(mask, "b c f h w -> b f c h w") + mask_input = rearrange(mask_input, "b c f h w -> b f c h w") + masked_video_latents_input = rearrange(masked_video_latents_input, "b c f h w -> b f c h w") + + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=2).to(latents.dtype) + else: + mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) + mask = rearrange(mask, "b c f h w -> b f c h w") + + inpaint_latents = None + else: + if num_channels_transformer != num_channels_latents: + mask = torch.zeros_like(latents).to(latents.device, latents.dtype) + masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype) + + mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype) + else: + mask = torch.zeros_like(init_video[:, :1]) + mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1]) + mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype) + mask = rearrange(mask, "b c f h w -> b f c h w") + + inpaint_latents = None + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Create rotary embeds if required + image_rotary_emb = ( + self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device) + if self.transformer.config.use_rotary_positional_embeddings + else None + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + # for DPM-solver++ + old_pred_original_sample = None + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + inpaint_latents=inpaint_latents, + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if use_dynamic_cfg: + self._guidance_scale = 1 + guidance_scale * ( + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 + ) + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + if not isinstance(self.scheduler, CogVideoXDPMScheduler): + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + else: + latents, old_pred_original_sample = self.scheduler.step( + noise_pred, + old_pred_original_sample, + t, + timesteps[i - 1] if i > 0 else None, + latents, + **extra_step_kwargs, + return_dict=False, + ) + latents = latents.to(prompt_embeds.dtype) + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return CogVideoXFunPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_fantasy_talking.py b/videox_fun/pipeline/pipeline_fantasy_talking.py new file mode 100644 index 0000000000000000000000000000000000000000..114bd2280f5d8b6b7feecf66ddb5c1ba00c31631 --- /dev/null +++ b/videox_fun/pipeline/pipeline_fantasy_talking.py @@ -0,0 +1,754 @@ +import inspect +import math +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from torchvision import transforms +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + Wan2_2Transformer3DModel_S2V, WanAudioEncoder, + WanT5EncoderModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class FantasyTalkingPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2", "audio_encoder"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + audio_encoder: WanAudioEncoder, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel_S2V, + clip_image_encoder: CLIPModel, + transformer_2: Wan2_2Transformer3DModel_S2V = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler, clip_image_encoder=clip_image_encoder, audio_encoder=audio_encoder + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + audio_path = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clip_image: Image = None, + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + fps: int = 16, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.tile( + torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1] + ) + masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + else: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype) + + # Prepare clip latent variables + if clip_image is not None: + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + else: + clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0)) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = torch.zeros_like(clip_context) + + # Extract audio emb + audio_wav2vec_fea = self.audio_encoder.extract_audio_feat(audio_path, num_frames=num_frames, fps=fps) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if init_video is not None: + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + + clip_context_input = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + + audio_wav2vec_fea_input = ( + torch.cat([audio_wav2vec_fea] * 2) if do_classifier_free_guidance else audio_wav2vec_fea + ) + + audio_scale = torch.tensor( + [0.75, 1] + ).to(latent_model_input.device, latent_model_input.dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=y, + audio_wav2vec_fea=audio_wav2vec_fea_input, + audio_scale=audio_scale, + clip_fea=clip_context_input, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_flux.py b/videox_fun/pipeline/pipeline_flux.py new file mode 100644 index 0000000000000000000000000000000000000000..15cea7226fc319609dcb246b2a905c46f28e9e44 --- /dev/null +++ b/videox_fun/pipeline/pipeline_flux.py @@ -0,0 +1,978 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux.py +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor, PipelineImageInput +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (CLIPImageProcessor, CLIPTextModel, + CLIPTokenizer, CLIPVisionModelWithProjection, + FluxTransformer2DModel, T5EncoderModel, AutoencoderKL, + T5TokenizerFast) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import FluxPipeline + + >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0] + >>> image.save("flux.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class FluxPipelineOutput(BaseOutput): + """ + Output class for Flux image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +@dataclass +class FluxPriorReduxPipelineOutput(BaseOutput): + """ + Output class for Flux Prior Redux pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + prompt_embeds: torch.Tensor + pooled_prompt_embeds: torch.Tensor + + +class FluxPipeline( + DiffusionPipeline, +): + r""" + The Flux pipeline for text-to-image generation. + + Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ + + Args: + transformer ([`FluxTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`T5TokenizerFast`): + Second Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + text_encoder_2: T5EncoderModel, + tokenizer_2: T5TokenizerFast, + transformer: FluxTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = ( + self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 + ) + self.default_sample_size = 128 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] + + dtype = self.text_encoder_2.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer_max_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) + + # Use pooled output of CLIPTextModel + prompt_embeds = prompt_embeds.pooler_output + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Optional[Union[str, List[str]]] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # We only use the pooled prompt output from the CLIPTextModel + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + ) + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt_2, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + true_cfg_scale: float = 1.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 28, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 3.5, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_flux2.py b/videox_fun/pipeline/pipeline_flux2.py new file mode 100644 index 0000000000000000000000000000000000000000..26a8741cf93a9977aee49a2887aa2d60683acf8a --- /dev/null +++ b/videox_fun/pipeline/pipeline_flux2.py @@ -0,0 +1,900 @@ +# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2.py +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLFlux2, Flux2ImageProcessor, + Flux2Transformer2DModel, Mistral3ForConditionalGeneration, AutoProcessor) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2Pipeline + + >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] + >>> image.save("flux.png") + ``` +""" + + +def format_text_input(prompts: List[str], system_message: str = None): + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +@dataclass +class Flux2PipelineOutput(BaseOutput): + """ + Output class for Flux2 image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class Flux2Pipeline(DiffusionPipeline): + r""" + The Flux2 pipeline for text-to-image generation. + + Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Mistral3ForConditionalGeneration`]): + [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration) + tokenizer (`AutoProcessor`): + Tokenizer of class + [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + transformer: Flux2Transformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + # fmt: off + self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + + @staticmethod + def _get_mistral_3_small_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: List[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_mistral_3_small_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _prepare_image_ids( + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 5. prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + torch.save({"pred": latents}, "pred_d.pt") + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_flux2_control.py b/videox_fun/pipeline/pipeline_flux2_control.py new file mode 100644 index 0000000000000000000000000000000000000000..f2c5aee517406464be1f6442921b02cd92140eb2 --- /dev/null +++ b/videox_fun/pipeline/pipeline_flux2_control.py @@ -0,0 +1,973 @@ +# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/pipeline_flux2.py +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.image_processor import VaeImageProcessor +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLFlux2, Flux2ImageProcessor, + Flux2ControlTransformer2DModel, Mistral3ForConditionalGeneration, AutoProcessor) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2Pipeline + + >>> pipe = Flux2Pipeline.from_pretrained("black-forest-labs/FLUX.2-dev", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50, guidance_scale=2.5).images[0] + >>> image.save("flux.png") + ``` +""" + + +def format_text_input(prompts: List[str], system_message: str = None): + # Remove [IMG] tokens from prompts to avoid Pixtral validation issues + # when truncation is enabled. The processor counts [IMG] tokens and fails + # if the count changes after truncation. + cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts] + + return [ + [ + { + "role": "system", + "content": [{"type": "text", "text": system_message}], + }, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + for prompt in cleaned_txt + ] + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +@dataclass +class Flux2PipelineOutput(BaseOutput): + """ + Output class for Flux2 image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class Flux2ControlPipeline(DiffusionPipeline): + r""" + The Flux2 pipeline for text-to-image generation. + + Reference: [https://bfl.ai/blog/flux-2](https://bfl.ai/blog/flux-2) + + Args: + transformer ([`Flux2ControlTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Mistral3ForConditionalGeneration`]): + [Mistral3ForConditionalGeneration](https://huggingface.co/docs/transformers/en/model_doc/mistral3#transformers.Mistral3ForConditionalGeneration) + tokenizer (`AutoProcessor`): + Tokenizer of class + [PixtralProcessor](https://huggingface.co/docs/transformers/en/model_doc/pixtral#transformers.PixtralProcessor). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + transformer: Flux2ControlTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.diffusers_image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + # fmt: off + self.system_message = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation." + # fmt: on + + @staticmethod + def _get_mistral_3_small_prompt_embeds( + text_encoder: Mistral3ForConditionalGeneration, + tokenizer: AutoProcessor, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + # fmt: off + system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.", + # fmt: on + hidden_states_layers: List[int] = (10, 20, 30), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Format input messages + messages_batch = format_text_input(prompts=prompt, system_message=system_message) + + # Process all messages at once + inputs = tokenizer.apply_chat_template( + messages_batch, + add_generation_prompt=False, + tokenize=True, + return_dict=True, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + # Move to device + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_mistral_3_small_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + system_message=self.system_message, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + def _prepare_image_ids( + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + + @staticmethod + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + latent_ids = self._prepare_latent_ids(latents) + latent_ids = latent_ids.to(device) + + latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C] + return latents, latent_ids + + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + imagge_latent = self._encode_vae_image(image=image, generator=generator) + image_latents.append(imagge_latent) # (1, 128, 32, 32) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + # latent: (1, 128, 32, 32) + packed = self._pack_latents(latent) # (1, 1024, 128) + packed = packed.squeeze(0) # (1024, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128) + + image_latents = image_latents.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1) + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + + def check_inputs( + self, + prompt, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + + image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + inpaint_image: Optional[Union[List[PIL.Image.Image], PIL.Image.Image]] = None, + mask_image: Union[torch.FloatTensor] = None, + control_image: Union[torch.FloatTensor] = None, + control_context_scale: float = 1.0, + + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (10, 20, 30), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + guidance_scale (`float`, *optional*, defaults to 1.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + height=height, + width=width, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device, weight_dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + device, weight_dtype + ) + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + num_channels_latents = self.transformer.config.in_channels // 4 + + # Prepare mask latent variables + if mask_image is not None: + mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width) + mask_condition = torch.tile(mask_condition, [1, 3, 1, 1]).to(dtype=weight_dtype, device=device) + + if inpaint_image is not None: + init_image = self.diffusers_image_processor.preprocess(inpaint_image, height=height, width=width) + init_image = init_image.to(dtype=weight_dtype, device=device) * (mask_condition < 0.5) + inpaint_latent = self.vae.encode(init_image)[0].mode() + else: + inpaint_latent = torch.zeros((batch_size, num_channels_latents * 4, height // 2 // self.vae_scale_factor, width // 2 // self.vae_scale_factor)).to(device, weight_dtype) + + if control_image is not None: + control_image = self.diffusers_image_processor.preprocess(control_image, height=height, width=width) + control_image = control_image.to(dtype=weight_dtype, device=device) + control_latents = self.vae.encode(control_image)[0].mode() + else: + control_latents = torch.zeros_like(inpaint_latent) + + mask_condition = F.interpolate(1 - mask_condition[:, :1], size=control_latents.size()[-2:], mode='nearest').to(device, weight_dtype) + mask_condition = self._patchify_latents(mask_condition) + mask_condition = self._pack_latents(mask_condition) + + if inpaint_image is not None: + inpaint_latent = self._patchify_latents(inpaint_latent) + inpaint_latent = (inpaint_latent - latents_bn_mean) / latents_bn_std + inpaint_latent = self._pack_latents(inpaint_latent) + else: + inpaint_latent = self._patchify_latents(inpaint_latent) + inpaint_latent = self._pack_latents(inpaint_latent) + + if control_image is not None: + control_latents = self._patchify_latents(control_latents) + control_latents = (control_latents - latents_bn_mean) / latents_bn_std + control_latents = self._pack_latents(control_latents) + else: + control_latents = self._patchify_latents(control_latents) + control_latents = self._pack_latents(control_latents) + control_context = torch.concat([control_latents, mask_condition, inpaint_latent], dim=2) + + # 3. prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 4. process images + if image is not None and not isinstance(image, list): + image = [image] + + condition_images = None + if image is not None: + for img in image: + self.image_processor.check_image_input(img) + + condition_images = [] + for img in image: + image_width, image_height = img.size + if image_width * image_height > 1024 * 1024: + img = self.image_processor._resize_to_target_area(img, 1024 * 1024) + image_width, image_height = img.size + + multiple_of = self.vae_scale_factor * 2 + image_width = (image_width // multiple_of) * multiple_of + image_height = (image_height // multiple_of) * multiple_of + img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop") + condition_images.append(img) + height = height or image_height + width = width or image_width + + # 5. prepare latent variables + latents, latent_ids = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latents_channels=num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=latents, + ) + + image_latents = None + image_latent_ids = None + if condition_images is not None: + image_latents, image_latent_ids = self.prepare_image_latents( + images=condition_images, + batch_size=batch_size * num_images_per_prompt, + generator=generator, + device=device, + dtype=self.vae.dtype, + ) + + # 6. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = latents.shape[1] + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + + # 7. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + control_context_input = control_context.to(self.transformer.dtype) + latent_image_ids = latent_ids + + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) + latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) + + local_bs, local_length, local_c = control_context.size() + control_context_input = torch.cat( + [ + control_context, + torch.zeros( + [ + local_bs, + image_latents.size()[1], + local_c + ] + ).to(control_context.device, control_context.dtype)], + dim=1 + ).to(self.transformer.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self._attention_kwargs, + control_context=control_context_input, + control_context_scale=control_context_scale, + return_dict=False, + )[0] + + noise_pred = noise_pred[:, : latents.size(1) :] + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents_with_ids(latents, latent_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_hunyuanvideo.py b/videox_fun/pipeline/pipeline_hunyuanvideo.py new file mode 100644 index 0000000000000000000000000000000000000000..9afe5c7939b822d4b40f5f8402b2fad08e424f2f --- /dev/null +++ b/videox_fun/pipeline/pipeline_hunyuanvideo.py @@ -0,0 +1,805 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import HunyuanVideoLoraLoaderMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, deprecate, is_torch_xla_available, + logging, replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLHunyuanVideo, CLIPImageProcessor, + CLIPTextModel, CLIPTokenizer, + HunyuanVideoTransformer3DModel, LlamaModel, + LlamaTokenizerFast, LlavaForConditionalGeneration) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + >>> import torch + >>> from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel + >>> from diffusers.utils import export_to_video + + >>> model_id = "hunyuanvideo-community/HunyuanVideo" + >>> transformer = HunyuanVideoTransformer3DModel.from_pretrained( + ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16 + ... ) + >>> pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16) + >>> pipe.vae.enable_tiling() + >>> pipe.to("cuda") + + >>> output = pipe( + ... prompt="A cat walks on the grass, realistic", + ... height=320, + ... width=512, + ... num_frames=61, + ... num_inference_steps=30, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=15) + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + ), + "crop_start": 95, +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for video pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class HunyuanVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlamaModel, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + ) + + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|eot_id|> token and placeholder {} + crop_start -= 2 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_attention_mask, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + prompt_embeds = prompt_embeds[:, crop_start:] + prompt_attention_mask = prompt_attention_mask[:, crop_start:] + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_videos_per_prompt) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_videos_per_prompt, seq_len) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1) + + return prompt_embeds + + def encode_prompt( + self, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + ): + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + max_sequence_length: int = 256, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + guidance_scale (`float`, defaults to `6.0`): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Prepare guidance condition + guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return HunyuanVideoPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py b/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py new file mode 100644 index 0000000000000000000000000000000000000000..e2628ef48c7ba244e8c47e341db960c72454e844 --- /dev/null +++ b/videox_fun/pipeline/pipeline_hunyuanvideo_i2v.py @@ -0,0 +1,972 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +# Copyright 2025 The HunyuanVideo Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.loaders import HunyuanVideoLoraLoaderMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, deprecate, is_torch_xla_available, + logging, replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLHunyuanVideo, CLIPImageProcessor, + CLIPTextModel, CLIPTokenizer, + HunyuanVideoTransformer3DModel, LlamaModel, + LlamaTokenizerFast, LlavaForConditionalGeneration) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + ``` +""" + + +DEFAULT_PROMPT_TEMPLATE = { + "template": ( + "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " + "1. The main content and theme of the video." + "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects." + "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects." + "4. background environment, light, style and atmosphere." + "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n" + "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ), + "crop_start": 103, + "image_emb_start": 5, + "image_emb_end": 581, + "image_emb_len": 576, + "double_return_token_id": 271, +} + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, +): + special_image_token_mask = text_input_ids == image_token_index + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index) + + max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1)) + new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1 + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + + expanded_input_ids = torch.full( + (text_input_ids.shape[0], max_expanded_length), + pad_token_id, + dtype=text_input_ids.dtype, + device=text_input_ids.device, + ) + expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices] + expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index + + expanded_attention_mask = torch.zeros( + (text_input_ids.shape[0], max_expanded_length), + dtype=prompt_attention_mask.dtype, + device=prompt_attention_mask.device, + ) + attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id) + expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0 + expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype) + position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1) + + return { + "input_ids": expanded_input_ids, + "attention_mask": expanded_attention_mask, + "position_ids": position_ids, + } + + +@dataclass +class HunyuanVideoPipelineOutput(BaseOutput): + r""" + Output class for video pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +class HunyuanVideoI2VPipeline(DiffusionPipeline): + r""" + Pipeline for image-to-video generation using HunyuanVideo. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + text_encoder ([`LlamaModel`]): + [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + tokenizer (`LlamaTokenizer`): + Tokenizer from [Llava Llama3-8B](https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers). + transformer ([`HunyuanVideoTransformer3DModel`]): + Conditional Transformer to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder_2 ([`CLIPTextModel`]): + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer_2 (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + text_encoder: LlavaForConditionalGeneration, + tokenizer: LlamaTokenizerFast, + transformer: HunyuanVideoTransformer3DModel, + vae: AutoencoderKLHunyuanVideo, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + image_processor=image_processor, + ) + + self.vae_scaling_factor = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.476986 + self.vae_scale_factor_temporal = self.vae.temporal_compression_ratio if getattr(self, "vae", None) else 4 + self.vae_scale_factor_spatial = self.vae.spatial_compression_ratio if getattr(self, "vae", None) else 8 + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + def _get_llama_prompt_embeds( + self, + image: torch.Tensor, + prompt: Union[str, List[str]], + prompt_template: Dict[str, Any], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + num_hidden_layers_to_skip: int = 2, + image_embed_interleave: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor]: + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_template["template"].format(p) for p in prompt] + + crop_start = prompt_template.get("crop_start", None) + + image_emb_len = prompt_template.get("image_emb_len", 576) + image_emb_start = prompt_template.get("image_emb_start", 5) + image_emb_end = prompt_template.get("image_emb_end", 581) + double_return_token_id = prompt_template.get("double_return_token_id", 271) + + if crop_start is None: + prompt_template_input = self.tokenizer( + prompt_template["template"], + padding="max_length", + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=False, + ) + crop_start = prompt_template_input["input_ids"].shape[-1] + # Remove <|start_header_id|>, <|end_header_id|>, assistant, <|eot_id|>, and placeholder {} + crop_start -= 5 + + max_sequence_length += crop_start + text_inputs = self.tokenizer( + prompt, + max_length=max_sequence_length, + padding="max_length", + truncation=True, + return_tensors="pt", + return_length=False, + return_overflowing_tokens=False, + return_attention_mask=True, + ) + text_input_ids = text_inputs.input_ids.to(device=device) + prompt_attention_mask = text_inputs.attention_mask.to(device=device) + + image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device) + + image_token_index = self.text_encoder.config.image_token_index + pad_token_id = self.text_encoder.config.pad_token_id + expanded_inputs = _expand_input_ids_with_image_tokens( + text_input_ids, + prompt_attention_mask, + max_sequence_length, + image_token_index, + image_emb_len, + image_emb_start, + image_emb_end, + pad_token_id, + ) + prompt_embeds = self.text_encoder( + **expanded_inputs, + pixel_values=image_embeds, + output_hidden_states=True, + ).hidden_states[-(num_hidden_layers_to_skip + 1)] + prompt_embeds = prompt_embeds.to(dtype=dtype) + + if crop_start is not None and crop_start > 0: + text_crop_start = crop_start - 1 + image_emb_len + batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id) + + if last_double_return_token_indices.shape[0] == 3: + # in case the prompt is too long + last_double_return_token_indices = torch.cat( + (last_double_return_token_indices, torch.tensor([text_input_ids.shape[-1]])) + ) + batch_indices = torch.cat((batch_indices, torch.tensor([0]))) + + last_double_return_token_indices = last_double_return_token_indices.reshape(text_input_ids.shape[0], -1)[ + :, -1 + ] + batch_indices = batch_indices.reshape(text_input_ids.shape[0], -1)[:, -1] + assistant_crop_start = last_double_return_token_indices - 1 + image_emb_len - 4 + assistant_crop_end = last_double_return_token_indices - 1 + image_emb_len + attention_mask_assistant_crop_start = last_double_return_token_indices - 4 + attention_mask_assistant_crop_end = last_double_return_token_indices + + prompt_embed_list = [] + prompt_attention_mask_list = [] + image_embed_list = [] + image_attention_mask_list = [] + + for i in range(text_input_ids.shape[0]): + prompt_embed_list.append( + torch.cat( + [ + prompt_embeds[i, text_crop_start : assistant_crop_start[i].item()], + prompt_embeds[i, assistant_crop_end[i].item() :], + ] + ) + ) + prompt_attention_mask_list.append( + torch.cat( + [ + prompt_attention_mask[i, crop_start : attention_mask_assistant_crop_start[i].item()], + prompt_attention_mask[i, attention_mask_assistant_crop_end[i].item() :], + ] + ) + ) + image_embed_list.append(prompt_embeds[i, image_emb_start:image_emb_end]) + image_attention_mask_list.append( + torch.ones(image_embed_list[-1].shape[0]).to(prompt_embeds.device).to(prompt_attention_mask.dtype) + ) + + prompt_embed_list = torch.stack(prompt_embed_list) + prompt_attention_mask_list = torch.stack(prompt_attention_mask_list) + image_embed_list = torch.stack(image_embed_list) + image_attention_mask_list = torch.stack(image_attention_mask_list) + + if 0 < image_embed_interleave < 6: + image_embed_list = image_embed_list[:, ::image_embed_interleave, :] + image_attention_mask_list = image_attention_mask_list[:, ::image_embed_interleave] + + assert ( + prompt_embed_list.shape[0] == prompt_attention_mask_list.shape[0] + and image_embed_list.shape[0] == image_attention_mask_list.shape[0] + ) + + prompt_embeds = torch.cat([image_embed_list, prompt_embed_list], dim=1) + prompt_attention_mask = torch.cat([image_attention_mask_list, prompt_attention_mask_list], dim=1) + + return prompt_embeds, prompt_attention_mask + + def _get_clip_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 77, + ) -> torch.Tensor: + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + text_inputs = self.tokenizer_2( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output + return prompt_embeds + + def encode_prompt( + self, + image: torch.Tensor, + prompt: Union[str, List[str]], + prompt_2: Union[str, List[str]] = None, + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + max_sequence_length: int = 256, + image_embed_interleave: int = 2, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_llama_prompt_embeds( + image, + prompt, + prompt_template, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, + ) + + if pooled_prompt_embeds is None: + if prompt_2 is None: + prompt_2 = prompt + pooled_prompt_embeds = self._get_clip_prompt_embeds( + prompt, + num_videos_per_prompt, + device=device, + dtype=dtype, + max_sequence_length=77, + ) + + return prompt_embeds, pooled_prompt_embeds, prompt_attention_mask + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + prompt_template=None, + true_cfg_scale=1.0, + guidance_scale=1.0, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if prompt_template is not None: + if not isinstance(prompt_template, dict): + raise ValueError(f"`prompt_template` has to be of type `dict` but is {type(prompt_template)}") + if "template" not in prompt_template: + raise ValueError( + f"`prompt_template` has to contain a key `template` but only found {prompt_template.keys()}" + ) + + if true_cfg_scale > 1.0 and guidance_scale > 1.0: + logger.warning( + "Both `true_cfg_scale` and `guidance_scale` are greater than 1.0. This will result in both " + "classifier-free guidance and embedded-guidance to be applied. This is not recommended " + "as it may lead to higher memory usage, slower inference and potentially worse results." + ) + + def prepare_latents( + self, + image: torch.Tensor, + batch_size: int, + num_channels_latents: int = 32, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + latent_height, latent_width = height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial + shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) + + image = image.unsqueeze(2) # [B, C, 1, H, W] + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i], "argmax") + for i in range(batch_size) + ] + else: + image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator, "argmax") for img in image] + + image_latents = torch.cat(image_latents, dim=0).to(dtype) * self.vae_scaling_factor + image_latents = image_latents.repeat(1, 1, num_latent_frames, 1, 1) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device=device, dtype=dtype) + + t = torch.tensor([0.999]).to(device=device) + latents = latents * t + image_latents * (1 - t) + + image_latents = image_latents[:, :, :1] + return latents, image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = 1 / self.vae.config.scaling_factor * latents + + frames = self.vae.decode(latents).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + negative_prompt_2: Union[str, List[str]] = None, + height: int = 720, + width: int = 1280, + num_frames: int = 129, + num_inference_steps: int = 50, + sigmas: List[float] = None, + true_cfg_scale: float = 1.0, + guidance_scale: float = 6.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + prompt_template: Dict[str, Any] = DEFAULT_PROMPT_TEMPLATE, + image: PIL.Image.Image = None, + max_sequence_length: int = 256, + image_embed_interleave: Optional[int] = None, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + height (`int`, defaults to `720`): + The height in pixels of the generated image. + width (`int`, defaults to `1280`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `129`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + guidance_scale (`float`, defaults to `1.0`): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. Note that the only available + HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and + conditional latent is not applied. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~HunyuanVideoPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + prompt_embeds, + callback_on_step_end_tensor_inputs, + prompt_template, + true_cfg_scale, + guidance_scale, + ) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + image_embed_interleave = ( + image_embed_interleave + if image_embed_interleave is not None + else 4 + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Prepare latent variables + vae_dtype = self.vae.dtype + image_tensor = self.video_processor.preprocess(image, height, width).to(device, vae_dtype) + + num_channels_latents = self.transformer.config.in_channels + + latents, image_latents = self.prepare_latents( + image_tensor, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 4. Encode input prompt + transformer_dtype = self.transformer.dtype + prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = self.encode_prompt( + image=image, + prompt=prompt, + prompt_2=prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + image_embed_interleave=image_embed_interleave, + ) + prompt_embeds = prompt_embeds.to(transformer_dtype) + prompt_attention_mask = prompt_attention_mask.to(transformer_dtype) + pooled_prompt_embeds = pooled_prompt_embeds.to(transformer_dtype) + + if do_true_cfg: + black_image = PIL.Image.new("RGB", (width, height), 0) + negative_prompt_embeds, negative_pooled_prompt_embeds, negative_prompt_attention_mask = self.encode_prompt( + image=black_image, + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_template=prompt_template, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + prompt_attention_mask=negative_prompt_attention_mask, + device=device, + max_sequence_length=max_sequence_length, + ) + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + negative_prompt_attention_mask = negative_prompt_attention_mask.to(transformer_dtype) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(transformer_dtype) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) + + # 6. Prepare guidance condition + guidance = None + if self.transformer.config.guidance_embeds: + guidance = ( + torch.tensor([guidance_scale] * latents.shape[0], dtype=transformer_dtype, device=device) * 1000.0 + ) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + + if do_true_cfg: + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents = latents = self.scheduler.step( + noise_pred[:, :, 1:], t, latents[:, :, 1:], return_dict=False + )[0] + latents = torch.cat([image_latents, latents], dim=2) + latents = latents.to(self.vae.dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return HunyuanVideoPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_qwenimage.py b/videox_fun/pipeline/pipeline_qwenimage.py new file mode 100644 index 0000000000000000000000000000000000000000..e61ac0e700bb2ee5757eb5118f0d29325cca0fdf --- /dev/null +++ b/videox_fun/pipeline/pipeline_qwenimage.py @@ -0,0 +1,767 @@ +# Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLQwenImage, + Qwen2_5_VLForConditionalGeneration, + Qwen2Tokenizer, QwenImageTransformer2DModel) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import QwenImagePipeline + + >>> pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "A cat holding a sign that says hello world" + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimage.png") + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class QwenImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class QwenImagePipeline(DiffusionPipeline): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 34 + self.default_sample_size = 128 + + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + txt_tokens = self.tokenizer( + txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt" + ).to(device) + encoder_hidden_states = self.text_encoder( + input_ids=txt_tokens.input_ids, + attention_mask=txt_tokens.attention_mask, + output_hidden_states=True, + ) + hidden_states = encoder_hidden_states.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) + + prompt_embeds = prompt_embeds[:, :max_sequence_length] + prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length] + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + if latents is not None: + return latents.to(device=device, dtype=dtype) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 1.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + + This parameter in the pipeline is there to support future guidance-distilled models when they come up. + Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, + please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should + enable classifier-free guidance computations. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size + if comfyui_progressbar: + pbar.update(1) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + else: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + if self.interrupt: + continue + + if do_true_cfg: + latent_model_input = torch.cat([latents] * 2) + prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask] + prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds] + img_shapes_input = img_shapes * 2 + txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens + else: + latent_model_input = latents + prompt_embeds_mask_input = prompt_embeds_mask + prompt_embeds_input = prompt_embeds + img_shapes_input = img_shapes + txt_seq_lens_input = txt_seq_lens + + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + noise_pred = self.transformer.forward_bs( + x=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask_input, + encoder_hidden_states=prompt_embeds_input, + img_shapes=img_shapes_input, + txt_seq_lens=txt_seq_lens_input, + attention_kwargs=self.attention_kwargs, + return_dict=False, + ) + + if do_true_cfg: + neg_noise_pred, noise_pred = noise_pred.chunk(2) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if comfyui_progressbar: + pbar.update(1) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_qwenimage_edit.py b/videox_fun/pipeline/pipeline_qwenimage_edit.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0a3d3efed6241a78b5b4fc6c62602a1ee1ed6d --- /dev/null +++ b/videox_fun/pipeline/pipeline_qwenimage_edit.py @@ -0,0 +1,952 @@ +# Modified from https://github.com/naykun/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import math +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLQwenImage, + Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor, + Qwen2Tokenizer, QwenImageTransformer2DModel) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPipeline + + >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Change the cat to a dog" + >>> image = Image.open("cat.png") + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimageedit.png") + ``` +""" +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +@dataclass +class QwenImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class QwenImageEditPipeline(DiffusionPipeline): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + + template = self.prompt_template_encode + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + def prepare_latents( + self, + image, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if image is not None: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[0].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + image = self.image_processor.resize(image, calculated_height, calculated_width) + prompt_image = image + image = self.image_processor.preprocess(image, calculated_height, calculated_width) + image = image.unsqueeze(2) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=prompt_image, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if comfyui_progressbar: + pbar.update(1) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + image, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2), + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + if self.interrupt: + continue + + if image_latents is not None: + latents_and_image_latents = torch.cat([latents, image_latents], dim=1) + else: + latents_and_image_latents = latents + + if do_true_cfg: + latent_model_input = torch.cat([latents_and_image_latents] * 2) + prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask] + prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds] + img_shapes_input = img_shapes * 2 + txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens + else: + latent_model_input = latents_and_image_latents + prompt_embeds_mask_input = prompt_embeds_mask + prompt_embeds_input = prompt_embeds + img_shapes_input = img_shapes + txt_seq_lens_input = txt_seq_lens + + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + noise_pred = self.transformer.forward_bs( + x=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask_input, + encoder_hidden_states=prompt_embeds_input, + img_shapes=img_shapes_input, + txt_seq_lens=txt_seq_lens_input, + attention_kwargs=self.attention_kwargs, + return_dict=False, + ) + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + neg_noise_pred, noise_pred = noise_pred.chunk(2) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if comfyui_progressbar: + pbar.update(1) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_qwenimage_edit_plus.py b/videox_fun/pipeline/pipeline_qwenimage_edit_plus.py new file mode 100644 index 0000000000000000000000000000000000000000..885550cc2d3529360e09b89c13b96511bbe1c9f1 --- /dev/null +++ b/videox_fun/pipeline/pipeline_qwenimage_edit_plus.py @@ -0,0 +1,937 @@ +# Modified from https://github.com/naykun/diffusers/blob/main/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import math +import PIL.Image +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor + +from ..models import (AutoencoderKLQwenImage, + Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor, + Qwen2Tokenizer, QwenImageTransformer2DModel) + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from PIL import Image + >>> from diffusers import QwenImageEditPipeline + + >>> pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + >>> prompt = "Change the cat to a dog" + >>> image = Image.open("cat.png") + >>> # Depending on the variant being used, the pipeline call will slightly vary. + >>> # Refer to the pipeline documentation for more details. + >>> image = pipe(image, prompt, num_inference_steps=50).images[0] + >>> image.save("qwenimageedit.png") + ``` +""" + +CONDITION_IMAGE_SIZE = 384 * 384 +VAE_IMAGE_SIZE = 1024 * 1024 + +PREFERRED_QWENIMAGE_RESOLUTIONS = [ + (672, 1568), + (688, 1504), + (720, 1456), + (752, 1392), + (800, 1328), + (832, 1248), + (880, 1184), + (944, 1104), + (1024, 1024), + (1104, 944), + (1184, 880), + (1248, 832), + (1328, 800), + (1392, 752), + (1456, 720), + (1504, 688), + (1568, 672), +] + +# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +def calculate_dimensions(target_area, ratio): + width = math.sqrt(target_area * ratio) + height = width / ratio + + width = round(width / 32) * 32 + height = round(height / 32) * 32 + + return width, height + + +@dataclass +class QwenImagePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class QwenImageEditPlusPipeline(DiffusionPipeline): + r""" + The QwenImage pipeline for text-to-image generation. + + Args: + transformer ([`QwenImageTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen2.5-VL-7B-Instruct`]): + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the + [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant. + tokenizer (`QwenTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLQwenImage, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2Tokenizer, + processor: Qwen2VLProcessor, + transformer: QwenImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + processor=processor, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 + self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16 + # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.tokenizer_max_length = 1024 + + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode_start_idx = 64 + self.default_sample_size = 128 + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden + def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): + bool_mask = mask.bool() + valid_lengths = bool_mask.sum(dim=1) + selected = hidden_states[bool_mask] + split_result = torch.split(selected, valid_lengths.tolist(), dim=0) + + return split_result + + def _get_qwen_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + + model_inputs = self.processor( + text=txt, + images=image, + padding=True, + return_tensors="pt", + ).to(device) + + outputs = self.text_encoder( + input_ids=model_inputs.input_ids, + attention_mask=model_inputs.attention_mask, + pixel_values=model_inputs.pixel_values, + image_grid_thw=model_inputs.image_grid_thw, + output_hidden_states=True, + ) + + hidden_states = outputs.hidden_states[-1] + split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask) + split_hidden_states = [e[drop_idx:] for e in split_hidden_states] + attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] + max_seq_len = max([e.size(0) for e in split_hidden_states]) + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states] + ) + encoder_attention_mask = torch.stack( + [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] + ) + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + return prompt_embeds, encoder_attention_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + image: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + max_sequence_length: int = 1024, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + image (`torch.Tensor`, *optional*): + image to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device) + + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) + prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, prompt_embeds_mask + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_embeds_mask=None, + negative_prompt_embeds_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_embeds_mask is None: + raise ValueError( + "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`." + ) + if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if max_sequence_length is not None and max_sequence_length > 1024: + raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents + def _unpack_latents(latents, height, width, vae_scale_factor): + batch_size, num_patches, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + + # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax") + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std) + .view(1, self.latent_channels, 1, 1, 1) + .to(image_latents.device, image_latents.dtype) + ) + image_latents = (image_latents - latents_mean) / latents_std + + return image_latents + + def prepare_latents( + self, + images, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, 1, num_channels_latents, height, width) + + image_latents = None + if images is not None: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latent_height, image_latent_width = image_latents.shape[3:] + image_latents = self._pack_latents( + image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width + ) + all_image_latents.append(image_latents) + image_latents = torch.cat(all_image_latents, dim=1) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + else: + latents = latents.to(device=device, dtype=dtype) + + return latents, image_latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + image = None, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + true_cfg_scale: float = 4.0, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = None, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + true_cfg_scale (`float`, *optional*, defaults to 1.0): + true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free + Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of + equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is + enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale + encourages to generate images that are closely linked to the text `prompt`, usually at the expense of + lower image quality. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to None): + A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance + where the guidance scale is applied during inference through noise prediction rescaling, guidance + distilled models take the guidance scale directly as an input parameter during forward pass. Guidance + scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images + that are closely linked to the text `prompt`, usually at the expense of lower image quality. This + parameter in the pipeline is there to support future guidance-distilled models when they come up. It is + ignored when not using guidance distilled models. To enable traditional classifier-free guidance, + please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should + enable classifier-free guidance computations). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`: + [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images. + """ + image_size = image[-1].size if isinstance(image, list) else image.size + calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1]) + height = height or calculated_height + width = width or calculated_width + + multiple_of = self.vae_scale_factor * 2 + width = width // multiple_of * multiple_of + height = height // multiple_of * multiple_of + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + negative_prompt_embeds_mask=negative_prompt_embeds_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 3. Preprocess image + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if not isinstance(image, list): + image = [image] + condition_image_sizes = [] + condition_images = [] + vae_image_sizes = [] + vae_images = [] + for img in image: + image_width, image_height = img.size + condition_width, condition_height = calculate_dimensions( + CONDITION_IMAGE_SIZE, image_width / image_height + ) + vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height) + condition_image_sizes.append((condition_width, condition_height)) + vae_image_sizes.append((vae_width, vae_height)) + condition_images.append(self.image_processor.resize(img, condition_height, condition_width)) + vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2)) + + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + + if true_cfg_scale > 1 and not has_neg_prompt: + logger.warning( + f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided." + ) + elif true_cfg_scale <= 1 and has_neg_prompt: + logger.warning( + " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" + ) + + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + prompt_embeds, prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=prompt, + prompt_embeds=prompt_embeds, + prompt_embeds_mask=prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if do_true_cfg: + negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( + image=condition_images, + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + prompt_embeds_mask=negative_prompt_embeds_mask, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + ) + if comfyui_progressbar: + pbar.update(1) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, image_latents = self.prepare_latents( + vae_images, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + img_shapes = [ + [ + (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] + ] * batch_size + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # handle guidance + if self.transformer.config.guidance_embeds and guidance_scale is None: + raise ValueError("guidance_scale is required for guidance-distilled model.") + elif self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latents.shape[0]) + elif not self.transformer.config.guidance_embeds and guidance_scale is not None: + logger.warning( + f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled." + ) + guidance = None + elif not self.transformer.config.guidance_embeds and guidance_scale is None: + guidance = None + + if self.attention_kwargs is None: + self._attention_kwargs = {} + + txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None + negative_txt_seq_lens = ( + negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Denoising loop + self.scheduler.set_begin_index(0) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + if self.interrupt: + continue + + if image_latents is not None: + latents_and_image_latents = torch.cat([latents, image_latents], dim=1) + else: + latents_and_image_latents = latents + + if do_true_cfg: + latent_model_input = torch.cat([latents_and_image_latents] * 2) + prompt_embeds_mask_input = [_negative_prompt_embeds_mask for _negative_prompt_embeds_mask in negative_prompt_embeds_mask] + [_prompt_embeds_mask for _prompt_embeds_mask in prompt_embeds_mask] + prompt_embeds_input = [_negative_prompt_embeds for _negative_prompt_embeds in negative_prompt_embeds] + [_prompt_embeds for _prompt_embeds in prompt_embeds] + img_shapes_input = img_shapes * 2 + txt_seq_lens_input = negative_txt_seq_lens + txt_seq_lens + else: + latent_model_input = latents_and_image_latents + prompt_embeds_mask_input = prompt_embeds_mask + prompt_embeds_input = prompt_embeds + img_shapes_input = img_shapes + txt_seq_lens_input = txt_seq_lens + + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) + guidance = guidance.expand(latent_model_input.shape[0]) + else: + guidance = None + + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) + + with torch.cuda.amp.autocast(dtype=latents.dtype), torch.cuda.device(device=latents.device): + noise_pred = self.transformer.forward_bs( + x=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask=prompt_embeds_mask_input, + encoder_hidden_states=prompt_embeds_input, + img_shapes=img_shapes_input, + txt_seq_lens=txt_seq_lens_input, + attention_kwargs=self.attention_kwargs, + return_dict=False, + ) + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + neg_noise_pred, noise_pred = noise_pred.chunk(2) + comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if comfyui_progressbar: + pbar.update(1) + + self._current_timestep = None + if output_type == "latent": + image = latents + else: + latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return QwenImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_wan.py b/videox_fun/pipeline/pipeline_wan.py new file mode 100755 index 0000000000000000000000000000000000000000..f105c9a338bc608a571ba53a262bf667b538db32 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan.py @@ -0,0 +1,576 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, WanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) + + # 5. Prepare latents + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2.py b/videox_fun/pipeline/pipeline_wan2_2.py new file mode 100755 index 0000000000000000000000000000000000000000..e96287a98f5b336b115ccf693e94b36501416a30 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2.py @@ -0,0 +1,591 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, Wan2_2Transformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2Pipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel, + transformer_2: Wan2_2Transformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) + + # 5. Prepare latents + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + ) + + # perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_animate.py b/videox_fun/pipeline/pipeline_wan2_2_animate.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0df263374b2d6d15d57b46d8ab7a9ce2435d4d --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_animate.py @@ -0,0 +1,929 @@ +import inspect +import math +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import copy +import torch +import cv2 +import torch.nn.functional as F +from einops import rearrange +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from decord import VideoReader + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, Wan2_2Transformer3DModel_Animate) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2AnimatePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2", "clip_image_encoder"] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel_Animate, + transformer_2: Wan2_2Transformer3DModel_Animate = None, + clip_image_encoder: CLIPModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, clip_image_encoder=clip_image_encoder, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def padding_resize(self, img_ori, height=512, width=512, padding_color=(0, 0, 0), interpolation=cv2.INTER_LINEAR): + ori_height = img_ori.shape[0] + ori_width = img_ori.shape[1] + channel = img_ori.shape[2] + + img_pad = np.zeros((height, width, channel)) + if channel == 1: + img_pad[:, :, 0] = padding_color[0] + else: + img_pad[:, :, 0] = padding_color[0] + img_pad[:, :, 1] = padding_color[1] + img_pad[:, :, 2] = padding_color[2] + + if (ori_height / ori_width) > (height / width): + new_width = int(height / ori_height * ori_width) + img = cv2.resize(img_ori, (new_width, height), interpolation=interpolation) + padding = int((width - new_width) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[:, padding: padding + new_width, :] = img + else: + new_height = int(width / ori_width * ori_height) + img = cv2.resize(img_ori, (width, new_height), interpolation=interpolation) + padding = int((height - new_height) / 2) + if len(img.shape) == 2: + img = img[:, :, np.newaxis] + img_pad[padding: padding + new_height, :, :] = img + + img_pad = np.uint8(img_pad) + + return img_pad + + def inputs_padding(self, x, target_len): + ndim = x.ndim + + if ndim == 4: + f = x.shape[0] + if target_len <= f: + return [deepcopy(x[i]) for i in range(target_len)] + + idx = 0 + flip = False + target_array = [] + while len(target_array) < target_len: + target_array.append(deepcopy(x[idx])) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == f - 1: + flip = not flip + return target_array[:target_len] + + elif ndim == 5: + b, c, f, h, w = x.shape + + if target_len <= f: + return x[:, :, :target_len, :, :] + + indices = [] + idx = 0 + flip = False + while len(indices) < target_len: + indices.append(idx) + if flip: + idx -= 1 + else: + idx += 1 + if idx == 0 or idx == f - 1: + flip = not flip + indices = indices[:target_len] + + if isinstance(x, torch.Tensor): + indices_tensor = torch.tensor(indices, device=x.device, dtype=torch.long) + return x[:, :, indices_tensor, :, :] + else: + indices_array = np.array(indices) + return x[:, :, indices_array, :, :] + + else: + raise ValueError(f"Unsupported input dimension: {ndim}. Expected 4D or 5D.") + + def get_valid_len(self, real_len, clip_len=81, overlap=1): + real_clip_len = clip_len - overlap + last_clip_num = (real_len - overlap) % real_clip_len + if last_clip_num == 0: + extra = 0 + else: + extra = real_clip_len - last_clip_num + target_len = real_len + extra + return target_len + + def prepare_source(self, src_pose_path, src_face_path, src_ref_path): + pose_video_reader = VideoReader(src_pose_path) + pose_len = len(pose_video_reader) + pose_idxs = list(range(pose_len)) + pose_video = pose_video_reader.get_batch(pose_idxs).asnumpy() + + face_video_reader = VideoReader(src_face_path) + face_len = len(face_video_reader) + face_idxs = list(range(face_len)) + face_video = face_video_reader.get_batch(face_idxs).asnumpy() + height, width = pose_video[0].shape[:2] + + ref_image = cv2.imread(src_ref_path)[..., ::-1] + ref_image = self.padding_resize(ref_image, height=height, width=width) + return pose_video, face_video, ref_image + + def prepare_source_for_replace(self, src_bg_path, src_mask_path): + bg_video_reader = VideoReader(src_bg_path) + bg_len = len(bg_video_reader) + bg_idxs = list(range(bg_len)) + bg_video = bg_video_reader.get_batch(bg_idxs).asnumpy() + + mask_video_reader = VideoReader(src_mask_path) + mask_len = len(mask_video_reader) + mask_idxs = list(range(mask_len)) + mask_video = mask_video_reader.get_batch(mask_idxs).asnumpy() + mask_video = mask_video[:, :, :, 0] / 255 + return bg_video, mask_video + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2) + return msk + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + # frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + clip_len=77, + num_frames: int = 49, + num_inference_steps: int = 50, + pose_video = None, + face_video = None, + ref_image = None, + bg_video = None, + mask_video = None, + replace_flag = True, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + refert_num = 1, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 1) + + # 4. Prepare latents + if pose_video is not None: + video_length = pose_video.shape[2] + pose_video = self.image_processor.preprocess(rearrange(pose_video, "b c f h w -> (b f) c h w"), height=height, width=width) + pose_video = pose_video.to(dtype=torch.float32) + pose_video = rearrange(pose_video, "(b f) c h w -> b c f h w", f=video_length) + else: + pose_video = None + + if face_video is not None: + video_length = face_video.shape[2] + face_video = self.image_processor.preprocess(rearrange(face_video, "b c f h w -> (b f) c h w")) + face_video = face_video.to(dtype=torch.float32) + face_video = rearrange(face_video, "(b f) c h w -> b c f h w", f=video_length) + else: + face_video = None + + real_frame_len = pose_video.size()[2] + target_len = self.get_valid_len(real_frame_len, clip_len, overlap=refert_num) + print('real frames: {} target frames: {}'.format(real_frame_len, target_len)) + pose_video = self.inputs_padding(pose_video, target_len).to(device, weight_dtype) + face_video = self.inputs_padding(face_video, target_len).to(device, weight_dtype) + ref_image = self.padding_resize(np.array(ref_image), height=height, width=width) + ref_image = torch.tensor(ref_image / 127.5 - 1).unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0).to(device, weight_dtype) + + if replace_flag: + if bg_video is not None: + video_length = bg_video.shape[2] + bg_video = self.image_processor.preprocess(rearrange(bg_video, "b c f h w -> (b f) c h w"), height=height, width=width) + bg_video = bg_video.to(dtype=torch.float32) + bg_video = rearrange(bg_video, "(b f) c h w -> b c f h w", f=video_length) + else: + bg_video = None + bg_video = self.inputs_padding(bg_video, target_len).to(device, weight_dtype) + mask_video = self.inputs_padding(mask_video, target_len).to(device, weight_dtype) + + if comfyui_progressbar: + pbar.update(1) + + # 5. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + + # 6. Denoising loop + start = 0 + end = clip_len + all_out_frames = [] + copy_timesteps = copy.deepcopy(timesteps) + copy_latents = copy.deepcopy(latents) + bs = pose_video.size()[0] + while True: + if start + refert_num >= pose_video.size()[2]: + break + + # Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps) + self._num_timesteps = len(timesteps) + + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + copy_latents, + ) + + if start == 0: + mask_reft_len = 0 + else: + mask_reft_len = refert_num + + conditioning_pixel_values = pose_video[:, :, start:end] + face_pixel_values = face_video[:, :, start:end] + ref_pixel_values = ref_image.clone().detach() + if start > 0: + refer_t_pixel_values = out_frames[:, :, -refert_num:].clone().detach() + refer_t_pixel_values = (refer_t_pixel_values - 0.5) / 0.5 + else: + refer_t_pixel_values = torch.zeros(bs, 3, refert_num, height, width) + refer_t_pixel_values = refer_t_pixel_values.to(device=device, dtype=weight_dtype) + + pose_latents, ref_latents = self.prepare_control_latents( + conditioning_pixel_values, + ref_pixel_values, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + ) + + mask_ref = self.get_i2v_mask(1, target_shape[-1], target_shape[-2], 1, device=device) + y_ref = torch.concat([mask_ref, ref_latents], dim=1).to(device=device, dtype=weight_dtype) + if mask_reft_len > 0: + if replace_flag: + # Image.fromarray(np.array((refer_t_pixel_values[0, :, 0].permute(1,2,0) * 0.5 + 0.5).float().cpu().numpy() *255, np.uint8)).save("1.jpg") + bg_pixel_values = bg_video[:, :, start:end] + y_reft = self.vae.encode( + torch.concat( + [ + refer_t_pixel_values[:, :, :mask_reft_len], + bg_pixel_values[:, :, mask_reft_len:] + ], dim=2 + ).to(device=device, dtype=weight_dtype) + )[0].mode() + + mask_pixel_values = 1 - mask_video[:, :, start:end] + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0] + msk_reft = self.get_i2v_mask( + int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device + ) + else: + refer_t_pixel_values = rearrange(refer_t_pixel_values[:, :, :mask_reft_len], "b c t h w -> (b t) c h w") + refer_t_pixel_values = F.interpolate(refer_t_pixel_values, size=(height, width), mode="bicubic") + refer_t_pixel_values = rearrange(refer_t_pixel_values, "(b t) c h w -> b c t h w", b = bs) + + y_reft = self.vae.encode( + torch.concat( + [ + refer_t_pixel_values, + torch.zeros(bs, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype), + ], dim=2, + ).to(device=device, dtype=weight_dtype) + )[0].mode() + msk_reft = self.get_i2v_mask( + int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device + ) + else: + if replace_flag: + bg_pixel_values = bg_video[:, :, start:end] + y_reft = self.vae.encode( + bg_pixel_values.to(device=device, dtype=weight_dtype) + )[0].mode() + + mask_pixel_values = 1 - mask_video[:, :, start:end] + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = F.interpolate(mask_pixel_values, size=(target_shape[-1], target_shape[-2]), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b c t h w", b = bs)[:, 0] + msk_reft = self.get_i2v_mask( + int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, mask_pixel_values=mask_pixel_values, device=device + ) + else: + y_reft = self.vae.encode( + torch.zeros(1, 3, clip_len - mask_reft_len, height, width).to(device=device, dtype=weight_dtype) + )[0].mode() + msk_reft = self.get_i2v_mask( + int((clip_len - 1) // self.vae.temporal_compression_ratio + 1), target_shape[-1], target_shape[-2], mask_reft_len, device=device + ) + + y_reft = torch.concat([msk_reft, y_reft], dim=1).to(device=device, dtype=weight_dtype) + y = torch.concat([y_ref, y_reft], dim=2) + + clip_context = self.clip_image_encoder([ref_pixel_values[0, :, :, :]]).to(device=device, dtype=weight_dtype) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + y_in = torch.cat([y] * 2) if do_classifier_free_guidance else y + clip_context_input = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + pose_latents_input = ( + torch.cat([pose_latents] * 2) if do_classifier_free_guidance else pose_latents + ) + face_pixel_values_input = ( + torch.cat([torch.ones_like(face_pixel_values) * -1] + [face_pixel_values]) if do_classifier_free_guidance else face_pixel_values + ) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=y_in, + clip_fea=clip_context_input, + pose_latents=pose_latents_input, + face_pixel_values=face_pixel_values_input, + ) + + # Perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # Compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + out_frames = self.decode_latents(latents[:, :, 1:]) + if start != 0: + out_frames = out_frames[:, :, refert_num:] + all_out_frames.append(out_frames.cpu()) + start += clip_len - refert_num + end += clip_len - refert_num + + videos = torch.cat(all_out_frames, dim=2)[:, :, :real_frame_len] + + # Offload all models + self.maybe_free_model_hooks() + + return WanPipelineOutput(videos=videos.float().cpu()) diff --git a/videox_fun/pipeline/pipeline_wan2_2_fun_control.py b/videox_fun/pipeline/pipeline_wan2_2_fun_control.py new file mode 100644 index 0000000000000000000000000000000000000000..5b923191c43165bdf825b4eb0cc2cbb08ffab6b5 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_fun_control.py @@ -0,0 +1,903 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, + Wan2_2Transformer3DModel, WanT5EncoderModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2FunControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel, + transformer_2: Wan2_2Transformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + control_video: Union[torch.FloatTensor] = None, + control_camera_video: Union[torch.FloatTensor] = None, + start_image: Union[torch.FloatTensor] = None, + ref_image: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.tile( + torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1] + ) + masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + if self.vae.spatial_compression_ratio >= 16: + mask = torch.ones_like(latents).to(device, weight_dtype)[:, :1].to(device, weight_dtype) + else: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype) + + if self.vae.spatial_compression_ratio >= 16: + mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype) + if not mask[:, :, 0, :, :].any(): + mask[:, :, 1:, :, :] = 1 + latents = (1 - mask) * masked_video_latents + mask * latents + + # Prepare mask latent variables + if control_camera_video is not None: + control_latents = None + # Rearrange dimensions + # Concatenate and transpose dimensions + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + + # Reshape, transpose, and view into desired shape + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + elif control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + control_camera_latents = None + else: + control_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + control_camera_latents = None + + if start_image is not None: + video_length = start_image.shape[2] + start_image = self.image_processor.preprocess(rearrange(start_image, "b c f h w -> (b f) c h w"), height=height, width=width) + start_image = start_image.to(dtype=torch.float32) + start_image = rearrange(start_image, "(b f) c h w -> b c f h w", f=video_length) + + start_image_latentes = self.prepare_control_latents( + None, + start_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + + start_image_latentes_conv_in = torch.zeros_like(latents) + if latents.size()[2] != 1: + start_image_latentes_conv_in[:, :, :1] = start_image_latentes + else: + start_image_latentes_conv_in = torch.zeros_like(latents) + + if self.transformer.config.get("add_ref_conv", False): + if ref_image is not None: + video_length = ref_image.shape[2] + ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length) + + ref_image_latentes = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + ref_image_latentes = ref_image_latentes[:, :, 0] + else: + ref_image_latentes = torch.zeros_like(latents)[:, :, 0] + else: + if ref_image is not None: + raise ValueError("The add_ref_conv is False, but ref_image is not None") + else: + ref_image_latentes = None + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Prepare mask latent variables + if control_camera_video is not None: + control_latents_input = None + control_camera_latents_input = ( + torch.cat([control_camera_latents] * 2) if do_classifier_free_guidance else control_camera_latents + ).to(device, weight_dtype) + else: + control_latents_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ).to(device, weight_dtype) + control_camera_latents_input = None + + if init_video is not None: + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + control_latents_input = y if control_latents_input is None else \ + torch.cat([control_latents_input, y], dim = 1) + else: + start_image_latentes_conv_in_input = ( + torch.cat([start_image_latentes_conv_in] * 2) if do_classifier_free_guidance else start_image_latentes_conv_in + ).to(device, weight_dtype) + control_latents_input = start_image_latentes_conv_in_input if control_latents_input is None else \ + torch.cat([control_latents_input, start_image_latentes_conv_in_input], dim = 1) + + if ref_image_latentes is not None: + full_ref = ( + torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes + ).to(device, weight_dtype) + else: + full_ref = None + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + if self.vae.spatial_compression_ratio >= 16 and init_video is not None: + temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten() + temp_ts = torch.cat([ + temp_ts, + temp_ts.new_ones(seq_len - temp_ts.size(0)) * t + ]) + temp_ts = temp_ts.unsqueeze(0) + timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1)) + else: + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=control_latents_input, + y_camera=control_camera_latents_input, + full_ref=full_ref, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if self.vae.spatial_compression_ratio >= 16 and not mask[:, :, 0, :, :].any(): + latents = (1 - mask) * masked_video_latents + mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py b/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py new file mode 100644 index 0000000000000000000000000000000000000000..8eaf68e67d624b4d84e0cbfe62ef81836ee83d18 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_fun_inpaint.py @@ -0,0 +1,752 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, Wan2_2Transformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2FunInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel, + transformer_2: Wan2_2Transformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.tile( + torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1] + ) + masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + if self.vae.spatial_compression_ratio >= 16: + mask = torch.ones_like(latents).to(device, weight_dtype)[:, :1].to(device, weight_dtype) + else: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype) + + if self.vae.spatial_compression_ratio >= 16: + mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype) + if not mask[:, :, 0, :, :].any(): + mask[:, :, 1:, :, :] = 1 + latents = (1 - mask) * masked_video_latents + mask * latents + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if init_video is not None: + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + if self.vae.spatial_compression_ratio >= 16 and init_video is not None: + temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten() + temp_ts = torch.cat([ + temp_ts, + temp_ts.new_ones(seq_len - temp_ts.size(0)) * t + ]) + temp_ts = temp_ts.unsqueeze(0) + timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1)) + else: + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=y, + ) + + # perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if self.vae.spatial_compression_ratio >= 16 and not mask[:, :, 0, :, :].any(): + latents = (1 - mask) * masked_video_latents + mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_s2v.py b/videox_fun/pipeline/pipeline_wan2_2_s2v.py new file mode 100644 index 0000000000000000000000000000000000000000..eb421af38e26b8bed8903bdf9e4642f23c194506 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_s2v.py @@ -0,0 +1,815 @@ +import inspect +import math +import copy +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from torchvision import transforms +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + Wan2_2Transformer3DModel_S2V, WanAudioEncoder, + WanT5EncoderModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2S2VPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2", "audio_encoder"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + audio_encoder: WanAudioEncoder, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel_S2V, + transformer_2: Wan2_2Transformer3DModel_S2V = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler, audio_encoder=audio_encoder + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + self.motion_frames = 73 + self.audio_sample_m = 0 + self.drop_first_motion = True + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def encode_audio_embeddings(self, audio_path, num_frames, fps, weight_dtype, device): + z = self.audio_encoder.extract_audio_feat( + audio_path, return_all_layers=True) + audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps( + z, fps=fps, batch_frames=num_frames, m=self.audio_sample_m) + audio_embed_bucket = audio_embed_bucket.to(device, + weight_dtype) + audio_embed_bucket = audio_embed_bucket.unsqueeze(0) + if len(audio_embed_bucket.shape) == 3: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) + elif len(audio_embed_bucket.shape) == 4: + audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) + return audio_embed_bucket, num_repeat + + def encode_pose_latents(self, pose_video, num_repeat, num_frames, size, fps, weight_dtype, device): + height, width = size + if not pose_video is None: + padding_frame_num = num_repeat * num_frames - pose_video.shape[2] + pose_video = torch.cat( + [ + pose_video, + -torch.ones([1, 3, padding_frame_num, height, width]) + ], + dim=2 + ) + + cond_tensors = torch.chunk(pose_video, num_repeat, dim=2) + else: + cond_tensors = [-torch.ones([1, 3, num_frames, height, width])] + + pose_latents = [] + for r in range(len(cond_tensors)): + cond = cond_tensors[r] + cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], + dim=2) + cond_lat = self.vae.encode(cond.to(dtype=weight_dtype, device=device))[0].mode()[:, :, 1:] + pose_latents.append(cond_lat) + return pose_latents + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + # frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + ref_image: Union[torch.FloatTensor] = None, + audio_path = None, + pose_video = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + fps: int = 16, + init_first_frame: bool = False, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + lat_motion_frames = (self.motion_frames + 3) // 4 + lat_target_frames = (num_frames + 3 + self.motion_frames) // 4 - lat_motion_frames + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + if comfyui_progressbar: + pbar.update(1) + + video_length = ref_image.shape[2] + ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length) + + ref_image_latentes = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + ref_image_latentes = ref_image_latentes[:, :, :1] + + # Extract audio emb + audio_emb, num_repeat = self.encode_audio_embeddings( + audio_path, num_frames=num_frames, fps=fps, weight_dtype=weight_dtype, device=device + ) + + # Encode the motion latents + motion_latents = torch.zeros( + [1, 3, self.motion_frames, height, width], + dtype=weight_dtype, + device=device + ) + videos_last_frames = motion_latents.detach() + drop_first_motion = self.drop_first_motion + if init_first_frame: + drop_first_motion = False + motion_latents[:, :, -6:] = ref_image + motion_latents = self.vae.encode(motion_latents)[0].mode() + + # Get pose cond input if need + if pose_video is not None: + video_length = pose_video.shape[2] + pose_video = self.image_processor.preprocess(rearrange(pose_video, "b c f h w -> (b f) c h w"), height=height, width=width) + pose_video = pose_video.to(dtype=torch.float32) + pose_video = rearrange(pose_video, "(b f) c h w -> b c f h w", f=video_length) + pose_latents = self.encode_pose_latents( + pose_video=pose_video, + num_repeat=num_repeat, + num_frames=num_frames, + size=(height, width), + fps=fps, + weight_dtype=weight_dtype, + device=device + ) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + videos = [] + copy_timesteps = copy.deepcopy(timesteps) + copy_latents = copy.deepcopy(latents) + for r in range(num_repeat): + # Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, copy_timesteps) + self._num_timesteps = len(timesteps) + + target_shape = (self.vae.latent_channels, lat_target_frames, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + copy_latents, + num_length_latents=target_shape[1] + ) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + with torch.no_grad(): + left_idx = r * num_frames + right_idx = r * num_frames + num_frames + cond_latents = pose_latents[r] if pose_video is not None else pose_latents[0] * 0 + cond_latents = cond_latents.to(dtype=weight_dtype, device=device) + audio_input = audio_emb[..., left_idx:right_idx] + + pose_latents_input = torch.cat([cond_latents] * 2) if do_classifier_free_guidance else cond_latents + motion_latents_input = torch.cat([motion_latents] * 2) if do_classifier_free_guidance else motion_latents + audio_emb_input = torch.cat([audio_input * 0] + [audio_input]) if do_classifier_free_guidance else audio_input + ref_image_latentes_input = torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes + motion_frames=[[self.motion_frames, (self.motion_frames + 3) // 4]] * 2 if do_classifier_free_guidance else [[self.motion_frames, (self.motion_frames + 3) // 4]] + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + cond_states=pose_latents_input, + motion_latents=motion_latents_input, + ref_latents=ref_image_latentes_input, + audio_input=audio_emb_input, + motion_frames=motion_frames, + drop_motion_frames=drop_first_motion and r == 0, + ) + # perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if not (drop_first_motion and r == 0): + decode_latents = torch.cat([motion_latents, latents], dim=2) + else: + decode_latents = torch.cat([ref_image_latentes, latents], dim=2) + + image = self.vae.decode(decode_latents).sample + image = image[:, :, -(num_frames):] + if (drop_first_motion and r == 0): + image = image[:, :, 3:] + + overlap_frames_num = min(self.motion_frames, image.shape[2]) + videos_last_frames = torch.cat( + [ + videos_last_frames[:, :, overlap_frames_num:], + image[:, :, -overlap_frames_num:] + ], + dim=2 + ).to(dtype=motion_latents.dtype, device=motion_latents.device) + motion_latents = self.vae.encode(videos_last_frames)[0].mode() + videos.append(image) + + videos = torch.cat(videos, dim=2) + videos = (videos / 2 + 0.5).clamp(0, 1) + + # Offload all models + self.maybe_free_model_hooks() + + return WanPipelineOutput(videos=videos.float().cpu()) diff --git a/videox_fun/pipeline/pipeline_wan2_2_ti2v.py b/videox_fun/pipeline/pipeline_wan2_2_ti2v.py new file mode 100644 index 0000000000000000000000000000000000000000..12b3bd8cf0aa36a16b059d3eb55c2b8fcad61762 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_ti2v.py @@ -0,0 +1,732 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, Wan2_2Transformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2TI2VPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: Wan2_2Transformer3DModel, + transformer_2: Wan2_2Transformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None and not (mask_video == 255).all(): + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + + mask = F.interpolate(mask_condition[:, :1], size=latents.size()[-3:], mode='trilinear', align_corners=True).to(device, weight_dtype) + latents = (1 - mask) * masked_video_latents + mask * latents + else: + init_video = None + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + if init_video is not None: + temp_ts = ((mask[0][0][:, ::2, ::2]) * t).flatten() + temp_ts = torch.cat([ + temp_ts, + temp_ts.new_ones(seq_len - temp_ts.size(0)) * t + ]) + temp_ts = temp_ts.unsqueeze(0) + timestep = temp_ts.expand(latent_model_input.shape[0], temp_ts.size(1)) + else: + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + ) + + # perform guidance + if do_classifier_free_guidance: + if self.transformer_2 is not None and (isinstance(self.guidance_scale, (list, tuple))): + sample_guide_scale = self.guidance_scale[1] if t >= self.transformer_2.config.boundary * self.scheduler.config.num_train_timesteps else self.guidance_scale[0] + else: + sample_guide_scale = self.guidance_scale + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + sample_guide_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if init_video is not None: + latents = (1 - mask) * masked_video_latents + mask * latents + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py b/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py new file mode 100644 index 0000000000000000000000000000000000000000..c8b0d9932c5797af2e20081e37b26fae0c599486 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan2_2_vace_fun.py @@ -0,0 +1,801 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, VaceWanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class Wan2_2VaceFunPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = ["transformer_2"] + model_cpu_offload_seq = "text_encoder->transformer_2->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: VaceWanTransformer3DModel, + transformer_2: VaceWanTransformer3DModel = None, + scheduler: FlowMatchEulerDiscreteScheduler = None, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, + transformer_2=transformer_2, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): + vae = self.vae if vae is None else vae + weight_dtype = frames.dtype + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames)[0].mode() + else: + masks = [torch.where(m > 0.5, 1.0, 0.0).to(weight_dtype) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive)[0].mode() + reactive = vae.encode(reactive)[0].mode() + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs)[0].mode() + else: + ref_latent = vae.encode(refs)[0].mode() + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None, vae_stride=[4, 8, 8]): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + control_video: Union[torch.FloatTensor] = None, + subject_ref_images: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + boundary: float = 0.875, + comfyui_progressbar: bool = False, + shift: int = 5, + vace_context_scale: float = 1.0, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + latent_channels = self.vae.config.latent_channels + + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if mask_video is not None: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + mask_condition = torch.tile(mask_condition, [1, 3, 1, 1, 1]).to(dtype=weight_dtype, device=device) + + + if control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + input_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + + input_video = input_video.to(dtype=weight_dtype, device=device) + + elif video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length).to(dtype=weight_dtype, device=device) + + input_video = init_video * (mask_condition < 0.5) + input_video = input_video.to(dtype=weight_dtype, device=device) + + if subject_ref_images is not None: + video_length = subject_ref_images.shape[2] + subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width) + subject_ref_images = subject_ref_images.to(dtype=torch.float32) + subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length) + subject_ref_images = subject_ref_images.to(dtype=weight_dtype, device=device) + + bs, c, f, h, w = subject_ref_images.size() + new_subject_ref_images = [] + for i in range(bs): + new_subject_ref_images.append([]) + for j in range(f): + new_subject_ref_images[i].append(subject_ref_images[i, :, j:j+1]) + subject_ref_images = new_subject_ref_images + + vace_latents = self.vace_encode_frames(input_video, subject_ref_images, masks=mask_condition, vae=self.vae) + mask_latents = self.vace_encode_masks(mask_condition, subject_ref_images, vae_stride=[4, self.vae.spatial_compression_ratio, self.vae.spatial_compression_ratio]) + vace_context = self.vace_latent(vace_latents, mask_latents) + + # 5. Prepare latents. + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + num_length_latents=vace_latents[0].size(1) + ) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, vace_latents[0].size(1), vace_latents[0].size(2), vace_latents[0].size(3)) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + vace_context_input = torch.stack(vace_context * 2) if do_classifier_free_guidance else vace_context + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + if self.transformer_2 is not None: + if t >= boundary * self.scheduler.config.num_train_timesteps: + local_transformer = self.transformer_2 + else: + local_transformer = self.transformer + else: + local_transformer = self.transformer + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = local_transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + vace_context=vace_context_input, + seq_len=seq_len, + vace_context_scale=vace_context_scale, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if subject_ref_images is not None: + len_subject_ref_images = len(subject_ref_images[0]) + latents = latents[:, :, len_subject_ref_images:, :, :] + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_fun_control.py b/videox_fun/pipeline/pipeline_wan_fun_control.py new file mode 100755 index 0000000000000000000000000000000000000000..4d4ec75e77a434f031b081d39eb4358dc59025a8 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan_fun_control.py @@ -0,0 +1,799 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanFunControlPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + clip_image_encoder: CLIPModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + control_video: Union[torch.FloatTensor] = None, + control_camera_video: Union[torch.FloatTensor] = None, + start_image: Union[torch.FloatTensor] = None, + ref_image: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clip_image: Image = None, + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if control_camera_video is not None: + control_latents = None + # Rearrange dimensions + # Concatenate and transpose dimensions + control_camera_latents = torch.concat( + [ + torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2), + control_camera_video[:, :, 1:] + ], dim=2 + ).transpose(1, 2) + + # Reshape, transpose, and view into desired shape + b, f, c, h, w = control_camera_latents.shape + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3) + control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2) + elif control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + control_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + control_video_latents = self.prepare_control_latents( + None, + control_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + control_camera_latents = None + else: + control_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + control_camera_latents = None + + if start_image is not None: + video_length = start_image.shape[2] + start_image = self.image_processor.preprocess(rearrange(start_image, "b c f h w -> (b f) c h w"), height=height, width=width) + start_image = start_image.to(dtype=torch.float32) + start_image = rearrange(start_image, "(b f) c h w -> b c f h w", f=video_length) + + start_image_latentes = self.prepare_control_latents( + None, + start_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + + start_image_latentes_conv_in = torch.zeros_like(latents) + if latents.size()[2] != 1: + start_image_latentes_conv_in[:, :, :1] = start_image_latentes + else: + start_image_latentes_conv_in = torch.zeros_like(latents) + + # Prepare clip latent variables + if clip_image is not None: + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + else: + clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0)) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = torch.zeros_like(clip_context) + + if self.transformer.config.get("add_ref_conv", False): + if ref_image is not None: + video_length = ref_image.shape[2] + ref_image = self.image_processor.preprocess(rearrange(ref_image, "b c f h w -> (b f) c h w"), height=height, width=width) + ref_image = ref_image.to(dtype=torch.float32) + ref_image = rearrange(ref_image, "(b f) c h w -> b c f h w", f=video_length) + + ref_image_latentes = self.prepare_control_latents( + None, + ref_image, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] + ref_image_latentes = ref_image_latentes[:, :, 0] + else: + ref_image_latentes = torch.zeros_like(latents)[:, :, 0] + else: + if ref_image is not None: + raise ValueError("The add_ref_conv is False, but ref_image is not None") + else: + ref_image_latentes = None + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # Prepare mask latent variables + if control_camera_video is not None: + control_latents_input = None + control_camera_latents_input = ( + torch.cat([control_camera_latents] * 2) if do_classifier_free_guidance else control_camera_latents + ).to(device, weight_dtype) + else: + control_latents_input = ( + torch.cat([control_video_latents] * 2) if do_classifier_free_guidance else control_video_latents + ).to(device, weight_dtype) + control_camera_latents_input = None + + start_image_latentes_conv_in_input = ( + torch.cat([start_image_latentes_conv_in] * 2) if do_classifier_free_guidance else start_image_latentes_conv_in + ).to(device, weight_dtype) + control_latents_input = start_image_latentes_conv_in_input if control_latents_input is None else \ + torch.cat([control_latents_input, start_image_latentes_conv_in_input], dim = 1) + + clip_context_input = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + + if ref_image_latentes is not None: + full_ref = ( + torch.cat([ref_image_latentes] * 2) if do_classifier_free_guidance else ref_image_latentes + ).to(device, weight_dtype) + else: + full_ref = None + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=control_latents_input, + y_camera=control_camera_latents_input, + full_ref=full_ref, + clip_fea=clip_context_input, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_fun_inpaint.py b/videox_fun/pipeline/pipeline_wan_fun_inpaint.py new file mode 100755 index 0000000000000000000000000000000000000000..35f3b96def177e1ce851d369d28353a19aefb3db --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan_fun_inpaint.py @@ -0,0 +1,734 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanFunInpaintPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + clip_image_encoder: CLIPModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, clip_image_encoder=clip_image_encoder, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance, noise_aug_strength + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if mask is not None: + mask = mask.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask = [] + for i in range(0, mask.shape[0], bs): + mask_bs = mask[i : i + bs] + mask_bs = self.vae.encode(mask_bs)[0] + mask_bs = mask_bs.mode() + new_mask.append(mask_bs) + mask = torch.cat(new_mask, dim = 0) + # mask = mask * self.vae.config.scaling_factor + + if masked_image is not None: + masked_image = masked_image.to(device=device, dtype=self.vae.dtype) + bs = 1 + new_mask_pixel_values = [] + for i in range(0, masked_image.shape[0], bs): + mask_pixel_values_bs = masked_image[i : i + bs] + mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0] + mask_pixel_values_bs = mask_pixel_values_bs.mode() + new_mask_pixel_values.append(mask_pixel_values_bs) + masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0) + # masked_image_latents = masked_image_latents * self.vae.config.scaling_factor + else: + masked_image_latents = None + + return mask, masked_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clip_image: Image = None, + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + if video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length) + else: + init_video = None + + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if init_video is not None: + if (mask_video == 255).all(): + mask_latents = torch.tile( + torch.zeros_like(latents)[:, :1].to(device, weight_dtype), [1, 4, 1, 1, 1] + ) + masked_video_latents = torch.zeros_like(latents).to(device, weight_dtype) + else: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + + masked_video = init_video * (torch.tile(mask_condition, [1, 3, 1, 1, 1]) < 0.5) + _, masked_video_latents = self.prepare_mask_latents( + None, + masked_video, + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance, + noise_aug_strength=None, + ) + + mask_condition = torch.concat( + [ + torch.repeat_interleave(mask_condition[:, :, 0:1], repeats=4, dim=2), + mask_condition[:, :, 1:] + ], dim=2 + ) + mask_condition = mask_condition.view(bs, mask_condition.shape[2] // 4, 4, height, width) + mask_condition = mask_condition.transpose(1, 2) + mask_latents = resize_mask(1 - mask_condition, masked_video_latents, True).to(device, weight_dtype) + + # Prepare clip latent variables + if clip_image is not None: + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + else: + clip_image = Image.new("RGB", (512, 512), color=(0, 0, 0)) + clip_image = TF.to_tensor(clip_image).sub_(0.5).div_(0.5).to(device, weight_dtype) + clip_context = self.clip_image_encoder([clip_image[:, None, :, :]]) + clip_context = torch.zeros_like(clip_context) + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if init_video is not None: + mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents + masked_video_latents_input = ( + torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents + ) + y = torch.cat([mask_input, masked_video_latents_input], dim=1).to(device, weight_dtype) + + clip_context_input = ( + torch.cat([clip_context] * 2) if do_classifier_free_guidance else clip_context + ) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + y=y, + clip_fea=clip_context_input, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_phantom.py b/videox_fun/pipeline/pipeline_wan_phantom.py new file mode 100644 index 0000000000000000000000000000000000000000..fd993b001a6c2d3f4b3ed9f71c5df2202d523b14 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan_phantom.py @@ -0,0 +1,695 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanFunPhantomPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->clip_image_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: WanTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + subject_ref_images: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + # 5. Prepare latents. + latent_channels = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + ) + if comfyui_progressbar: + pbar.update(1) + + if subject_ref_images is not None: + video_length = subject_ref_images.shape[2] + subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width) + subject_ref_images = subject_ref_images.to(dtype=torch.float32) + subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length) + + subject_ref_images_latentes = torch.cat( + [ + self.prepare_control_latents( + None, + subject_ref_images[:, :, i:i+1], + batch_size, + height, + width, + weight_dtype, + device, + generator, + do_classifier_free_guidance + )[1] for i in range(video_length) + ], dim = 2 + ) + + if comfyui_progressbar: + pbar.update(1) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, (num_frames - 1) // self.vae.temporal_compression_ratio + 1, width // self.vae.spatial_compression_ratio, height // self.vae.spatial_compression_ratio) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + if subject_ref_images is not None: + subject_ref = ( + torch.cat( + [torch.zeros_like(subject_ref_images_latentes), subject_ref_images_latentes] + ) if do_classifier_free_guidance else subject_ref_images_latentes + ).to(device, weight_dtype) + else: + subject_ref = None + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + seq_len=seq_len, + subject_ref=subject_ref, + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_wan_vace.py b/videox_fun/pipeline/pipeline_wan_vace.py new file mode 100644 index 0000000000000000000000000000000000000000..7c0ded274bb5ebe29fb7840a23a210b2917ec071 --- /dev/null +++ b/videox_fun/pipeline/pipeline_wan_vace.py @@ -0,0 +1,787 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from diffusers import FlowMatchEulerDiscreteScheduler +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.embeddings import get_1d_rotary_pos_embed +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import BaseOutput, logging, replace_example_docstring +from diffusers.utils.torch_utils import randn_tensor +from diffusers.video_processor import VideoProcessor +from einops import rearrange +from PIL import Image +from transformers import T5Tokenizer + +from ..models import (AutoencoderKLWan, AutoTokenizer, + WanT5EncoderModel, VaceWanTransformer3DModel) +from ..utils.fm_solvers import (FlowDPMSolverMultistepScheduler, + get_sampling_sigmas) +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```python + pass + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def resize_mask(mask, latent, process_first_frame_only=True): + latent_size = latent.size() + batch_size, channels, num_frames, height, width = mask.shape + + if process_first_frame_only: + target_size = list(latent_size[2:]) + target_size[0] = 1 + first_frame_resized = F.interpolate( + mask[:, :, 0:1, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + + target_size = list(latent_size[2:]) + target_size[0] = target_size[0] - 1 + if target_size[0] != 0: + remaining_frames_resized = F.interpolate( + mask[:, :, 1:, :, :], + size=target_size, + mode='trilinear', + align_corners=False + ) + resized_mask = torch.cat([first_frame_resized, remaining_frames_resized], dim=2) + else: + resized_mask = first_frame_resized + else: + target_size = list(latent_size[2:]) + resized_mask = F.interpolate( + mask, + size=target_size, + mode='trilinear', + align_corners=False + ) + return resized_mask + + +@dataclass +class WanPipelineOutput(BaseOutput): + r""" + Output class for CogVideo pipelines. + + Args: + video (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + videos: torch.Tensor + + +class WanVacePipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Wan. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + """ + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: AutoTokenizer, + text_encoder: WanT5EncoderModel, + vae: AutoencoderKLWan, + transformer: VaceWanTransformer3DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae.spatial_compression_ratio) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae.spatial_compression_ratio, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))[0] + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ + device = device or self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds = self._get_t5_prompt_embeds( + prompt=prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + return prompt_embeds, negative_prompt_embeds + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None, num_length_latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + shape = ( + batch_size, + num_channels_latents, + (num_frames - 1) // self.vae.temporal_compression_ratio + 1 if num_length_latents is None else num_length_latents, + height // self.vae.spatial_compression_ratio, + width // self.vae.spatial_compression_ratio, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(self.scheduler, "init_noise_sigma"): + latents = latents * self.scheduler.init_noise_sigma + return latents + + def vace_encode_frames(self, frames, ref_images, masks=None, vae=None): + vae = self.vae if vae is None else vae + weight_dtype = frames.dtype + if ref_images is None: + ref_images = [None] * len(frames) + else: + assert len(frames) == len(ref_images) + + if masks is None: + latents = vae.encode(frames)[0].mode() + else: + masks = [torch.where(m > 0.5, 1.0, 0.0).to(weight_dtype) for m in masks] + inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] + reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] + inactive = vae.encode(inactive)[0].mode() + reactive = vae.encode(reactive)[0].mode() + latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] + + cat_latents = [] + for latent, refs in zip(latents, ref_images): + if refs is not None: + if masks is None: + ref_latent = vae.encode(refs)[0].mode() + else: + ref_latent = vae.encode(refs)[0].mode() + ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] + assert all([x.shape[1] == 1 for x in ref_latent]) + latent = torch.cat([*ref_latent, latent], dim=1) + cat_latents.append(latent) + return cat_latents + + def vace_encode_masks(self, masks, ref_images=None, vae_stride=[4, 8, 8]): + if ref_images is None: + ref_images = [None] * len(masks) + else: + assert len(masks) == len(ref_images) + + result_masks = [] + for mask, refs in zip(masks, ref_images): + c, depth, height, width = mask.shape + new_depth = int((depth + 3) // vae_stride[0]) + height = 2 * (int(height) // (vae_stride[1] * 2)) + width = 2 * (int(width) // (vae_stride[2] * 2)) + + # reshape + mask = mask[0, :, :, :] + mask = mask.view( + depth, height, vae_stride[1], width, vae_stride[1] + ) # depth, height, 8, width, 8 + mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width + mask = mask.reshape( + vae_stride[1] * vae_stride[2], depth, height, width + ) # 8*8, depth, height, width + + # interpolation + mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) + + if refs is not None: + length = len(refs) + mask_pad = torch.zeros_like(mask[:, :length, :, :]) + mask = torch.cat((mask_pad, mask), dim=1) + result_masks.append(mask) + return result_masks + + def vace_latent(self, z, m): + return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] + + def prepare_control_latents( + self, control, control_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the control to latents shape as we concatenate the control to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + + if control is not None: + control = control.to(device=device, dtype=dtype) + bs = 1 + new_control = [] + for i in range(0, control.shape[0], bs): + control_bs = control[i : i + bs] + control_bs = self.vae.encode(control_bs)[0] + control_bs = control_bs.mode() + new_control.append(control_bs) + control = torch.cat(new_control, dim = 0) + + if control_image is not None: + control_image = control_image.to(device=device, dtype=dtype) + bs = 1 + new_control_pixel_values = [] + for i in range(0, control_image.shape[0], bs): + control_pixel_values_bs = control_image[i : i + bs] + control_pixel_values_bs = self.vae.encode(control_pixel_values_bs)[0] + control_pixel_values_bs = control_pixel_values_bs.mode() + new_control_pixel_values.append(control_pixel_values_bs) + control_image_latents = torch.cat(new_control_pixel_values, dim = 0) + else: + control_image_latents = None + + return control, control_image_latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + frames = self.vae.decode(latents.to(self.vae.dtype)).sample + frames = (frames / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + frames = frames.cpu().float().numpy() + return frames + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 480, + width: int = 720, + video: Union[torch.FloatTensor] = None, + mask_video: Union[torch.FloatTensor] = None, + control_video: Union[torch.FloatTensor] = None, + subject_ref_images: Union[torch.FloatTensor] = None, + num_frames: int = 49, + num_inference_steps: int = 50, + timesteps: Optional[List[int]] = None, + guidance_scale: float = 6, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: str = "numpy", + return_dict: bool = False, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + comfyui_progressbar: bool = False, + shift: int = 5, + vace_context_scale: float = 1.0 + ) -> Union[WanPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + Args: + + Examples: + + Returns: + + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + callback_on_step_end_tensor_inputs, + prompt_embeds, + negative_prompt_embeds, + ) + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + negative_prompt, + do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + if do_classifier_free_guidance: + in_prompt_embeds = negative_prompt_embeds + prompt_embeds + else: + in_prompt_embeds = prompt_embeds + + # 4. Prepare timesteps + if isinstance(self.scheduler, FlowMatchEulerDiscreteScheduler): + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps, mu=1) + elif isinstance(self.scheduler, FlowUniPCMultistepScheduler): + self.scheduler.set_timesteps(num_inference_steps, device=device, shift=shift) + timesteps = self.scheduler.timesteps + elif isinstance(self.scheduler, FlowDPMSolverMultistepScheduler): + sampling_sigmas = get_sampling_sigmas(num_inference_steps, shift) + timesteps, _ = retrieve_timesteps( + self.scheduler, + device=device, + sigmas=sampling_sigmas) + else: + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + if comfyui_progressbar: + from comfy.utils import ProgressBar + pbar = ProgressBar(num_inference_steps + 2) + + latent_channels = self.vae.config.latent_channels + + if comfyui_progressbar: + pbar.update(1) + + # Prepare mask latent variables + if mask_video is not None: + bs, _, video_length, height, width = video.size() + mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width) + mask_condition = mask_condition.to(dtype=torch.float32) + mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length) + mask_condition = torch.tile(mask_condition, [1, 3, 1, 1, 1]).to(dtype=weight_dtype, device=device) + + + if control_video is not None: + video_length = control_video.shape[2] + control_video = self.image_processor.preprocess(rearrange(control_video, "b c f h w -> (b f) c h w"), height=height, width=width) + control_video = control_video.to(dtype=torch.float32) + input_video = rearrange(control_video, "(b f) c h w -> b c f h w", f=video_length) + + input_video = input_video.to(dtype=weight_dtype, device=device) + + elif video is not None: + video_length = video.shape[2] + init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width) + init_video = init_video.to(dtype=torch.float32) + init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length).to(dtype=weight_dtype, device=device) + + input_video = init_video * (mask_condition < 0.5) + input_video = input_video.to(dtype=weight_dtype, device=device) + + if subject_ref_images is not None: + video_length = subject_ref_images.shape[2] + subject_ref_images = self.image_processor.preprocess(rearrange(subject_ref_images, "b c f h w -> (b f) c h w"), height=height, width=width) + subject_ref_images = subject_ref_images.to(dtype=torch.float32) + subject_ref_images = rearrange(subject_ref_images, "(b f) c h w -> b c f h w", f=video_length) + subject_ref_images = subject_ref_images.to(dtype=weight_dtype, device=device) + + bs, c, f, h, w = subject_ref_images.size() + new_subject_ref_images = [] + for i in range(bs): + new_subject_ref_images.append([]) + for j in range(f): + new_subject_ref_images[i].append(subject_ref_images[i, :, j:j+1]) + subject_ref_images = new_subject_ref_images + + vace_latents = self.vace_encode_frames(input_video, subject_ref_images, masks=mask_condition, vae=self.vae) + mask_latents = self.vace_encode_masks(mask_condition, subject_ref_images) + vace_context = self.vace_latent(vace_latents, mask_latents) + + # 5. Prepare latents. + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + weight_dtype, + device, + generator, + latents, + num_length_latents=vace_latents[0].size(1) + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + target_shape = (self.vae.latent_channels, vace_latents[0].size(1), vace_latents[0].size(2), vace_latents[0].size(3)) + seq_len = math.ceil((target_shape[2] * target_shape[3]) / (self.transformer.config.patch_size[1] * self.transformer.config.patch_size[2]) * target_shape[1]) + # 7. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self.transformer.num_inference_steps = num_inference_steps + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + self.transformer.current_steps = i + + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + if hasattr(self.scheduler, "scale_model_input"): + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + vace_context_input = torch.stack(vace_context * 2) if do_classifier_free_guidance else vace_context + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + with torch.cuda.amp.autocast(dtype=weight_dtype), torch.cuda.device(device=device): + noise_pred = self.transformer( + x=latent_model_input, + context=in_prompt_embeds, + t=timestep, + vace_context=vace_context_input, + seq_len=seq_len, + vace_context_scale=vace_context_scale + ) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if comfyui_progressbar: + pbar.update(1) + + if subject_ref_images is not None: + len_subject_ref_images = len(subject_ref_images[0]) + latents = latents[:, :, len_subject_ref_images:, :, :] + + if output_type == "numpy": + video = self.decode_latents(latents) + elif not output_type == "latent": + video = self.decode_latents(latents) + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + video = torch.from_numpy(video) + + return WanPipelineOutput(videos=video) diff --git a/videox_fun/pipeline/pipeline_z_image.py b/videox_fun/pipeline/pipeline_z_image.py new file mode 100644 index 0000000000000000000000000000000000000000..1b59b88e7843b94658d21102a05cf6114f2f7573 --- /dev/null +++ b/videox_fun/pipeline/pipeline_z_image.py @@ -0,0 +1,613 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import numpy as np +import PIL +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoTokenizer, PreTrainedModel + +from ..models import AutoencoderKL, ZImageTransformer2DModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImagePipeline + + >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + diffusers. prompt, + diffusers. height=1024, + diffusers. width=1024, + diffusers. num_inference_steps=9, + diffusers. guidance_scale=0.0, + diffusers. generator=torch.Generator("cuda").manual_seed(42), + diffusers. ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class ZImagePipelineOutput(BaseOutput): + """ + Output class for Z-Image image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + device = self._execution_device + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.in_channels + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/pipeline/pipeline_z_image_control.py b/videox_fun/pipeline/pipeline_z_image_control.py new file mode 100644 index 0000000000000000000000000000000000000000..5bfb87b9af25521c369de6251ce3d69be5b9cb5f --- /dev/null +++ b/videox_fun/pipeline/pipeline_z_image_control.py @@ -0,0 +1,633 @@ +# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import numpy as np +import PIL +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.utils import (BaseOutput, is_torch_xla_available, logging, + replace_example_docstring) +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoTokenizer, PreTrainedModel + +from ..models import AutoencoderKL, ZImageTransformer2DModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ZImagePipeline + + >>> pipe = ZImagePipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16) + >>> pipe.to("cuda") + + >>> # Optionally, set the attention backend to flash-attn 2 or 3, default is SDPA in PyTorch. + >>> # (1) Use flash attention 2 + >>> # pipe.transformer.set_attention_backend("flash") + >>> # (2) Use flash attention 3 + >>> # pipe.transformer.set_attention_backend("_flash_3") + + >>> prompt = "一幅为名为“造相「Z-IMAGE-TURBO」”的项目设计的创意海报。画面巧妙地将文字概念视觉化:一辆复古蒸汽小火车化身为巨大的拉链头,正拉开厚厚的冬日积雪,展露出一个生机盎然的春天。" + >>> image = pipe( + diffusers. prompt, + diffusers. height=1024, + diffusers. width=1024, + diffusers. num_inference_steps=9, + diffusers. guidance_scale=0.0, + diffusers. generator=torch.Generator("cuda").manual_seed(42), + diffusers. ).images[0] + >>> image.save("zimage.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +@dataclass +class ZImagePipelineOutput(BaseOutput): + """ + Output class for Z-Image image generation pipelines. + + Args: + images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + +class ZImageControlPipeline(DiffusionPipeline, FromSingleFileMixin): + model_cpu_offload_seq = "text_encoder->transformer->vae" + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: PreTrainedModel, + tokenizer: AutoTokenizer, + transformer: ZImageTransformer2DModel, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[Union[str, List[str]]] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + max_sequence_length: int = 512, + ): + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + prompt_embeds=prompt_embeds, + max_sequence_length=max_sequence_length, + ) + + if do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = ["" for _ in prompt] + else: + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + assert len(prompt) == len(negative_prompt) + negative_prompt_embeds = self._encode_prompt( + prompt=negative_prompt, + device=device, + prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + ) + else: + negative_prompt_embeds = [] + return prompt_embeds, negative_prompt_embeds + + def _encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + max_sequence_length: int = 512, + ) -> List[torch.FloatTensor]: + device = device or self._execution_device + + if prompt_embeds is not None: + return prompt_embeds + + if isinstance(prompt, str): + prompt = [prompt] + + for i, prompt_item in enumerate(prompt): + messages = [ + {"role": "user", "content": prompt_item}, + ] + prompt_item = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True, + ) + prompt[i] = prompt_item + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids.to(device) + prompt_masks = text_inputs.attention_mask.to(device).bool() + + prompt_embeds = self.text_encoder( + input_ids=text_input_ids, + attention_mask=prompt_masks, + output_hidden_states=True, + ).hidden_states[-2] + + embeddings_list = [] + + for i in range(len(prompt_embeds)): + embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) + + return embeddings_list + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + + control_image: Union[torch.FloatTensor] = None, + control_context_scale: float = 1.0, + + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + cfg_normalization: bool = False, + cfg_truncation: float = 1.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[List[torch.FloatTensor]] = None, + negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to 1024): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 1024): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + cfg_normalization (`bool`, *optional*, defaults to False): + Whether to apply configuration normalization. + cfg_truncation (`float`, *optional*, defaults to 1.0): + The truncation value for configuration. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain + tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + height = height or 1024 + width = width or 1024 + + vae_scale = self.vae_scale_factor * 2 + if height % vae_scale != 0: + raise ValueError( + f"Height must be divisible by {vae_scale} (got {height}). " + f"Please adjust the height to a multiple of {vae_scale}." + ) + if width % vae_scale != 0: + raise ValueError( + f"Width must be divisible by {vae_scale} (got {width}). " + f"Please adjust the width to a multiple of {vae_scale}." + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._interrupt = False + self._cfg_normalization = cfg_normalization + self._cfg_truncation = cfg_truncation + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = len(prompt_embeds) + + device = self._execution_device + weight_dtype = self.text_encoder.dtype + num_channels_latents = self.transformer.in_channels + + if control_image is not None: + control_image = self.image_processor.preprocess(control_image, height=height, width=width) + control_image = control_image.to(dtype=weight_dtype, device=device) + control_latents = self.vae.encode(control_image)[0].mode() + control_latents = (control_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + control_latents = torch.zeros_like(inpaint_latent) + + control_context = control_latents.unsqueeze(2) + + # If prompt_embeds is provided and prompt is None, skip encoding + if prompt_embeds is not None and prompt is None: + if self.do_classifier_free_guidance and negative_prompt_embeds is None: + raise ValueError( + "When `prompt_embeds` is provided without `prompt`, " + "`negative_prompt_embeds` must also be provided for classifier-free guidance." + ) + else: + ( + prompt_embeds, + negative_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + device=device, + max_sequence_length=max_sequence_length, + ) + + # 4. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # Repeat prompt_embeds for num_images_per_prompt + if num_images_per_prompt > 1: + prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] + if self.do_classifier_free_guidance and negative_prompt_embeds: + negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] + + actual_batch_size = batch_size * num_images_per_prompt + image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) + + # 5. Prepare timesteps + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + self.scheduler.sigma_min = 0.0 + scheduler_kwargs = {"mu": mu} + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + **scheduler_kwargs, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]) + timestep = (1000 - timestep) / 1000 + # Normalized time for time-aware config (0 at start, 1 at end) + t_norm = timestep[0].item() + + # Handle cfg truncation + current_guidance_scale = self.guidance_scale + if ( + self.do_classifier_free_guidance + and self._cfg_truncation is not None + and float(self._cfg_truncation) <= 1 + ): + if t_norm > self._cfg_truncation: + current_guidance_scale = 0.0 + + # Run CFG only if configured AND scale is non-zero + apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0 + + if apply_cfg: + latents_typed = latents.to(self.transformer.dtype) + latent_model_input = latents_typed.repeat(2, 1, 1, 1) + prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds + timestep_model_input = timestep.repeat(2) + else: + latent_model_input = latents.to(self.transformer.dtype) + prompt_embeds_model_input = prompt_embeds + timestep_model_input = timestep + + latent_model_input = latent_model_input.unsqueeze(2) + latent_model_input_list = list(latent_model_input.unbind(dim=0)) + + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + control_context=control_context, + control_context_scale=control_context_scale, + )[0] + + if apply_cfg: + # Perform CFG + pos_out = model_out_list[:actual_batch_size] + neg_out = model_out_list[actual_batch_size:] + + noise_pred = [] + for j in range(actual_batch_size): + pos = pos_out[j].float() + neg = neg_out[j].float() + + pred = pos + current_guidance_scale * (pos - neg) + + # Renormalization + if self._cfg_normalization and float(self._cfg_normalization) > 0.0: + ori_pos_norm = torch.linalg.vector_norm(pos) + new_pos_norm = torch.linalg.vector_norm(pred) + max_new_norm = ori_pos_norm * float(self._cfg_normalization) + if new_pos_norm > max_new_norm: + pred = pred * (max_new_norm / new_pos_norm) + + noise_pred.append(pred) + + noise_pred = torch.stack(noise_pred, dim=0) + else: + noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) + + noise_pred = noise_pred.squeeze(2) + noise_pred = -noise_pred + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] + assert latents.dtype == torch.float32 + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if output_type == "latent": + image = latents + + else: + latents = latents.to(self.vae.dtype) + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ZImagePipelineOutput(images=image) \ No newline at end of file diff --git a/videox_fun/reward/MPS/README.md b/videox_fun/reward/MPS/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d66d2ee73284e44bca36c5ba86dbc5911d64defa --- /dev/null +++ b/videox_fun/reward/MPS/README.md @@ -0,0 +1 @@ +This folder is modified from the official [MPS](https://github.com/Kwai-Kolors/MPS/tree/main) repository. \ No newline at end of file diff --git a/videox_fun/reward/aesthetic_predictor_v2_5/__init__.py b/videox_fun/reward/aesthetic_predictor_v2_5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d3d8f197ab036c531c5ba7efb65a0800f6d4fff --- /dev/null +++ b/videox_fun/reward/aesthetic_predictor_v2_5/__init__.py @@ -0,0 +1,13 @@ +from .siglip_v2_5 import ( + AestheticPredictorV2_5Head, + AestheticPredictorV2_5Model, + AestheticPredictorV2_5Processor, + convert_v2_5_from_siglip, +) + +__all__ = [ + "AestheticPredictorV2_5Head", + "AestheticPredictorV2_5Model", + "AestheticPredictorV2_5Processor", + "convert_v2_5_from_siglip", +] \ No newline at end of file diff --git a/videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py b/videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py new file mode 100644 index 0000000000000000000000000000000000000000..867f4295eb20ff2d1acff02c9303e9fd0a7c00de --- /dev/null +++ b/videox_fun/reward/aesthetic_predictor_v2_5/siglip_v2_5.py @@ -0,0 +1,133 @@ +# Borrowed from https://github.com/discus0434/aesthetic-predictor-v2-5/blob/3125a9e/src/aesthetic_predictor_v2_5/siglip_v2_5.py +import os +from collections import OrderedDict +from os import PathLike +from typing import Final + +import torch +import torch.nn as nn +import torchvision.transforms as transforms +from transformers import ( + SiglipImageProcessor, + SiglipVisionConfig, + SiglipVisionModel, + logging, +) +from transformers.image_processing_utils import BatchFeature +from transformers.modeling_outputs import ImageClassifierOutputWithNoAttention + +logging.set_verbosity_error() + +URL: Final[str] = ( + "https://github.com/discus0434/aesthetic-predictor-v2-5/raw/main/models/aesthetic_predictor_v2_5.pth" +) + + +class AestheticPredictorV2_5Head(nn.Module): + def __init__(self, config: SiglipVisionConfig) -> None: + super().__init__() + self.scoring_head = nn.Sequential( + nn.Linear(config.hidden_size, 1024), + nn.Dropout(0.5), + nn.Linear(1024, 128), + nn.Dropout(0.5), + nn.Linear(128, 64), + nn.Dropout(0.5), + nn.Linear(64, 16), + nn.Dropout(0.2), + nn.Linear(16, 1), + ) + + def forward(self, image_embeds: torch.Tensor) -> torch.Tensor: + return self.scoring_head(image_embeds) + + +class AestheticPredictorV2_5Model(SiglipVisionModel): + PATCH_SIZE = 14 + + def __init__(self, config: SiglipVisionConfig, *args, **kwargs) -> None: + super().__init__(config, *args, **kwargs) + self.layers = AestheticPredictorV2_5Head(config) + self.post_init() + self.transforms = transforms.Compose([ + transforms.Resize((384, 384)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + def forward( + self, + pixel_values: torch.FloatTensor | None = None, + labels: torch.Tensor | None = None, + return_dict: bool | None = None, + ) -> tuple | ImageClassifierOutputWithNoAttention: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = super().forward( + pixel_values=pixel_values, + return_dict=return_dict, + ) + image_embeds = outputs.pooler_output + image_embeds_norm = image_embeds / image_embeds.norm(dim=-1, keepdim=True) + prediction = self.layers(image_embeds_norm) + + loss = None + if labels is not None: + loss_fct = nn.MSELoss() + loss = loss_fct() + + if not return_dict: + return (loss, prediction, image_embeds) + + return ImageClassifierOutputWithNoAttention( + loss=loss, + logits=prediction, + hidden_states=image_embeds, + ) + + +class AestheticPredictorV2_5Processor(SiglipImageProcessor): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def __call__(self, *args, **kwargs) -> BatchFeature: + return super().__call__(*args, **kwargs) + + @classmethod + def from_pretrained( + self, + pretrained_model_name_or_path: str + | PathLike = "google/siglip-so400m-patch14-384", + *args, + **kwargs, + ) -> "AestheticPredictorV2_5Processor": + return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + +def convert_v2_5_from_siglip( + predictor_name_or_path: str | PathLike | None = None, + encoder_model_name: str = "google/siglip-so400m-patch14-384", + *args, + **kwargs, +) -> tuple[AestheticPredictorV2_5Model, AestheticPredictorV2_5Processor]: + model = AestheticPredictorV2_5Model.from_pretrained( + encoder_model_name, *args, **kwargs + ) + + processor = AestheticPredictorV2_5Processor.from_pretrained( + encoder_model_name, *args, **kwargs + ) + + if predictor_name_or_path is None or not os.path.exists(predictor_name_or_path): + state_dict = torch.hub.load_state_dict_from_url(URL, map_location="cpu") + else: + state_dict = torch.load(predictor_name_or_path, map_location="cpu") + + assert isinstance(state_dict, OrderedDict) + + model.layers.load_state_dict(state_dict) + model.eval() + + return model, processor \ No newline at end of file diff --git a/videox_fun/reward/improved_aesthetic_predictor.py b/videox_fun/reward/improved_aesthetic_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..43037b9fd5d80dc1402707c74f6b6a9584f26ce7 --- /dev/null +++ b/videox_fun/reward/improved_aesthetic_predictor.py @@ -0,0 +1,49 @@ +import os + +import torch +import torch.nn as nn +from transformers import CLIPModel +from torchvision.datasets.utils import download_url + +URL = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/sac%2Blogos%2Bava1-l14-linearMSE.pth" +FILENAME = "sac+logos+ava1-l14-linearMSE.pth" +MD5 = "b1047fd767a00134b8fd6529bf19521a" + + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.Sequential( + nn.Linear(768, 1024), + nn.Dropout(0.2), + nn.Linear(1024, 128), + nn.Dropout(0.2), + nn.Linear(128, 64), + nn.Dropout(0.1), + nn.Linear(64, 16), + nn.Linear(16, 1), + ) + + + def forward(self, embed): + return self.layers(embed) + + +class ImprovedAestheticPredictor(nn.Module): + def __init__(self, encoder_path="openai/clip-vit-large-patch14", predictor_path=None): + super().__init__() + self.encoder = CLIPModel.from_pretrained(encoder_path) + self.predictor = MLP() + if predictor_path is None or not os.path.exists(predictor_path): + download_url(URL, torch.hub.get_dir(), FILENAME, md5=MD5) + predictor_path = os.path.join(torch.hub.get_dir(), FILENAME) + state_dict = torch.load(predictor_path, map_location="cpu") + self.predictor.load_state_dict(state_dict) + self.eval() + + + def forward(self, pixel_values): + embed = self.encoder.get_image_features(pixel_values=pixel_values) + embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) + + return self.predictor(embed).squeeze(1) diff --git a/videox_fun/reward/reward_fn.py b/videox_fun/reward/reward_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..6526919cbea00cb859a542d7758326c89d49ecf9 --- /dev/null +++ b/videox_fun/reward/reward_fn.py @@ -0,0 +1,385 @@ +import os +from abc import ABC, abstractmethod + +import torch +import torchvision.transforms as transforms +from einops import rearrange +from torchvision.datasets.utils import download_url +from typing import Optional, Tuple + + +# All reward models. +__all__ = ["AestheticReward", "HPSReward", "PickScoreReward", "MPSReward"] + + +class BaseReward(ABC): + """An base class for reward models. A custom Reward class must implement two functions below. + """ + def __init__(self): + """Define your reward model and image transformations (optional) here. + """ + pass + + @abstractmethod + def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]: + """Given batch frames with shape `[B, C, T, H, W]` extracted from a list of videos and a list of prompts + (optional) correspondingly, return the loss and reward computed by your reward model (reduction by mean). + """ + pass + +class AestheticReward(BaseReward): + """Aesthetic Predictor [V2](https://github.com/christophschuhmann/improved-aesthetic-predictor) + and [V2.5](https://github.com/discus0434/aesthetic-predictor-v2-5) reward model. + """ + def __init__( + self, + encoder_path="openai/clip-vit-large-patch14", + predictor_path=None, + version="v2", + device="cpu", + dtype=torch.float16, + max_reward=10, + loss_scale=0.1, + ): + from .improved_aesthetic_predictor import ImprovedAestheticPredictor + from ..video_caption.utils.siglip_v2_5 import convert_v2_5_from_siglip + + self.encoder_path = encoder_path + self.predictor_path = predictor_path + self.version = version + self.device = device + self.dtype = dtype + self.max_reward = max_reward + self.loss_scale = loss_scale + + if self.version != "v2" and self.version != "v2.5": + raise ValueError("Only v2 and v2.5 are supported.") + if self.version == "v2": + assert "clip-vit-large-patch14" in encoder_path.lower() + self.model = ImprovedAestheticPredictor(encoder_path=self.encoder_path, predictor_path=self.predictor_path) + # https://huggingface.co/openai/clip-vit-large-patch14/blob/main/preprocessor_config.json + # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio. + self.transform = transforms.Compose([ + transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), + ]) + elif self.version == "v2.5": + assert "siglip-so400m-patch14-384" in encoder_path.lower() + self.model, _ = convert_v2_5_from_siglip(encoder_model_name=self.encoder_path) + # https://huggingface.co/google/siglip-so400m-patch14-384/blob/main/preprocessor_config.json + self.transform = transforms.Compose([ + transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + self.model.to(device=self.device, dtype=self.dtype) + self.model.requires_grad_(False) + + + def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]: + batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") + batch_loss, batch_reward = 0, 0 + for frames in batch_frames: + pixel_values = torch.stack([self.transform(frame) for frame in frames]) + pixel_values = pixel_values.to(self.device, dtype=self.dtype) + if self.version == "v2": + reward = self.model(pixel_values) + elif self.version == "v2.5": + reward = self.model(pixel_values).logits.squeeze() + # Convert reward to loss in [0, 1]. + if self.max_reward is None: + loss = (-1 * reward) * self.loss_scale + else: + loss = abs(reward - self.max_reward) * self.loss_scale + batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() + + return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] + + +class HPSReward(BaseReward): + """[HPS](https://github.com/tgxs002/HPSv2) v2 and v2.1 reward model. + """ + def __init__( + self, + model_path=None, + version="v2.0", + device="cpu", + dtype=torch.float16, + max_reward=1, + loss_scale=1, + ): + from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer + + self.model_path = model_path + self.version = version + self.device = device + self.dtype = dtype + self.max_reward = max_reward + self.loss_scale = loss_scale + + self.model, _, _ = create_model_and_transforms( + "ViT-H-14", + "laion2B-s32B-b79K", + precision=self.dtype, + device=self.device, + jit=False, + force_quick_gelu=False, + force_custom_text=False, + force_patch_dropout=False, + force_image_size=None, + pretrained_image=False, + image_mean=None, + image_std=None, + light_augmentation=True, + aug_cfg={}, + output_dict=True, + with_score_predictor=False, + with_region_predictor=False, + ) + self.tokenizer = get_tokenizer("ViT-H-14") + + # https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json + # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio. + self.transform = transforms.Compose([ + transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), + ]) + + if version == "v2.0": + url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2_compressed.pt" + filename = "HPS_v2_compressed.pt" + md5 = "fd9180de357abf01fdb4eaad64631db4" + elif version == "v2.1": + url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2.1_compressed.pt" + filename = "HPS_v2.1_compressed.pt" + md5 = "4067542e34ba2553a738c5ac6c1d75c0" + else: + raise ValueError("Only v2.0 and v2.1 are supported.") + if self.model_path is None or not os.path.exists(self.model_path): + download_url(url, torch.hub.get_dir(), md5=md5) + model_path = os.path.join(torch.hub.get_dir(), filename) + + state_dict = torch.load(model_path, map_location="cpu")["state_dict"] + self.model.load_state_dict(state_dict) + self.model.to(device=self.device, dtype=self.dtype) + self.model.requires_grad_(False) + self.model.eval() + + def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: + assert batch_frames.shape[0] == len(batch_prompt) + # Compute batch reward and loss in frame-wise. + batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") + batch_loss, batch_reward = 0, 0 + for frames in batch_frames: + image_inputs = torch.stack([self.transform(frame) for frame in frames]) + image_inputs = image_inputs.to(device=self.device, dtype=self.dtype) + text_inputs = self.tokenizer(batch_prompt).to(device=self.device) + outputs = self.model(image_inputs, text_inputs) + + image_features, text_features = outputs["image_features"], outputs["text_features"] + logits = image_features @ text_features.T + reward = torch.diagonal(logits) + # Convert reward to loss in [0, 1]. + if self.max_reward is None: + loss = (-1 * reward) * self.loss_scale + else: + loss = abs(reward - self.max_reward) * self.loss_scale + + batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() + + return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] + + +class PickScoreReward(BaseReward): + """[PickScore](https://github.com/yuvalkirstain/PickScore) reward model. + """ + def __init__( + self, + model_path="yuvalkirstain/PickScore_v1", + device="cpu", + dtype=torch.float16, + max_reward=1, + loss_scale=1, + ): + from transformers import AutoProcessor, AutoModel + + self.model_path = model_path + self.device = device + self.dtype = dtype + self.max_reward = max_reward + self.loss_scale = loss_scale + + # https://huggingface.co/yuvalkirstain/PickScore_v1/blob/main/preprocessor_config.json + self.transform = transforms.Compose([ + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), + ]) + self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=self.dtype) + self.model = AutoModel.from_pretrained(model_path, torch_dtype=self.dtype).eval().to(device) + self.model.requires_grad_(False) + self.model.eval() + + def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]: + assert batch_frames.shape[0] == len(batch_prompt) + # Compute batch reward and loss in frame-wise. + batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") + batch_loss, batch_reward = 0, 0 + for frames in batch_frames: + image_inputs = torch.stack([self.transform(frame) for frame in frames]) + image_inputs = image_inputs.to(device=self.device, dtype=self.dtype) + text_inputs = self.processor( + text=batch_prompt, + padding=True, + truncation=True, + max_length=77, + return_tensors="pt", + ).to(self.device) + image_features = self.model.get_image_features(pixel_values=image_inputs) + text_features = self.model.get_text_features(**text_inputs) + image_features = image_features / torch.norm(image_features, dim=-1, keepdim=True) + text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True) + + logits = image_features @ text_features.T + reward = torch.diagonal(logits) + # Convert reward to loss in [0, 1]. + if self.max_reward is None: + loss = (-1 * reward) * self.loss_scale + else: + loss = abs(reward - self.max_reward) * self.loss_scale + + batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() + + return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] + + +class MPSReward(BaseReward): + """[MPS](https://github.com/Kwai-Kolors/MPS) reward model. + """ + def __init__( + self, + model_path=None, + device="cpu", + dtype=torch.float16, + max_reward=1, + loss_scale=1, + ): + from transformers import AutoTokenizer, AutoConfig + from .MPS.trainer.models.clip_model import CLIPModel + + self.model_path = model_path + self.device = device + self.dtype = dtype + self.condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things." + self.max_reward = max_reward + self.loss_scale = loss_scale + + processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" + # https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json + # TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio. + self.transform = transforms.Compose([ + transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]), + ]) + + # We convert the original [ckpt](http://drive.google.com/file/d/17qrK_aJkVNM75ZEvMEePpLj6L867MLkN/view?usp=sharing) + # (contains the entire model) to a `state_dict`. + url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/MPS_overall.pth" + filename = "MPS_overall.pth" + md5 = "1491cbbbd20565747fe07e7572e2ac56" + if self.model_path is None or not os.path.exists(self.model_path): + download_url(url, torch.hub.get_dir(), md5=md5) + model_path = os.path.join(torch.hub.get_dir(), filename) + + self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True) + config = AutoConfig.from_pretrained(processor_name_or_path) + self.model = CLIPModel(config) + state_dict = torch.load(model_path, map_location="cpu") + self.model.load_state_dict(state_dict, strict=False) + self.model.to(device=self.device, dtype=self.dtype) + self.model.requires_grad_(False) + self.model.eval() + + def _tokenize(self, caption): + input_ids = self.tokenizer( + caption, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt" + ).input_ids + + return input_ids + + def __call__( + self, + batch_frames: torch.Tensor, + batch_prompt: list[str], + batch_condition: Optional[list[str]] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + if batch_condition is None: + batch_condition = [self.condition] * len(batch_prompt) + batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w") + batch_loss, batch_reward = 0, 0 + for frames in batch_frames: + image_inputs = torch.stack([self.transform(frame) for frame in frames]) + image_inputs = image_inputs.to(device=self.device, dtype=self.dtype) + text_inputs = self._tokenize(batch_prompt).to(self.device) + condition_inputs = self._tokenize(batch_condition).to(device=self.device) + text_features, image_features = self.model(text_inputs, image_inputs, condition_inputs) + + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + # reward = self.model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_features)) + logits = image_features @ text_features.T + reward = torch.diagonal(logits) + # Convert reward to loss in [0, 1]. + if self.max_reward is None: + loss = (-1 * reward) * self.loss_scale + else: + loss = abs(reward - self.max_reward) * self.loss_scale + + batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean() + + return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0] + + +if __name__ == "__main__": + import numpy as np + from decord import VideoReader + + video_path_list = ["your_video_path_1.mp4", "your_video_path_2.mp4"] + prompt_list = ["your_prompt_1", "your_prompt_2"] + num_sampled_frames = 8 + + to_tensor = transforms.ToTensor() + + sampled_frames_list = [] + for video_path in video_path_list: + vr = VideoReader(video_path) + sampled_frame_indices = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int) + sampled_frames = vr.get_batch(sampled_frame_indices).asnumpy() + sampled_frames = torch.stack([to_tensor(frame) for frame in sampled_frames]) + sampled_frames_list.append(sampled_frames) + sampled_frames = torch.stack(sampled_frames_list) + sampled_frames = rearrange(sampled_frames, "b t c h w -> b c t h w") + + aesthetic_reward_v2 = AestheticReward(device="cuda", dtype=torch.bfloat16) + print(f"aesthetic_reward_v2: {aesthetic_reward_v2(sampled_frames)}") + + aesthetic_reward_v2_5 = AestheticReward( + encoder_path="google/siglip-so400m-patch14-384", version="v2.5", device="cuda", dtype=torch.bfloat16 + ) + print(f"aesthetic_reward_v2_5: {aesthetic_reward_v2_5(sampled_frames)}") + + hps_reward_v2 = HPSReward(device="cuda", dtype=torch.bfloat16) + print(f"hps_reward_v2: {hps_reward_v2(sampled_frames, prompt_list)}") + + hps_reward_v2_1 = HPSReward(version="v2.1", device="cuda", dtype=torch.bfloat16) + print(f"hps_reward_v2_1: {hps_reward_v2_1(sampled_frames, prompt_list)}") + + pick_score = PickScoreReward(device="cuda", dtype=torch.bfloat16) + print(f"pick_score_reward: {pick_score(sampled_frames, prompt_list)}") + + mps_score = MPSReward(device="cuda", dtype=torch.bfloat16) + print(f"mps_reward: {mps_score(sampled_frames, prompt_list)}") \ No newline at end of file diff --git a/videox_fun/ui/cogvideox_fun_ui.py b/videox_fun/ui/cogvideox_fun_ui.py new file mode 100755 index 0000000000000000000000000000000000000000..748ca9605e95f31dcf6aecc769ac7c8ce37c7fd5 --- /dev/null +++ b/videox_fun/ui/cogvideox_fun_ui.py @@ -0,0 +1,722 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..models import (AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, + T5EncoderModel, T5Tokenizer) +from ..pipeline import (CogVideoXFunControlPipeline, + CogVideoXFunInpaintPipeline, CogVideoXFunPipeline) +from ..utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name, + convert_weight_dtype_wrapper) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer, + get_video_to_video_latent, save_videos_grid) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_ui_outputs) +from ..dist import set_multi_gpus_devices, shard_model + + +class CogVideoXFunController(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + self.vae = AutoencoderKLCogVideoX.from_pretrained( + diffusion_transformer_dropdown, + subfolder="vae", + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = CogVideoXTransformer3DModel.from_pretrained( + diffusion_transformer_dropdown, + subfolder="transformer", + low_cpu_mem_usage=True, + ).to(self.weight_dtype) + + # Get tokenizer and text_encoder + tokenizer = T5Tokenizer.from_pretrained( + diffusion_transformer_dropdown, subfolder="tokenizer" + ) + text_encoder = T5EncoderModel.from_pretrained( + diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype + ) + + # Get pipeline + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = CogVideoXFunInpaintPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=self.vae, + transformer=self.transformer, + scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"), + ) + else: + self.pipeline = CogVideoXFunPipeline( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=self.vae, + transformer=self.transformer, + scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"), + ) + else: + self.pipeline = CogVideoXFunControlPipeline( + diffusion_transformer_dropdown, + vae=self.vae, + transformer=self.transformer, + scheduler=self.scheduler_dict[list(self.scheduler_dict.keys())[0]].from_pretrained(diffusion_transformer_dropdown, subfolder="scheduler"), + torch_dtype=self.weight_dtype + ) + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.transformer_blocks)): + self.pipeline.transformer.transformer_blocks[i] = torch.compile(self.pipeline.transformer.transformer_blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.pipeline.transformer, exclude_module_name=[], device=self.device) + convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.pipeline.transformer, exclude_module_name=[], device=self.device) + convert_weight_dtype_wrapper(self.pipeline.transformer, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + + print(f"Load scheduler.") + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(self.pipeline.scheduler.config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Merge Lora done.") + + if fps is None: + fps = 8 + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if generation_method == "Long Video Generation": + if validation_video is not None: + raise gr.Error(f"Video to Video is not Support Long Video Generation now.") + init_frames = 0 + last_frames = init_frames + partial_video_length + while init_frames < length_slider: + if last_frames >= length_slider: + _partial_video_length = length_slider - init_frames + _partial_video_length = int((_partial_video_length - 1) // self.vae.config.temporal_compression_ratio * self.vae.config.temporal_compression_ratio) + 1 + + if _partial_video_length <= 0: + break + else: + _partial_video_length = partial_video_length + + if last_frames >= length_slider: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider)) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider)) + + with torch.no_grad(): + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = _partial_video_length, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + strength = 1, + ).videos + + if init_frames != 0: + mix_ratio = torch.from_numpy( + np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32) + ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + + new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \ + sample[:, :, :overlap_video_length] * mix_ratio + new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2) + + sample = new_sample + else: + new_sample = sample + + if last_frames >= length_slider: + break + + start_image = [ + Image.fromarray( + (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8) + ) for _index in range(-overlap_video_length, 0) + ] + + init_frames = init_frames + _partial_video_length - overlap_video_length + last_frames = init_frames + _partial_video_length + else: + if validation_video is not None: + input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + strength = denoise_strength + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + strength = 1 + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + strength = strength, + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator + ).videos + else: + input_video, input_video_mask, ref_image, clip_image = get_video_to_video_latent(control_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + control_video = input_video, + ).videos + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +CogVideoXFunController_Host = CogVideoXFunController +CogVideoXFunController_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, compile_dit, weight_dtype, savedir_sample=None): + controller = CogVideoXFunController( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # CogVideoX-Fun: + + A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_model_type(visible=True) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts() + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 384, default_width = 672, maximum_height = 1344, + maximum_width = 1344, + ) + gr.Markdown( + """ + V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames. + (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。) + """ + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation", "Long Video Generation"], + default_video_length=49, + maximum_video_length=85, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=85, value=49, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, compile_dit, weight_dtype, savedir_sample=None): + controller = CogVideoXFunController_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # CogVideoX-Fun + + A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts() + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 384, default_width = 672, maximum_height = 1344, + maximum_width = 1344, + ) + gr.Markdown( + """ + V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames. + (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。) + """ + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=49, + maximum_video_length=85, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)", "Video Control (视频控制)"], prompt_textbox + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=8, maximum=85, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = CogVideoXFunController_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # CogVideoX-Fun + + A CogVideoX with more flexible generation conditions, capable of producing videos of different resolutions, around 6 seconds, and fps 8 (frames 1 to 49), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts() + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 384, default_width = 672, maximum_height = 1344, + maximum_width = 1344, + ) + gr.Markdown( + """ + V1.0 and V1.1 support up to 49 frames of video generation, while V1.5 supports up to 85 frames. + (V1.0和V1.1支持最大49帧视频生成,V1.5支持最大85帧视频生成。) + """ + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=49, + maximum_video_length=85, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video to Video (视频到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=85, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/ui/controller.py b/videox_fun/ui/controller.py new file mode 100755 index 0000000000000000000000000000000000000000..31a5e9d0e76ecc32e4d92273ae7fb55a874c0dfe --- /dev/null +++ b/videox_fun/ui/controller.py @@ -0,0 +1,514 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import base64 +import gc +import json +import os +import hashlib +import random +from datetime import datetime +from glob import glob + +import cv2 +import gradio as gr +import numpy as np +import pkg_resources +import requests +import torch +from diffusers import (CogVideoXDDIMScheduler, DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, + FlowMatchEulerDiscreteScheduler, PNDMScheduler) +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..utils.utils import save_videos_grid +from ..utils.fm_solvers import FlowDPMSolverMultistepScheduler +from ..utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from ..dist import set_multi_gpus_devices + +gradio_version = pkg_resources.get_distribution("gradio").version +gradio_version_is_above_4 = True if int(gradio_version.split('.')[0]) >= 4 else False + +css = """ +.toolbutton { + margin-buttom: 0em 0em 0em 0em; + max-width: 2.5em; + min-width: 2.5em !important; + height: 2.5em; +} +""" + +ddpm_scheduler_dict = { + "Euler": EulerDiscreteScheduler, + "Euler A": EulerAncestralDiscreteScheduler, + "DPM++": DPMSolverMultistepScheduler, + "PNDM": PNDMScheduler, + "DDIM": DDIMScheduler, + "DDIM_Origin": DDIMScheduler, + "DDIM_Cog": CogVideoXDDIMScheduler, +} +flow_scheduler_dict = { + "Flow": FlowMatchEulerDiscreteScheduler, + "Flow_Unipc": FlowUniPCMultistepScheduler, + "Flow_DPM++": FlowDPMSolverMultistepScheduler, +} +all_cheduler_dict = {**ddpm_scheduler_dict, **flow_scheduler_dict} + +class Fun_Controller: + def __init__( + self, GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=None, ulysses_degree=1, ring_degree=1, + fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False, + weight_dtype=None, savedir_sample=None, + ): + # config dirs + self.basedir = os.getcwd() + self.config_dir = os.path.join(self.basedir, "config") + self.diffusion_transformer_dir = os.path.join(self.basedir, "models", "Diffusion_Transformer") + self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") + self.personalized_model_dir = os.path.join(self.basedir, "models", "Personalized_Model") + if savedir_sample is None: + self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) + else: + self.savedir_sample = savedir_sample + os.makedirs(self.savedir_sample, exist_ok=True) + + self.GPU_memory_mode = GPU_memory_mode + self.model_name = model_name + self.diffusion_transformer_dropdown = model_name + self.scheduler_dict = scheduler_dict + self.model_type = model_type + if config_path is not None: + self.config_path = os.path.realpath(config_path) + self.config = OmegaConf.load(config_path) + else: + self.config_path = None + self.ulysses_degree = ulysses_degree + self.ring_degree = ring_degree + self.fsdp_dit = fsdp_dit + self.fsdp_text_encoder = fsdp_text_encoder + self.compile_dit = compile_dit + self.weight_dtype = weight_dtype + self.device = set_multi_gpus_devices(self.ulysses_degree, self.ring_degree) + + self.diffusion_transformer_list = [] + self.motion_module_list = [] + self.personalized_model_list = [] + self.config_list = [] + + # config models + self.tokenizer = None + self.text_encoder = None + self.vae = None + self.transformer = None + self.transformer_2 = None + self.pipeline = None + self.base_model_path = "none" + self.base_model_2_path = "none" + self.lora_model_path = "none" + self.lora_model_2_path = "none" + + self.refresh_config() + self.refresh_diffusion_transformer() + self.refresh_personalized_model() + if model_name != None: + self.update_diffusion_transformer(model_name) + + def refresh_config(self): + config_list = [] + for root, dirs, files in os.walk(self.config_dir): + for file in files: + if file.endswith(('.yaml', '.yml')): + full_path = os.path.join(root, file) + config_list.append(full_path) + self.config_list = config_list + + def refresh_diffusion_transformer(self): + self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/"))) + + def refresh_personalized_model(self): + personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors"))) + self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] + + def update_model_type(self, model_type): + self.model_type = model_type + + def update_config(self, config_dropdown): + self.config_path = config_dropdown + self.config = OmegaConf.load(config_dropdown) + print(f"Update config: {config_dropdown}") + + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + pass + + def update_base_model(self, base_model_dropdown, is_checkpoint_2=False): + if not is_checkpoint_2: + self.base_model_path = base_model_dropdown + else: + self.base_model_2_path = base_model_dropdown + print(f"Update base model: {base_model_dropdown}") + if base_model_dropdown == "none": + return gr.update() + if self.transformer is None and not is_checkpoint_2: + gr.Info(f"Please select a pretrained model path.") + print(f"Please select a pretrained model path.") + return gr.update(value=None) + elif self.transformer_2 is None and is_checkpoint_2: + gr.Info(f"Please select a pretrained model path.") + print(f"Please select a pretrained model path.") + return gr.update(value=None) + else: + base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) + base_model_state_dict = {} + with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: + for key in f.keys(): + base_model_state_dict[key] = f.get_tensor(key) + if not is_checkpoint_2: + self.transformer.load_state_dict(base_model_state_dict, strict=False) + else: + self.transformer_2.load_state_dict(base_model_state_dict, strict=False) + print("Update base model done") + return gr.update() + + def update_lora_model(self, lora_model_dropdown, is_checkpoint_2=False): + print(f"Update lora model: {lora_model_dropdown}") + if lora_model_dropdown == "none": + self.lora_model_path = "none" + return gr.update() + lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) + if not is_checkpoint_2: + self.lora_model_path = lora_model_dropdown + else: + self.lora_model_2_path = lora_model_dropdown + return gr.update() + + def clear_cache(self,): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + def auto_model_clear_cache(self, model): + origin_device = model.device + model = model.to("cpu") + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + model = model.to(origin_device) + + def input_check(self, + resize_method, + generation_method, + start_image, + end_image, + validation_video, + control_video, + is_api = False, + ): + if self.transformer is None: + if is_api: + return "", f"Please select a pretrained model path." + else: + raise gr.Error(f"Please select a pretrained model path.") + + if control_video is not None and self.model_type == "Inpaint": + if is_api: + return "", f"If specifying the control video, please set the model_type == \"Control\". " + else: + raise gr.Error(f"If specifying the control video, please set the model_type == \"Control\". ") + + if control_video is None and self.model_type == "Control": + if is_api: + return "", f"If set the model_type == \"Control\", please specifying the control video. " + else: + raise gr.Error(f"If set the model_type == \"Control\", please specifying the control video. ") + + if resize_method == "Resize according to Reference": + if start_image is None and validation_video is None and control_video is None: + if is_api: + return "", f"Please upload an image when using \"Resize according to Reference\"." + else: + raise gr.Error(f"Please upload an image when using \"Resize according to Reference\".") + + if self.transformer.config.in_channels == self.vae.config.latent_channels and start_image is not None: + if is_api: + return "", f"Please select an image to video pretrained model while using image to video." + else: + raise gr.Error(f"Please select an image to video pretrained model while using image to video.") + + if self.transformer.config.in_channels == self.vae.config.latent_channels and generation_method == "Long Video Generation": + if is_api: + return "", f"Please select an image to video pretrained model while using long video generation." + else: + raise gr.Error(f"Please select an image to video pretrained model while using long video generation.") + + if start_image is None and end_image is not None: + if is_api: + return "", f"If specifying the ending image of the video, please specify a starting image of the video." + else: + raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.") + return "", "OK" + + def get_height_width_from_reference( + self, + base_resolution, + start_image, + validation_video, + control_video, + ): + spatial_compression_ratio = self.vae.config.spatial_compression_ratio if hasattr(self.vae.config, "spatial_compression_ratio") else 8 + aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()} + if self.model_type == "Inpaint": + if validation_video is not None: + original_width, original_height = Image.fromarray(cv2.VideoCapture(validation_video).read()[1]).size + else: + original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size + else: + original_width, original_height = Image.fromarray(cv2.VideoCapture(control_video).read()[1]).size + closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size) + height_slider, width_slider = [int(x / spatial_compression_ratio / 2) * spatial_compression_ratio * 2 for x in closest_size] + return height_slider, width_slider + + def save_outputs(self, is_image, length_slider, sample, fps): + def save_results(): + if not os.path.exists(self.savedir_sample): + os.makedirs(self.savedir_sample, exist_ok=True) + index = len([path for path in os.listdir(self.savedir_sample)]) + 1 + prefix = str(index).zfill(8) + + md5_hash = hashlib.md5(sample.cpu().numpy().tobytes()).hexdigest() + + if is_image or length_slider == 1: + save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") + print(f"Saving to {save_sample_path}") + image = sample[0, :, 0] + image = image.transpose(0, 1).transpose(1, 2) + image = (image * 255).numpy().astype(np.uint8) + image = Image.fromarray(image) + image.save(save_sample_path) + + else: + save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") + print(f"Saving to {save_sample_path}") + save_videos_grid(sample, save_sample_path, fps=fps) + return save_sample_path + + if self.ulysses_degree * self.ring_degree > 1: + import torch.distributed as dist + if dist.get_rank() == 0: + save_sample_path = save_results() + else: + save_sample_path = None + else: + save_sample_path = save_results() + return save_sample_path + + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + is_api = False, + ): + pass + +def post_to_host( + diffusion_transformer_dropdown, + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, + prompt_textbox, negative_prompt_textbox, + sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, + base_resolution, generation_method, length_slider, cfg_scale_slider, + start_image, end_image, validation_video, validation_video_mask, denoise_strength, seed_textbox, + ref_image = None, enable_teacache = None, teacache_threshold = None, num_skip_start_steps = None, + teacache_offload = None, cfg_skip_ratio = None,enable_riflex = None, riflex_k = None, +): + if start_image is not None: + with open(start_image, 'rb') as file: + file_content = file.read() + start_image_encoded_content = base64.b64encode(file_content) + start_image = start_image_encoded_content.decode('utf-8') + + if end_image is not None: + with open(end_image, 'rb') as file: + file_content = file.read() + end_image_encoded_content = base64.b64encode(file_content) + end_image = end_image_encoded_content.decode('utf-8') + + if validation_video is not None: + with open(validation_video, 'rb') as file: + file_content = file.read() + validation_video_encoded_content = base64.b64encode(file_content) + validation_video = validation_video_encoded_content.decode('utf-8') + + if validation_video_mask is not None: + with open(validation_video_mask, 'rb') as file: + file_content = file.read() + validation_video_mask_encoded_content = base64.b64encode(file_content) + validation_video_mask = validation_video_mask_encoded_content.decode('utf-8') + + if ref_image is not None: + with open(ref_image, 'rb') as file: + file_content = file.read() + ref_image_encoded_content = base64.b64encode(file_content) + ref_image = ref_image_encoded_content.decode('utf-8') + + datas = { + "base_model_path": base_model_dropdown, + "lora_model_path": lora_model_dropdown, + "lora_alpha_slider": lora_alpha_slider, + "prompt_textbox": prompt_textbox, + "negative_prompt_textbox": negative_prompt_textbox, + "sampler_dropdown": sampler_dropdown, + "sample_step_slider": sample_step_slider, + "resize_method": resize_method, + "width_slider": width_slider, + "height_slider": height_slider, + "base_resolution": base_resolution, + "generation_method": generation_method, + "length_slider": length_slider, + "cfg_scale_slider": cfg_scale_slider, + "start_image": start_image, + "end_image": end_image, + "validation_video": validation_video, + "validation_video_mask": validation_video_mask, + "denoise_strength": denoise_strength, + "seed_textbox": seed_textbox, + + "ref_image": ref_image, + "enable_teacache": enable_teacache, + "teacache_threshold": teacache_threshold, + "num_skip_start_steps": num_skip_start_steps, + "teacache_offload": teacache_offload, + "cfg_skip_ratio": cfg_skip_ratio, + "enable_riflex": enable_riflex, + "riflex_k": riflex_k, + } + + session = requests.session() + session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")}) + + response = session.post(url=f'{os.environ.get("EAS_URL")}/videox_fun/infer_forward', json=datas, timeout=300) + + outputs = response.json() + return outputs + + +class Fun_Controller_Client: + def __init__(self, scheduler_dict, savedir_sample): + self.basedir = os.getcwd() + if savedir_sample is None: + self.savedir_sample = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) + else: + self.savedir_sample = savedir_sample + os.makedirs(self.savedir_sample, exist_ok=True) + + self.scheduler_dict = scheduler_dict + + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + ): + is_image = True if generation_method == "Image Generation" else False + + outputs = post_to_host( + diffusion_transformer_dropdown, + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, + prompt_textbox, negative_prompt_textbox, + sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider, + base_resolution, generation_method, length_slider, cfg_scale_slider, + start_image, end_image, validation_video, validation_video_mask, denoise_strength, + seed_textbox, ref_image = ref_image, enable_teacache = enable_teacache, teacache_threshold = teacache_threshold, + num_skip_start_steps = num_skip_start_steps, teacache_offload = teacache_offload, + cfg_skip_ratio = cfg_skip_ratio, enable_riflex = enable_riflex, riflex_k = riflex_k, + ) + + try: + base64_encoding = outputs["base64_encoding"] + except: + return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"] + + decoded_data = base64.b64decode(base64_encoding) + + if not os.path.exists(self.savedir_sample): + os.makedirs(self.savedir_sample, exist_ok=True) + md5_hash = hashlib.md5(decoded_data).hexdigest() + + index = len([path for path in os.listdir(self.savedir_sample)]) + 1 + prefix = str(index).zfill(8) + + if is_image or length_slider == 1: + save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.png") + print(f"Saving to {save_sample_path}") + with open(save_sample_path, "wb") as file: + file.write(decoded_data) + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + save_sample_path = os.path.join(self.savedir_sample, prefix + f"-{md5_hash}.mp4") + print(f"Saving to {save_sample_path}") + with open(save_sample_path, "wb") as file: + file.write(decoded_data) + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" diff --git a/videox_fun/ui/ui.py b/videox_fun/ui/ui.py new file mode 100755 index 0000000000000000000000000000000000000000..ee7d02458458bdd287a8fbbe812f663d94b2f308 --- /dev/null +++ b/videox_fun/ui/ui.py @@ -0,0 +1,358 @@ +import random + +import gradio as gr + + +def create_model_type(visible): + gr.Markdown( + """ + ### Model Type (模型的种类,正常模型还是控制模型). + """, + visible=visible, + ) + with gr.Row(): + model_type = gr.Dropdown( + label="The model type of the model (模型的种类,正常模型还是控制模型)", + choices=["Inpaint", "Control"], + value="Inpaint", + visible=visible, + interactive=True, + ) + return model_type + +def create_fake_model_type(visible): + gr.Markdown( + """ + ### Model Type (模型的种类,正常模型还是控制模型). + """, + visible=visible, + ) + with gr.Row(): + model_type = gr.Dropdown( + label="The model type of the model (模型的种类,正常模型还是控制模型)", + choices=["Inpaint", "Control"], + value="Inpaint", + interactive=False, + visible=visible, + ) + return model_type + +def create_model_checkpoints(controller, visible): + gr.Markdown( + """ + ### Model checkpoints (模型路径). + """ + ) + with gr.Row(visible=visible): + diffusion_transformer_dropdown = gr.Dropdown( + label="Pretrained Model Path (预训练模型路径)", + choices=controller.diffusion_transformer_list, + value="none", + interactive=True, + ) + diffusion_transformer_dropdown.change( + fn=controller.update_diffusion_transformer, + inputs=[diffusion_transformer_dropdown], + outputs=[diffusion_transformer_dropdown] + ) + + diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") + def refresh_diffusion_transformer(): + controller.refresh_diffusion_transformer() + return gr.update(choices=controller.diffusion_transformer_list) + diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown]) + + return diffusion_transformer_dropdown, diffusion_transformer_refresh_button + +def create_fake_model_checkpoints(model_name, visible): + gr.Markdown( + """ + ### Model checkpoints (模型路径). + """ + ) + with gr.Row(visible=visible): + diffusion_transformer_dropdown = gr.Dropdown( + label="Pretrained Model Path (预训练模型路径)", + choices=[model_name], + value=model_name, + interactive=False, + ) + return diffusion_transformer_dropdown + +def create_finetune_models_checkpoints(controller, visible, add_checkpoint_2=False): + with gr.Row(visible=visible): + base_model_dropdown = gr.Dropdown( + label="Select base Dreambooth model (选择基模型[非必需])", + choices=["none"] + controller.personalized_model_list, + value="none", + interactive=True, + ) + if add_checkpoint_2: + base_model_2_dropdown = gr.Dropdown( + label="Select base Dreambooth model (选择第二个基模型[非必需])", + choices=["none"] + controller.personalized_model_list, + value="none", + interactive=True, + ) + + lora_model_dropdown = gr.Dropdown( + label="Select LoRA model (选择LoRA模型[非必需])", + choices=["none"] + controller.personalized_model_list, + value="none", + interactive=True, + ) + if add_checkpoint_2: + lora_model_2_dropdown = gr.Dropdown( + label="Select LoRA model (选择LoRA模型[非必需])", + choices=["none"] + controller.personalized_model_list, + value="none", + interactive=True, + ) + + lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True) + + personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") + def update_personalized_model(): + controller.refresh_personalized_model() + return [ + gr.update(choices=controller.personalized_model_list), + gr.update(choices=["none"] + controller.personalized_model_list) + ] + personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]) + + if not add_checkpoint_2: + return base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button + else: + return [base_model_dropdown, base_model_2_dropdown], [lora_model_dropdown, lora_model_2_dropdown], \ + lora_alpha_slider, personalized_refresh_button + +def create_fake_finetune_models_checkpoints(visible): + with gr.Row(): + base_model_dropdown = gr.Dropdown( + label="Select base Dreambooth model (选择基模型[非必需])", + choices=["none"], + value="none", + interactive=False, + visible=False + ) + with gr.Column(visible=False): + gr.Markdown( + """ + ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/CogVideoX-Fun/wiki/Training-Lora). + """ + ) + with gr.Row(): + lora_model_dropdown = gr.Dropdown( + label="Select LoRA model", + choices=["none"], + value="none", + interactive=True, + ) + + lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True) + + return base_model_dropdown, lora_model_dropdown, lora_alpha_slider + +def create_teacache_params( + enable_teacache = True, + teacache_threshold = 0.10, + num_skip_start_steps = 1, + teacache_offload = False, +): + enable_teacache = gr.Checkbox(label="Enable TeaCache", value=enable_teacache) + teacache_threshold = gr.Slider(0.00, 0.25, value=teacache_threshold, step=0.01, label="TeaCache Threshold") + num_skip_start_steps = gr.Slider(0, 10, value=num_skip_start_steps, step=5, label="Number of Skip Start Steps") + teacache_offload = gr.Checkbox(label="Offload TeaCache to CPU", value=teacache_offload) + return enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload + +def create_cfg_skip_params( + cfg_skip_ratio = 0 +): + cfg_skip_ratio = gr.Slider(0.00, 0.50, value=cfg_skip_ratio, step=0.01, label="CFG Skip Ratio") + return cfg_skip_ratio + +def create_cfg_riflex_k( + enable_riflex = False, + riflex_k = 6 +): + enable_riflex = gr.Checkbox(label="Enable Riflex", value=enable_riflex) + riflex_k = gr.Slider(0, 10, value=riflex_k, step=1, label="Riflex Intrinsic Frequency Index") + return enable_riflex, riflex_k + +def create_prompts( + prompt="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", + negative_prompt="The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. " +): + gr.Markdown( + """ + ### Configs for Generation (生成参数配置). + """ + ) + + prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value=prompt) + negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value=negative_prompt) + return prompt_textbox, negative_prompt_textbox + +def create_samplers(controller, maximum_step=100): + with gr.Row(): + sampler_dropdown = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(controller.scheduler_dict.keys()), value=list(controller.scheduler_dict.keys())[0]) + sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=50, minimum=10, maximum=maximum_step, step=1) + + return sampler_dropdown, sample_step_slider + +def create_height_width(default_height, default_width, maximum_height, maximum_width): + resize_method = gr.Radio( + ["Generate by", "Resize according to Reference"], + value="Generate by", + show_label=False, + ) + width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16) + height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16) + base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 640, 768, 896, 960, 1024], visible=False) + + return resize_method, width_slider, height_slider, base_resolution + +def create_fake_height_width(default_height, default_width, maximum_height, maximum_width): + resize_method = gr.Radio( + ["Generate by", "Resize according to Reference"], + value="Generate by", + show_label=False, + ) + width_slider = gr.Slider(label="Width (视频宽度)", value=default_width, minimum=128, maximum=maximum_width, step=16, interactive=False) + height_slider = gr.Slider(label="Height (视频高度)", value=default_height, minimum=128, maximum=maximum_height, step=16, interactive=False) + base_resolution = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 640, 768, 896, 960, 1024], interactive=False, visible=False) + + return resize_method, width_slider, height_slider, base_resolution + +def create_generation_methods_and_video_length( + generation_method_options, + default_video_length, + maximum_video_length +): + with gr.Group(): + generation_method = gr.Radio( + generation_method_options, + value="Video Generation", + show_label=False, + ) + with gr.Row(): + length_slider = gr.Slider(label="Animation length (视频帧数)", value=default_video_length, minimum=1, maximum=maximum_video_length, step=4) + overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1, maximum=4, step=1, visible=False) + partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=25, minimum=5, maximum=maximum_video_length, step=4, visible=False) + + return generation_method, length_slider, overlap_video_length, partial_video_length + +def create_generation_method(source_method_options, prompt_textbox, support_end_image=True, support_ref_image=False): + source_method = gr.Radio( + source_method_options, + value="Text to Video (文本到视频)", + show_label=False, + ) + with gr.Column(visible = False) as image_to_video_col: + start_image = gr.Image( + label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, + elem_id="i2v_start", sources="upload", type="filepath", + ) + + template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"] + def select_template(evt: gr.SelectData): + text = { + "asset/1.png": "A brown dog is shaking its head and sitting on a light colored sofa in a comfortable room. Behind the dog, there is a framed painting on the shelf surrounded by pink flowers. The soft and warm lighting in the room creates a comfortable atmosphere.", + "asset/2.png": "A sailboat navigates through moderately rough seas, with waves and ocean spray visible. The sailboat features a white hull and sails, accompanied by an orange sail catching the wind. The sky above shows dramatic, cloudy formations with a sunset or sunrise backdrop, casting warm colors across the scene. The water reflects the golden light, enhancing the visual contrast between the dark ocean and the bright horizon. The camera captures the scene with a dynamic and immersive angle, showcasing the movement of the boat and the energy of the ocean.", + "asset/3.png": "A stunningly beautiful woman with flowing long hair stands gracefully, her elegant dress rippling and billowing in the gentle wind. Petals falling off. Her serene expression and the natural movement of her attire create an enchanting and captivating scene, full of ethereal charm.", + "asset/4.png": "An astronaut, clad in a full space suit with a helmet, plays an electric guitar while floating in a cosmic environment filled with glowing particles and rocky textures. The scene is illuminated by a warm light source, creating dramatic shadows and contrasts. The background features a complex geometry, similar to a space station or an alien landscape, indicating a futuristic or otherworldly setting.", + "asset/5.png": "Fireworks light up the evening sky over a sprawling cityscape with gothic-style buildings featuring pointed towers and clock faces. The city is lit by both artificial lights from the buildings and the colorful bursts of the fireworks. The scene is viewed from an elevated angle, showcasing a vibrant urban environment set against a backdrop of a dramatic, partially cloudy sky at dusk.", + }[template_gallery_path[evt.index]] + return template_gallery_path[evt.index], text + + template_gallery = gr.Gallery( + template_gallery_path, + columns=5, rows=1, + height=140, + allow_preview=False, + container=False, + label="Template Examples", + ) + template_gallery.select(select_template, None, [start_image, prompt_textbox]) + + with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False, visible=support_end_image): + end_image = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath") + + with gr.Column(visible = False) as video_to_video_col: + with gr.Row(): + validation_video = gr.Video( + label="The video to convert (视频转视频的参考视频)", show_label=True, + elem_id="v2v", sources="upload", + ) + with gr.Accordion("The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", open=False): + gr.Markdown( + """ + - Please set a larger denoise_strength when using validation_video_mask, such as 1.00 instead of 0.70 + (请设置更大的denoise_strength,当使用validation_video_mask的时候,比如1而不是0.70) + """ + ) + validation_video_mask = gr.Image( + label="The mask of the video to inpaint (视频重新绘制的mask[非必需, Optional])", + show_label=False, elem_id="v2v_mask", sources="upload", type="filepath" + ) + denoise_strength = gr.Slider(label="Denoise strength (重绘系数)", value=0.70, minimum=0.10, maximum=1.00, step=0.01) + + with gr.Column(visible = False) as control_video_col: + gr.Markdown( + """ + Demo pose control video can be downloaded here [URL](https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/cogvideox_fun/asset/v1.1/pose.mp4). + """ + ) + control_video = gr.Video( + label="The control video (用于提供控制信号的video)", show_label=True, + elem_id="v2v_control", sources="upload", + ) + ref_image = gr.Image( + label="The reference image for control video (控制视频的参考图片)", show_label=True, + elem_id="ref_image", sources="upload", type="filepath", visible=support_ref_image + ) + return image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image + +def create_cfg_and_seedbox(gradio_version_is_above_4): + cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)", value=6.0, minimum=0, maximum=20) + + with gr.Row(): + seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43) + seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") + seed_button.click( + fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)), + inputs=[], + outputs=[seed_textbox] + ) + return cfg_scale_slider, seed_textbox, seed_button + +def create_ui_outputs(): + with gr.Column(): + result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False) + result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False) + infer_progress = gr.Textbox( + label="Generation Info (生成信息)", + value="No task currently", + interactive=False + ) + return result_image, result_video, infer_progress + +def create_config(controller): + gr.Markdown( + """ + ### Config Path (配置文件路径) + """ + ) + with gr.Row(): + config_dropdown = gr.Dropdown( + label="Config Path (配置文件路径)", + choices=controller.config_list, + value=controller.config_path, + interactive=True, + ) + config_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") + def refresh_config(): + controller.refresh_config() + return gr.update(choices=controller.config_list) + config_refresh_button.click(fn=refresh_config, inputs=[], outputs=[config_dropdown]) + return config_dropdown, config_refresh_button \ No newline at end of file diff --git a/videox_fun/ui/wan2_2_fun_ui.py b/videox_fun/ui/wan2_2_fun_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..0e4a07aa63aed9d987ad91487d8fdbdb00359bdd --- /dev/null +++ b/videox_fun/ui/wan2_2_fun_ui.py @@ -0,0 +1,803 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..dist import set_multi_gpus_devices, shard_model +from ..models import (AutoencoderKLWan, AutoencoderKLWan3_8, AutoTokenizer, + CLIPModel, Wan2_2Transformer3DModel, WanT5EncoderModel) +from ..models.cache_utils import get_teacache_coefficients +from ..pipeline import Wan2_2FunControlPipeline, Wan2_2FunPipeline, Wan2_2FunInpaintPipeline +from ..utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_latent, + get_image_to_video_latent, + get_video_to_video_latent, save_videos_grid, timer) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k, + create_cfg_skip_params, create_config, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_teacache_params, create_ui_outputs) + + +class Wan2_2_Fun_Controller(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.model_name = diffusion_transformer_dropdown + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + Chosen_AutoencoderKL = { + "AutoencoderKLWan": AutoencoderKLWan, + "AutoencoderKLWan3_8": AutoencoderKLWan3_8 + }[self.config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] + self.vae = Chosen_AutoencoderKL.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']), + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = Wan2_2Transformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + if self.config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe": + self.transformer_2 = Wan2_2Transformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + else: + self.transformer_2 = None + + # Get Tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + # Get Text encoder + self.text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + self.text_encoder = self.text_encoder.eval() + + Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]] + self.scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs'])) + ) + + # Get pipeline + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = Wan2_2FunInpaintPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + self.pipeline = Wan2_2FunPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + self.pipeline = Wan2_2FunControlPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.transformer_2 is not None: + self.transformer_2.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + if self.transformer_2 is not None: + self.pipeline.transformer_2 = shard_fn(self.pipeline.transformer_2) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.blocks)): + self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i]) + if self.transformer_2 is not None: + for i in range(len(self.pipeline.transformer_2.blocks)): + self.pipeline.transformer_2.blocks[i] = torch.compile(self.pipeline.transformer_2.blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(self.transformer, ["modulation",], device=self.device) + self.transformer.freqs = self.transformer.freqs.to(device=self.device) + if self.transformer_2 is not None: + replace_parameters_by_name(self.transformer_2, ["modulation",], device=self.device) + self.transformer_2.freqs = self.transformer_2.freqs.to(device=self.device) + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + if self.transformer_2 is not None: + convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + if self.transformer_2 is not None: + convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + if self.base_model_2_path != base_model_2_dropdown: + self.update_lora_model(base_model_2_dropdown, is_checkpoint_2=True) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + if self.lora_model_2_path != lora_model_2_dropdown: + self.update_lora_model(lora_model_2_dropdown, is_checkpoint_2=True) + + print(f"Load scheduler.") + scheduler_config = self.pipeline.scheduler.config + if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++": + scheduler_config['shift'] = 1 + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if self.transformer_2 is not None: + self.pipeline = merge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2") + print(f"Merge Lora done.") + + coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None + if coefficients is not None: + print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") + self.pipeline.transformer.enable_teacache( + coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload + ) + if self.transformer_2 is not None: + self.pipeline.transformer_2.share_teacache(self.pipeline.transformer) + else: + print(f"Disable TeaCache.") + self.pipeline.transformer.disable_teacache() + if self.transformer_2 is not None: + self.pipeline.transformer_2.disable_teacache() + + if cfg_skip_ratio is not None and cfg_skip_ratio >= 0: + print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") + self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider) + if self.transformer_2 is not None: + self.pipeline.transformer_2.share_cfg_skip(self.pipeline.transformer) + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + if fps is None: + fps = 16 + boundary = self.config['transformer_additional_kwargs'].get('boundary', 0.875) + + if enable_riflex: + print(f"Enable riflex") + latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1 + self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + if self.transformer_2 is not None: + self.pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if validation_video is not None: + input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + boundary = boundary + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + boundary = boundary + ).videos + else: + inpaint_video, inpaint_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + if ref_image is not None: + ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider)) + + input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = inpaint_video, + mask_video = inpaint_video_mask, + control_video = input_video, + ref_image = ref_image, + boundary = boundary, + ).videos + print(f"Generation done.") + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +Wan2_2_Fun_Controller_Host = Wan2_2_Fun_Controller +Wan2_2_Fun_Controller_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan2_2_Fun_Controller( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2.2-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/VideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + config_dropdown, config_refresh_button = create_config(controller) + model_type = create_model_type(visible=True) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + config_dropdown.change( + fn=controller.update_config, + inputs=[config_dropdown], + outputs=[] + ) + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan2_2_Fun_Controller_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2.2-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/VideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \ + create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = Wan2_2_Fun_Controller_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2.2-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/VideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \ + create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/ui/wan2_2_ui.py b/videox_fun/ui/wan2_2_ui.py new file mode 100644 index 0000000000000000000000000000000000000000..4fcf81c7960a2cdd2e1b1f5c8605fea8a58dfc9d --- /dev/null +++ b/videox_fun/ui/wan2_2_ui.py @@ -0,0 +1,797 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..models import (AutoencoderKLWan, AutoencoderKLWan3_8, AutoTokenizer, CLIPModel, + WanT5EncoderModel, Wan2_2Transformer3DModel) +from ..models.cache_utils import get_teacache_coefficients +from ..pipeline import Wan2_2I2VPipeline, Wan2_2Pipeline, Wan2_2TI2VPipeline +from ..utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer, + get_video_to_video_latent, save_videos_grid) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k, + create_cfg_skip_params, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_teacache_params, create_ui_outputs, create_config) +from ..dist import set_multi_gpus_devices, shard_model + + +class Wan2_2_Controller(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.model_name = diffusion_transformer_dropdown + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + Chosen_AutoencoderKL = { + "AutoencoderKLWan": AutoencoderKLWan, + "AutoencoderKLWan3_8": AutoencoderKLWan3_8 + }[self.config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] + self.vae = Chosen_AutoencoderKL.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']), + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = Wan2_2Transformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + if self.config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe": + self.transformer_2 = Wan2_2Transformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + else: + self.transformer_2 = None + + # Get Tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + # Get Text encoder + self.text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + self.text_encoder = self.text_encoder.eval() + + Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]] + self.scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs'])) + ) + + # Get pipeline + if self.model_type == "Inpaint": + if "wan_civitai_5b" in self.config_path: + self.pipeline = Wan2_2TI2VPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = Wan2_2I2VPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + self.pipeline = Wan2_2Pipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + transformer_2=self.transformer_2, + scheduler=self.scheduler, + ) + else: + raise ValueError("Not support now") + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.transformer_2 is not None: + self.transformer_2.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + if self.transformer_2 is not None: + self.pipeline.transformer_2 = shard_fn(self.pipeline.transformer_2) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.blocks)): + self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i]) + if self.transformer_2 is not None: + for i in range(len(self.pipeline.transformer_2.blocks)): + self.pipeline.transformer_2.blocks[i] = torch.compile(self.pipeline.transformer_2.blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(self.transformer, ["modulation",], device=self.device) + self.transformer.freqs = self.transformer.freqs.to(device=self.device) + if self.transformer_2 is not None: + replace_parameters_by_name(self.transformer_2, ["modulation",], device=self.device) + self.transformer_2.freqs = self.transformer_2.freqs.to(device=self.device) + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + if self.transformer_2 is not None: + convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + if self.transformer_2 is not None: + convert_model_weight_to_float8(self.transformer_2, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer_2, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + if self.base_model_2_path != base_model_2_dropdown: + self.update_lora_model(base_model_2_dropdown, is_checkpoint_2=True) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + if self.lora_model_2_path != lora_model_2_dropdown: + self.update_lora_model(lora_model_2_dropdown, is_checkpoint_2=True) + + print(f"Load scheduler.") + scheduler_config = self.pipeline.scheduler.config + if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++": + scheduler_config['shift'] = 1 + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if self.transformer_2 is not None: + self.pipeline = merge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2") + print(f"Merge Lora done.") + + coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None + if coefficients is not None: + print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") + self.pipeline.transformer.enable_teacache( + coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload + ) + if self.transformer_2 is not None: + self.pipeline.transformer_2.share_teacache(self.pipeline.transformer) + else: + print(f"Disable TeaCache.") + self.pipeline.transformer.disable_teacache() + if self.transformer_2 is not None: + self.pipeline.transformer_2.disable_teacache() + + if cfg_skip_ratio is not None and cfg_skip_ratio >= 0: + print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") + self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider) + if self.transformer_2 is not None: + self.pipeline.transformer_2.share_cfg_skip(self.pipeline.transformer) + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + if fps is None: + fps = 16 + boundary = self.config['transformer_additional_kwargs'].get('boundary', 0.875) + + if enable_riflex: + print(f"Enable riflex") + latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1 + self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + if self.transformer_2 is not None: + self.pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if validation_video is not None: + input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + boundary = boundary + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + boundary = boundary + ).videos + else: + if ref_image is not None: + ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider)) + + if start_image is not None: + start_image = get_image_latent(start_image, sample_size=(height_slider, width_slider)) + + input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + control_video = input_video, + ref_image = ref_image, + start_image = start_image, + boundary = boundary + ).videos + print(f"Generation done.") + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if self.transformer_2 is not None: + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2") + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if self.transformer_2 is not None: + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_2_path, multiplier=lora_alpha_slider, sub_transformer_name="transformer_2") + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +Wan2_2_Controller_Host = Wan2_2_Controller +Wan2_2_Controller_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan2_2_Controller( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2_2: + """ + ) + with gr.Column(variant="panel"): + config_dropdown, config_refresh_button = create_config(controller) + model_type = create_model_type(visible=False) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox, support_end_image=False + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + config_dropdown.change( + fn=controller.update_config, + inputs=[config_dropdown], + outputs=[] + ) + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan2_2_Controller_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2_2: + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \ + create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = Wan2_2_Controller_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan2_2: + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = \ + create_fake_finetune_models_checkpoints(visible=True, add_checkpoint_2=True) + base_model_dropdown, base_model_2_dropdown = base_model_dropdown + lora_model_dropdown, lora_model_2_dropdown = lora_model_dropdown + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + base_model_2_dropdown, + lora_model_2_dropdown + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/ui/wan_fun_ui.py b/videox_fun/ui/wan_fun_ui.py new file mode 100755 index 0000000000000000000000000000000000000000..315bde87c347e46a3e9a2cb674cec1d153e1bc04 --- /dev/null +++ b/videox_fun/ui/wan_fun_ui.py @@ -0,0 +1,752 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..models.cache_utils import get_teacache_coefficients +from ..pipeline import (WanFunControlPipeline, WanFunInpaintPipeline, + WanFunPipeline) +from ..utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer, + get_video_to_video_latent, save_videos_grid) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k, + create_cfg_skip_params, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_teacache_params, create_ui_outputs) +from ..dist import set_multi_gpus_devices, shard_model + + +class Wan_Fun_Controller(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.model_name = diffusion_transformer_dropdown + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + self.vae = AutoencoderKLWan.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']), + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = WanTransformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + + # Get Tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + # Get Text encoder + self.text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + self.text_encoder = self.text_encoder.eval() + + if self.transformer.config.in_channels != self.vae.config.latent_channels: + # Get Clip Image Encoder + self.clip_image_encoder = CLIPModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), + ).to(self.weight_dtype) + self.clip_image_encoder = self.clip_image_encoder.eval() + else: + self.clip_image_encoder = None + + Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]] + self.scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs'])) + ) + + # Get pipeline + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = WanFunInpaintPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + clip_image_encoder=self.clip_image_encoder, + ) + else: + self.pipeline = WanFunPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + ) + else: + self.pipeline = WanFunControlPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + clip_image_encoder=self.clip_image_encoder, + ) + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.blocks)): + self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(self.transformer, ["modulation",], device=self.device) + self.transformer.freqs = self.transformer.freqs.to(device=self.device) + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + + print(f"Load scheduler.") + scheduler_config = self.pipeline.scheduler.config + if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++": + scheduler_config['shift'] = 1 + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Merge Lora done.") + + coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None + if coefficients is not None: + print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") + self.pipeline.transformer.enable_teacache( + coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload + ) + else: + print(f"Disable TeaCache.") + self.pipeline.transformer.disable_teacache() + + if cfg_skip_ratio is not None and cfg_skip_ratio >= 0: + print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") + self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider) + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + if fps is None: + fps = 16 + + if enable_riflex: + print(f"Enable riflex") + latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1 + self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if validation_video is not None: + input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + clip_image = clip_image, + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + ).videos + else: + if ref_image is not None: + clip_image = Image.open(ref_image).convert("RGB") + elif start_image is not None: + clip_image = Image.open(start_image).convert("RGB") + else: + clip_image = None + + if ref_image is not None: + ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider)) + + if start_image is not None: + start_image = get_image_latent(start_image, sample_size=(height_slider, width_slider)) + + input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + control_video = input_video, + ref_image = ref_image, + start_image = start_image, + clip_image = clip_image, + ).videos + print(f"Generation done.") + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +Wan_Fun_Controller_Host = Wan_Fun_Controller +Wan_Fun_Controller_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan_Fun_Controller( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_model_type(visible=True) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan_Fun_Controller_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)", "Video Control (视频控制)"], prompt_textbox, support_ref_image=True + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = Wan_Fun_Controller_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan-Fun: + + A Wan with more flexible generation conditions, capable of producing videos of different resolutions, around 5 seconds, and fps 16 (frames 1 to 81), as well as image generated videos. + + [Github](https://github.com/aigc-apps/CogVideoX-Fun/) + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/ui/wan_ui.py b/videox_fun/ui/wan_ui.py new file mode 100755 index 0000000000000000000000000000000000000000..5e1e89d06af066dc777f75791d7ccf1142d1a019 --- /dev/null +++ b/videox_fun/ui/wan_ui.py @@ -0,0 +1,732 @@ +"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py +""" +import os +import random + +import cv2 +import gradio as gr +import numpy as np +import torch +from omegaconf import OmegaConf +from PIL import Image +from safetensors import safe_open + +from ..data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio +from ..models import (AutoencoderKLWan, AutoTokenizer, CLIPModel, + WanT5EncoderModel, WanTransformer3DModel) +from ..models.cache_utils import get_teacache_coefficients +from ..pipeline import WanI2VPipeline, WanPipeline +from ..utils.fp8_optimization import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from ..utils.lora_utils import merge_lora, unmerge_lora +from ..utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, timer, + get_video_to_video_latent, save_videos_grid) +from .controller import (Fun_Controller, Fun_Controller_Client, + all_cheduler_dict, css, ddpm_scheduler_dict, + flow_scheduler_dict, gradio_version, + gradio_version_is_above_4) +from .ui import (create_cfg_and_seedbox, create_cfg_riflex_k, + create_cfg_skip_params, + create_fake_finetune_models_checkpoints, + create_fake_height_width, create_fake_model_checkpoints, + create_fake_model_type, create_finetune_models_checkpoints, + create_generation_method, + create_generation_methods_and_video_length, + create_height_width, create_model_checkpoints, + create_model_type, create_prompts, create_samplers, + create_teacache_params, create_ui_outputs) +from ..dist import set_multi_gpus_devices, shard_model + + +class Wan_Controller(Fun_Controller): + def update_diffusion_transformer(self, diffusion_transformer_dropdown): + print(f"Update diffusion transformer: {diffusion_transformer_dropdown}") + self.model_name = diffusion_transformer_dropdown + self.diffusion_transformer_dropdown = diffusion_transformer_dropdown + if diffusion_transformer_dropdown == "none": + return gr.update() + self.vae = AutoencoderKLWan.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['vae_kwargs'].get('vae_subpath', 'vae')), + additional_kwargs=OmegaConf.to_container(self.config['vae_kwargs']), + ).to(self.weight_dtype) + + # Get Transformer + self.transformer = WanTransformer3DModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), + transformer_additional_kwargs=OmegaConf.to_container(self.config['transformer_additional_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + + # Get Tokenizer + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), + ) + + # Get Text encoder + self.text_encoder = WanT5EncoderModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), + additional_kwargs=OmegaConf.to_container(self.config['text_encoder_kwargs']), + low_cpu_mem_usage=True, + torch_dtype=self.weight_dtype, + ) + self.text_encoder = self.text_encoder.eval() + + if self.transformer.config.in_channels != self.vae.config.latent_channels: + # Get Clip Image Encoder + self.clip_image_encoder = CLIPModel.from_pretrained( + os.path.join(diffusion_transformer_dropdown, self.config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), + ).to(self.weight_dtype) + self.clip_image_encoder = self.clip_image_encoder.eval() + else: + self.clip_image_encoder = None + + Chosen_Scheduler = self.scheduler_dict[list(self.scheduler_dict.keys())[0]] + self.scheduler = Chosen_Scheduler( + **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(self.config['scheduler_kwargs'])) + ) + + # Get pipeline + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + self.pipeline = WanI2VPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + clip_image_encoder=self.clip_image_encoder, + ) + else: + self.pipeline = WanPipeline( + vae=self.vae, + tokenizer=self.tokenizer, + text_encoder=self.text_encoder, + transformer=self.transformer, + scheduler=self.scheduler, + ) + else: + raise ValueError("Not support now") + + if self.ulysses_degree > 1 or self.ring_degree > 1: + from functools import partial + self.transformer.enable_multi_gpus_inference() + if self.fsdp_dit: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.transformer = shard_fn(self.pipeline.transformer) + print("Add FSDP DIT") + if self.fsdp_text_encoder: + shard_fn = partial(shard_model, device_id=self.device, param_dtype=self.weight_dtype) + self.pipeline.text_encoder = shard_fn(self.pipeline.text_encoder) + print("Add FSDP TEXT ENCODER") + + if self.compile_dit: + for i in range(len(self.pipeline.transformer.blocks)): + self.pipeline.transformer.blocks[i] = torch.compile(self.pipeline.transformer.blocks[i]) + print("Add Compile") + + if self.GPU_memory_mode == "sequential_cpu_offload": + replace_parameters_by_name(self.transformer, ["modulation",], device=self.device) + self.transformer.freqs = self.transformer.freqs.to(device=self.device) + self.pipeline.enable_sequential_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_cpu_offload": + self.pipeline.enable_model_cpu_offload(device=self.device) + elif self.GPU_memory_mode == "model_full_load_and_qfloat8": + convert_model_weight_to_float8(self.transformer, exclude_module_name=["modulation",], device=self.device) + convert_weight_dtype_wrapper(self.transformer, self.weight_dtype) + self.pipeline.to(self.device) + else: + self.pipeline.to(self.device) + print("Update diffusion transformer done") + return gr.update() + + @timer + def generate( + self, + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image = None, + enable_teacache = None, + teacache_threshold = None, + num_skip_start_steps = None, + teacache_offload = None, + cfg_skip_ratio = None, + enable_riflex = None, + riflex_k = None, + base_model_2_dropdown=None, + lora_model_2_dropdown=None, + fps = None, + is_api = False, + ): + self.clear_cache() + + print(f"Input checking.") + _, comment = self.input_check( + resize_method, generation_method, start_image, end_image, validation_video,control_video, is_api + ) + print(f"Input checking down") + if comment != "OK": + return "", comment + is_image = True if generation_method == "Image Generation" else False + + if self.base_model_path != base_model_dropdown: + self.update_base_model(base_model_dropdown) + + if self.lora_model_path != lora_model_dropdown: + self.update_lora_model(lora_model_dropdown) + + print(f"Load scheduler.") + scheduler_config = self.pipeline.scheduler.config + if sampler_dropdown == "Flow_Unipc" or sampler_dropdown == "Flow_DPM++": + scheduler_config['shift'] = 1 + self.pipeline.scheduler = self.scheduler_dict[sampler_dropdown].from_config(scheduler_config) + print(f"Load scheduler down.") + + if resize_method == "Resize according to Reference": + print(f"Calculate height and width according to Reference.") + height_slider, width_slider = self.get_height_width_from_reference( + base_resolution, start_image, validation_video, control_video, + ) + + if self.lora_model_path != "none": + print(f"Merge Lora.") + self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Merge Lora done.") + + coefficients = get_teacache_coefficients(self.diffusion_transformer_dropdown) if enable_teacache else None + if coefficients is not None: + print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") + self.pipeline.transformer.enable_teacache( + coefficients, sample_step_slider, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload + ) + else: + print(f"Disable TeaCache.") + self.pipeline.transformer.disable_teacache() + + if cfg_skip_ratio is not None and cfg_skip_ratio >= 0: + print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") + self.pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, sample_step_slider) + + print(f"Generate seed.") + if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: seed_textbox = np.random.randint(0, 1e10) + generator = torch.Generator(device=self.device).manual_seed(int(seed_textbox)) + print(f"Generate seed done.") + + if fps is None: + fps = 16 + + if enable_riflex: + print(f"Enable riflex") + latent_frames = (int(length_slider) - 1) // self.vae.config.temporal_compression_ratio + 1 + self.pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames if not is_image else 1) + + try: + print(f"Generation.") + if self.model_type == "Inpaint": + if self.transformer.config.in_channels != self.vae.config.latent_channels: + if validation_video is not None: + input_video, input_video_mask, _, clip_image = get_video_to_video_latent(validation_video, length_slider if not is_image else 1, sample_size=(height_slider, width_slider), validation_video_mask=validation_video_mask, fps=fps) + else: + input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider)) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + video = input_video, + mask_video = input_video_mask, + clip_image = clip_image, + ).videos + else: + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + ).videos + else: + if ref_image is not None: + clip_image = Image.open(ref_image).convert("RGB") + elif start_image is not None: + clip_image = Image.open(start_image).convert("RGB") + else: + clip_image = None + + if ref_image is not None: + ref_image = get_image_latent(ref_image, sample_size=(height_slider, width_slider)) + + if start_image is not None: + start_image = get_image_latent(start_image, sample_size=(height_slider, width_slider)) + + input_video, input_video_mask, _, _ = get_video_to_video_latent(control_video, video_length=length_slider if not is_image else 1, sample_size=(height_slider, width_slider), fps=fps, ref_image=None) + + sample = self.pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + num_frames = length_slider if not is_image else 1, + generator = generator, + + control_video = input_video, + ref_image = ref_image, + start_image = start_image, + clip_image = clip_image, + ).videos + print(f"Generation done.") + except Exception as e: + self.auto_model_clear_cache(self.pipeline.transformer) + self.auto_model_clear_cache(self.pipeline.text_encoder) + self.auto_model_clear_cache(self.pipeline.vae) + self.clear_cache() + + print(f"Error. error information is {str(e)}") + if self.lora_model_path != "none": + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + if is_api: + return "", f"Error. error information is {str(e)}" + else: + return gr.update(), gr.update(), f"Error. error information is {str(e)}" + + self.clear_cache() + # lora part + if self.lora_model_path != "none": + print(f"Unmerge Lora.") + self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider) + print(f"Unmerge Lora done.") + + print(f"Saving outputs.") + save_sample_path = self.save_outputs( + is_image, length_slider, sample, fps=fps + ) + print(f"Saving outputs done.") + + if is_image or length_slider == 1: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success" + else: + return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success" + else: + if is_api: + return save_sample_path, "Success" + else: + if gradio_version_is_above_4: + return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success" + else: + return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success" + +Wan_Controller_Host = Wan_Controller +Wan_Controller_Client = Fun_Controller_Client + +def ui(GPU_memory_mode, scheduler_dict, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan_Controller( + GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint", + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan: + """ + ) + with gr.Column(variant="panel"): + model_type = create_model_type(visible=False) + diffusion_transformer_dropdown, diffusion_transformer_refresh_button = \ + create_model_checkpoints(controller, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider, personalized_refresh_button = \ + create_finetune_models_checkpoints(controller, visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox, support_end_image=False + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + model_type.change( + fn=controller.update_model_type, + inputs=[model_type], + outputs=[] + ) + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return [gr.update(visible=True, maximum=161, value=81, interactive=True), gr.update(visible=False), gr.update(visible=False)] + elif generation_method == "Image Generation": + return [gr.update(minimum=1, maximum=1, value=1, interactive=False), gr.update(visible=False), gr.update(visible=False)] + else: + return [gr.update(visible=True, maximum=1344), gr.update(visible=True), gr.update(visible=True)] + generation_method.change( + upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_host(GPU_memory_mode, scheduler_dict, model_name, model_type, config_path, compile_dit, weight_dtype, savedir_sample=None): + controller = Wan_Controller_Host( + GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, + config_path=config_path, compile_dit=compile_dit, + weight_dtype=weight_dtype, savedir_sample=savedir_sample, + ) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan: + """ + ) + with gr.Column(variant="panel"): + model_type = create_fake_model_type(visible=False) + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller) + + resize_method, width_slider, height_slider, base_resolution = create_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=1, maximum=161, value=81, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Video to Video (视频到视频)": + return [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(), gr.update(), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update()] + source_method.change( + upload_source_method, source_method, [ + image_to_video_col, video_to_video_col, control_video_col, start_image, end_image, + validation_video, validation_video_mask, control_video + ] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + overlap_video_length, + partial_video_length, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + control_video, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller + +def ui_client(scheduler_dict, model_name, savedir_sample=None): + controller = Wan_Controller_Client(scheduler_dict, savedir_sample) + + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # Wan: + """ + ) + with gr.Column(variant="panel"): + diffusion_transformer_dropdown = create_fake_model_checkpoints(model_name, visible=True) + base_model_dropdown, lora_model_dropdown, lora_alpha_slider = create_fake_finetune_models_checkpoints(visible=True) + + with gr.Row(): + enable_teacache, teacache_threshold, num_skip_start_steps, teacache_offload = \ + create_teacache_params(True, 0.10, 1, False) + cfg_skip_ratio = create_cfg_skip_params(0) + enable_riflex, riflex_k = create_cfg_riflex_k(False, 6) + + with gr.Column(variant="panel"): + prompt_textbox, negative_prompt_textbox = create_prompts(negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走") + + with gr.Row(): + with gr.Column(): + sampler_dropdown, sample_step_slider = create_samplers(controller, maximum_step=50) + + resize_method, width_slider, height_slider, base_resolution = create_fake_height_width( + default_height = 480, default_width = 832, maximum_height = 1344, + maximum_width = 1344, + ) + generation_method, length_slider, overlap_video_length, partial_video_length = \ + create_generation_methods_and_video_length( + ["Video Generation", "Image Generation"], + default_video_length=81, + maximum_video_length=161, + ) + image_to_video_col, video_to_video_col, control_video_col, source_method, start_image, template_gallery, end_image, validation_video, validation_video_mask, denoise_strength, control_video, ref_image = create_generation_method( + ["Text to Video (文本到视频)", "Image to Video (图片到视频)"], prompt_textbox + ) + + cfg_scale_slider, seed_textbox, seed_button = create_cfg_and_seedbox(gradio_version_is_above_4) + + generate_button = gr.Button(value="Generate (生成)", variant='primary') + + result_image, result_video, infer_progress = create_ui_outputs() + + def upload_generation_method(generation_method): + if generation_method == "Video Generation": + return gr.update(visible=True, minimum=5, maximum=161, value=49, interactive=True) + elif generation_method == "Image Generation": + return gr.update(minimum=1, maximum=1, value=1, interactive=False) + generation_method.change( + upload_generation_method, generation_method, [length_slider] + ) + + def upload_source_method(source_method): + if source_method == "Text to Video (文本到视频)": + return [gr.update(visible=False), gr.update(visible=False), gr.update(value=None), gr.update(value=None), gr.update(value=None), gr.update(value=None)] + elif source_method == "Image to Video (图片到视频)": + return [gr.update(visible=True), gr.update(visible=False), gr.update(), gr.update(), gr.update(value=None), gr.update(value=None)] + else: + return [gr.update(visible=False), gr.update(visible=True), gr.update(value=None), gr.update(value=None), gr.update(), gr.update()] + source_method.change( + upload_source_method, source_method, [image_to_video_col, video_to_video_col, start_image, end_image, validation_video, validation_video_mask] + ) + + def upload_resize_method(resize_method): + if resize_method == "Generate by": + return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)] + else: + return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)] + resize_method.change( + upload_resize_method, resize_method, [width_slider, height_slider, base_resolution] + ) + + generate_button.click( + fn=controller.generate, + inputs=[ + diffusion_transformer_dropdown, + base_model_dropdown, + lora_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + resize_method, + width_slider, + height_slider, + base_resolution, + generation_method, + length_slider, + cfg_scale_slider, + start_image, + end_image, + validation_video, + validation_video_mask, + denoise_strength, + seed_textbox, + ref_image, + enable_teacache, + teacache_threshold, + num_skip_start_steps, + teacache_offload, + cfg_skip_ratio, + enable_riflex, + riflex_k, + ], + outputs=[result_image, result_video, infer_progress] + ) + return demo, controller \ No newline at end of file diff --git a/videox_fun/utils/__init__.py b/videox_fun/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..009df372a0a1eeded219789d663baedb34371fca --- /dev/null +++ b/videox_fun/utils/__init__.py @@ -0,0 +1,42 @@ +import importlib.util + +from .fm_solvers import FlowDPMSolverMultistepScheduler +from .fm_solvers_unipc import FlowUniPCMultistepScheduler +from .fp8_optimization import (autocast_model_forward, + convert_model_weight_to_float8, + convert_weight_dtype_wrapper, + replace_parameters_by_name) +from .lora_utils import merge_lora, unmerge_lora +from .utils import (filter_kwargs, get_image_latent, get_image_to_video_latent, get_autocast_dtype, + get_video_to_video_latent, save_videos_grid) +from .cfg_optimization import cfg_skip +from .discrete_sampler import DiscreteSampling + + +# The pai_fuser is an internally developed acceleration package, which can be used on PAI. +if importlib.util.find_spec("paifuser") is not None: + # --------------------------------------------------------------- # + # FP8 Linear Kernel + # --------------------------------------------------------------- # + from paifuser.ops import (convert_model_weight_to_float8, + convert_weight_dtype_wrapper) + from . import fp8_optimization + fp8_optimization.convert_model_weight_to_float8 = convert_model_weight_to_float8 + fp8_optimization.convert_weight_dtype_wrapper = convert_weight_dtype_wrapper + convert_model_weight_to_float8 = fp8_optimization.convert_model_weight_to_float8 + convert_weight_dtype_wrapper = fp8_optimization.convert_weight_dtype_wrapper + print("Import PAI Quantization Turbo") + + # --------------------------------------------------------------- # + # CFG Skip Turbo + # --------------------------------------------------------------- # + if importlib.util.find_spec("paifuser.accelerator") is not None: + from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip, + enable_cfg_skip, share_cfg_skip) + else: + from paifuser import (cfg_skip_turbo, disable_cfg_skip, + enable_cfg_skip, share_cfg_skip) + from . import cfg_optimization + cfg_optimization.cfg_skip = cfg_skip_turbo + cfg_skip = cfg_skip_turbo + print("Import CFG Skip Turbo") \ No newline at end of file diff --git a/videox_fun/utils/ac_handle.py b/videox_fun/utils/ac_handle.py new file mode 100644 index 0000000000000000000000000000000000000000..91df98a1e53a79ac18a0b9579656854e0f4f6042 --- /dev/null +++ b/videox_fun/utils/ac_handle.py @@ -0,0 +1,64 @@ +from functools import partial + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) + + +non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, +) + + +def apply_checkpointing(model, block, p): + """ + Apply selective activation checkpointing. + + Selectivity is defined as a percentage p, which means we apply ac + on p of the total blocks. p is a floating number in the range of + [0, 1]. + + Some examples: + p = 0: no ac for all blocks. same as `fsdp_activation_checkpointing=False` + p = 1: apply ac on every block. i.e. "full ac". + p = 1/2: [ac, no-ac, ac, no-ac, ...] + p = 1/3: [no-ac, ac, no-ac, no-ac, ac, no-ac, ...] + p = 2/3: [ac, no-ac, ac, ac, no-ac, ac, ...] + Since blocks are homogeneous, we make ac blocks evenly spaced among + all blocks. + + Implementation: + For a given ac ratio p, we should essentially apply ac on every "1/p" + blocks. The first ac block can be as early as the 0th block, or as + late as the "1/p"th block, and we pick the middle one: (0.5p)th block. + Therefore, we are essentially to apply ac on: + (0.5/p)th block, (1.5/p)th block, (2.5/p)th block, etc., and of course, + with these values rounding to integers. + Since ac is applied recursively, we can simply use the following math + in the code to apply ac on corresponding blocks. + """ + block_idx = 0 + cut_off = 1 / 2 + # when passing p as a fraction number (e.g. 1/3), it will be interpreted + # as a string in argv, thus we need eval("1/3") here for fractions. + p = eval(p) if isinstance(p, str) else p + + def selective_checkpointing(submodule): + nonlocal block_idx + nonlocal cut_off + + if isinstance(submodule, block): + block_idx += 1 + if block_idx * p >= cut_off: + cut_off += 1 + return True + return False + + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=non_reentrant_wrapper, + check_fn=selective_checkpointing, + ) diff --git a/videox_fun/utils/cfg_optimization.py b/videox_fun/utils/cfg_optimization.py new file mode 100755 index 0000000000000000000000000000000000000000..344a2ef87e0f8d6e3a7a57cfe9719ba1784f7608 --- /dev/null +++ b/videox_fun/utils/cfg_optimization.py @@ -0,0 +1,39 @@ +import numpy as np +import torch + + +def cfg_skip(): + def decorator(func): + def wrapper(self, x, *args, **kwargs): + bs = len(x) + if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio): + bs_half = int(bs // 2) + + new_x = x[bs_half:] + + new_args = [] + for arg in args: + if isinstance(arg, (torch.Tensor, list, tuple, np.ndarray)): + new_args.append(arg[bs_half:]) + else: + new_args.append(arg) + + new_kwargs = {} + for key, content in kwargs.items(): + if isinstance(content, (torch.Tensor, list, tuple, np.ndarray)): + new_kwargs[key] = content[bs_half:] + else: + new_kwargs[key] = content + else: + new_x = x + new_args = args + new_kwargs = kwargs + + result = func(self, new_x, *new_args, **new_kwargs) + + if bs >= 2 and self.cfg_skip_ratio is not None and self.current_steps >= self.num_inference_steps * (1 - self.cfg_skip_ratio): + result = torch.cat([result, result], dim=0) + + return result + return wrapper + return decorator \ No newline at end of file diff --git a/videox_fun/utils/discrete_sampler.py b/videox_fun/utils/discrete_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..40b8316f1ed9dac69071dc8f10d012b5d97c5c90 --- /dev/null +++ b/videox_fun/utils/discrete_sampler.py @@ -0,0 +1,52 @@ +"""Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py +""" +import torch + +class DiscreteSampling: + def __init__(self, num_idx, uniform_sampling=False, start_num_idx=0, sp_size=1): + self.num_idx = num_idx + self.start_num_idx = start_num_idx + self.uniform_sampling = uniform_sampling + self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + + if self.is_distributed and self.uniform_sampling: + world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + + i = 1 + while True: + if world_size % i != 0 or num_idx % (world_size // i) != 0: + i += 1 + else: + if i >= sp_size: + self.group_num = world_size // i + elif sp_size > world_size: + self.group_num = 1 + else: + self.group_num = world_size // sp_size + break + assert self.group_num > 0 + assert world_size % self.group_num == 0 + # the number of rank in one group + self.group_width = world_size // self.group_num + self.sigma_interval = self.num_idx // self.group_num + print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % ( + self.rank, world_size, self.group_num, + self.group_width, self.sigma_interval)) + + def __call__(self, n_samples, generator=None, device=None): + if self.is_distributed and self.uniform_sampling: + group_index = self.rank // self.group_width + idx = torch.randint( + self.start_num_idx + group_index * self.sigma_interval, + self.start_num_idx + (group_index + 1) * self.sigma_interval, + (n_samples,), + generator=generator, device=device, + ) + print('proc[%d] idx=%s' % (self.rank, idx)) + else: + idx = torch.randint( + self.start_num_idx, self.start_num_idx + self.num_idx, (n_samples,), + generator=generator, device=device, + ) + return idx \ No newline at end of file diff --git a/videox_fun/utils/fm_solvers.py b/videox_fun/utils/fm_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..2d516ca97c98cd8b5e0703a531577ca553d22a22 --- /dev/null +++ b/videox_fun/utils/fm_solvers.py @@ -0,0 +1,857 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# Convert dpm solver for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + pass + + +def get_sampling_sigmas(sampling_steps, shift): + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = (shift * sigma / (1 + (shift - 1) * sigma)) + + return sigma + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. This determines the resolution of the diffusion process. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored + and used in multistep updates. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + shift (`float`, *optional*, defaults to 1.0): + A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling + process. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is + applied on the fly. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent + saturation and improve photorealism. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, *optional*, defaults to "zero"): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + invert_sigmas: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", + deprecation_message) + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++" + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++" + ] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + self._step_index = None + self._begin_index = None + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / + sigma_s) * sample - (alpha_t * + (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / + alpha_s) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ((alpha_t / alpha_s) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * + (alpha_t * (torch.exp(-h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * + (sigma_t * (torch.exp(h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * + (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / + (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + self.sigmas[self.step_index - 2], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[ + -2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - + (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((alpha_t / alpha_s0) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2) + return x_t # pyright: ignore + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or + (self.config.lower_order_final and len(self.timesteps) < 15) or + self.config.final_sigmas_type == "zero") + lower_order_second = ((self.step_index == len(self.timesteps) - 2) and + self.config.lower_order_final and + len(self.timesteps) < 15) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" + ] and variance_noise is None: + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to( + device=model_output.device, + dtype=torch.float32) # pyright: ignore + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.Tensor`): + The input sample. + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/videox_fun/utils/fm_solvers_unipc.py b/videox_fun/utils/fm_solvers_unipc.py new file mode 100644 index 0000000000000000000000000000000000000000..b8b347b7eca72e961adc2f1fdbbef33ceb1eb009 --- /dev/null +++ b/videox_fun/utils/fm_solvers_unipc.py @@ -0,0 +1,800 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import (KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], + b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps \ No newline at end of file diff --git a/videox_fun/utils/fp8_optimization.py b/videox_fun/utils/fp8_optimization.py new file mode 100755 index 0000000000000000000000000000000000000000..b8b00387b1423f5fdbeaa57dccadb4a38e14576e --- /dev/null +++ b/videox_fun/utils/fp8_optimization.py @@ -0,0 +1,64 @@ +"""Modified from https://github.com/kijai/ComfyUI-MochiWrapper +""" +import importlib.util + +import torch +import torch.nn as nn + +def replace_parameters_by_name(module, name_keywords, device): + from torch import nn + for name, param in list(module.named_parameters(recurse=False)): + if any(keyword in name for keyword in name_keywords): + if isinstance(param, nn.Parameter): + tensor = param.data + delattr(module, name) + setattr(module, name, tensor.to(device=device)) + for child_name, child_module in module.named_children(): + replace_parameters_by_name(child_module, name_keywords, device) + +def convert_model_weight_to_float8(model, exclude_module_name=['embed_tokens'], device=None): + for name, module in model.named_modules(): + flag = False + for _exclude_module_name in exclude_module_name: + if _exclude_module_name in name: + flag = True + if flag: + continue + for param_name, param in module.named_parameters(): + flag = False + for _exclude_module_name in exclude_module_name: + if _exclude_module_name in param_name: + flag = True + if flag: + continue + param.data = param.data.to(torch.float8_e4m3fn) + +def autocast_model_forward(cls, origin_dtype, *inputs, **kwargs): + weight_dtype = cls.weight.dtype + cls.to(origin_dtype) + + # Convert all inputs to the original dtype + inputs = [input.to(origin_dtype) for input in inputs] + out = cls.original_forward(*inputs, **kwargs) + + cls.to(weight_dtype) + return out + +def convert_weight_dtype_wrapper(module, origin_dtype): + for name, module in module.named_modules(): + if name == "" or "embed_tokens" in name: + continue + original_forward = module.forward + if hasattr(module, "weight") and module.weight is not None: + setattr(module, "original_forward", original_forward) + setattr( + module, + "forward", + lambda *inputs, m=module, **kwargs: autocast_model_forward(m, origin_dtype, *inputs, **kwargs) + ) + +def undo_convert_weight_dtype_wrapper(module): + for name, module in module.named_modules(): + if hasattr(module, "original_forward") and module.weight is not None: + setattr(module, "forward", module.original_forward) + delattr(module, "original_forward") \ No newline at end of file diff --git a/videox_fun/utils/lora_utils.py b/videox_fun/utils/lora_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..d303cb47e7fc540feafcbc52fca807c85dc2f5e3 --- /dev/null +++ b/videox_fun/utils/lora_utils.py @@ -0,0 +1,634 @@ +# LoRA network module +# reference: +# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py +# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py +# https://github.com/bmaltais/kohya_ss + +import hashlib +import math +import os +from collections import defaultdict +from io import BytesIO +from typing import List, Optional, Type, Union + +import safetensors.torch +import torch +import torch.utils.checkpoint +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from safetensors.torch import load_file +from transformers import T5EncoderModel + + +class LoRAModule(torch.nn.Module): + """ + replaces forward method of the original Linear, instead of replacing the original Linear module. + """ + + def __init__( + self, + lora_name, + org_module: torch.nn.Module, + multiplier=1.0, + lora_dim=4, + alpha=1, + dropout=None, + rank_dropout=None, + module_dropout=None, + ): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + self.lora_name = lora_name + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + alpha = self.lora_dim if alpha is None or alpha == 0 else alpha + self.scale = alpha / self.lora_dim + self.register_buffer("alpha", torch.tensor(alpha)) + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + self.org_module = org_module # remove in applying + self.dropout = dropout + self.rank_dropout = rank_dropout + self.module_dropout = module_dropout + + def apply_to(self): + self.org_forward = self.org_module.forward + self.org_module.forward = self.forward + del self.org_module + + def forward(self, x, *args, **kwargs): + weight_dtype = x.dtype + org_forwarded = self.org_forward(x) + + # module dropout + if self.module_dropout is not None and self.training: + if torch.rand(1) < self.module_dropout: + return org_forwarded + + lx = self.lora_down(x.to(self.lora_down.weight.dtype)) + + # normal dropout + if self.dropout is not None and self.training: + lx = torch.nn.functional.dropout(lx, p=self.dropout) + + # rank dropout + if self.rank_dropout is not None and self.training: + mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout + if len(lx.size()) == 3: + mask = mask.unsqueeze(1) # for Text Encoder + elif len(lx.size()) == 4: + mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d + lx = lx * mask + + # scaling for rank dropout: treat as if the rank is changed + scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability + else: + scale = self.scale + + lx = self.lora_up(lx) + + return org_forwarded.to(weight_dtype) + lx.to(weight_dtype) * self.multiplier * scale + +def addnet_hash_legacy(b): + """Old model hash used by sd-webui-additional-networks for .safetensors format files""" + m = hashlib.sha256() + + b.seek(0x100000) + m.update(b.read(0x10000)) + return m.hexdigest()[0:8] + +def addnet_hash_safetensors(b): + """New model hash used by sd-webui-additional-networks for .safetensors format files""" + hash_sha256 = hashlib.sha256() + blksize = 1024 * 1024 + + b.seek(0) + header = b.read(8) + n = int.from_bytes(header, "little") + + offset = n + 8 + b.seek(offset) + for chunk in iter(lambda: b.read(blksize), b""): + hash_sha256.update(chunk) + + return hash_sha256.hexdigest() + +def precalculate_safetensors_hashes(tensors, metadata): + """Precalculate the model hashes needed by sd-webui-additional-networks to + save time on indexing the model later.""" + + # Because writing user metadata to the file can change the result of + # sd_models.model_hash(), only retain the training metadata for purposes of + # calculating the hash, as they are meant to be immutable + metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} + + bytes = safetensors.torch.save(tensors, metadata) + b = BytesIO(bytes) + + model_hash = addnet_hash_safetensors(b) + legacy_hash = addnet_hash_legacy(b) + return model_hash, legacy_hash + +class LoRANetwork(torch.nn.Module): + TRANSFORMER_TARGET_REPLACE_MODULE = [ + "CogVideoXTransformer3DModel", "WanTransformer3DModel", \ + "Wan2_2Transformer3DModel", "FluxTransformer2DModel", "QwenImageTransformer2DModel", \ + "Wan2_2Transformer3DModel_Animate", "Wan2_2Transformer3DModel_S2V", "FantasyTalkingTransformer3DModel", \ + "HunyuanVideoTransformer3DModel", "Flux2Transformer2DModel", "ZImageTransformer2DModel", + ] + TEXT_ENCODER_TARGET_REPLACE_MODULE = ["T5LayerSelfAttention", "T5LayerFF", "BertEncoder", "T5SelfAttention", "T5CrossAttention"] + LORA_PREFIX_TRANSFORMER = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + def __init__( + self, + text_encoder: Union[List[T5EncoderModel], T5EncoderModel], + unet, + multiplier: float = 1.0, + lora_dim: int = 4, + alpha: float = 1, + dropout: Optional[float] = None, + module_class: Type[object] = LoRAModule, + skip_name: str = None, + target_name: str = None, + varbose: Optional[bool] = False, + ) -> None: + super().__init__() + self.multiplier = multiplier + + self.lora_dim = lora_dim + self.alpha = alpha + self.dropout = dropout + + print(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") + print(f"neuron dropout: p={self.dropout}") + + # create module instances + def create_modules( + is_unet: bool, + root_module: torch.nn.Module, + target_replace_modules: List[torch.nn.Module], + ) -> List[LoRAModule]: + prefix = ( + self.LORA_PREFIX_TRANSFORMER + if is_unet + else self.LORA_PREFIX_TEXT_ENCODER + ) + loras = [] + skipped = [] + for name, module in root_module.named_modules(): + if module.__class__.__name__ in target_replace_modules: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "LoRACompatibleLinear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" or child_module.__class__.__name__ == "LoRACompatibleConv" + is_conv2d_1x1 = is_conv2d and child_module.kernel_size == (1, 1) + + skip_names = skip_name.split(',') if skip_name is not None else [] + target_names = target_name.split(',') if target_name is not None else [] + + skip_names = [name.strip() for name in skip_names if name.strip()] + target_names = [name.strip() for name in target_names if name.strip()] + + if skip_names and any(skip_n in child_name for skip_n in skip_names): + continue + + if target_names and not any(target_n in child_name for target_n in target_names): + continue + + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + + dim = None + alpha = None + + if is_linear or is_conv2d_1x1: + dim = self.lora_dim + alpha = self.alpha + + if dim is None or dim == 0: + if is_linear or is_conv2d_1x1: + skipped.append(lora_name) + continue + + lora = module_class( + lora_name, + child_module, + self.multiplier, + dim, + alpha, + dropout=dropout, + ) + loras.append(lora) + return loras, skipped + + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] + + self.text_encoder_loras = [] + skipped_te = [] + for i, text_encoder in enumerate(text_encoders): + if text_encoder is not None: + text_encoder_loras, skipped = create_modules(False, text_encoder, LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) + self.text_encoder_loras.extend(text_encoder_loras) + skipped_te += skipped + print(f"create LoRA for Text Encoder: {len(self.text_encoder_loras)} modules.") + + self.unet_loras, skipped_un = create_modules(True, unet, LoRANetwork.TRANSFORMER_TARGET_REPLACE_MODULE) + print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + + # assertion + names = set() + for lora in self.text_encoder_loras + self.unet_loras: + assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" + names.add(lora.lora_name) + + def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): + if apply_text_encoder: + print("enable LoRA for text encoder") + else: + self.text_encoder_loras = [] + + if apply_unet: + print("enable LoRA for U-Net") + else: + self.unet_loras = [] + + for lora in self.text_encoder_loras + self.unet_loras: + lora.apply_to() + self.add_module(lora.lora_name, lora) + + def set_multiplier(self, multiplier): + self.multiplier = multiplier + for lora in self.text_encoder_loras + self.unet_loras: + lora.multiplier = self.multiplier + + def load_weights(self, file): + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import load_file + + weights_sd = load_file(file) + else: + weights_sd = torch.load(file, map_location="cpu") + info = self.load_state_dict(weights_sd, False) + return info + + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + self.requires_grad_(True) + all_params = [] + + def enumerate_params(loras): + params = [] + for lora in loras: + params.extend(lora.parameters()) + return params + + if self.text_encoder_loras: + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) + + if self.unet_loras: + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) + + return all_params + + def enable_gradient_checkpointing(self): + pass + + def get_trainable_params(self): + return self.parameters() + + def save_weights(self, file, dtype, metadata): + if metadata is not None and len(metadata) == 0: + metadata = None + + state_dict = self.state_dict() + + if dtype is not None: + for key in list(state_dict.keys()): + v = state_dict[key] + v = v.detach().clone().to("cpu").to(dtype) + state_dict[key] = v + + if os.path.splitext(file)[1] == ".safetensors": + from safetensors.torch import save_file + + # Precalculate model hashes to save time on indexing + if metadata is None: + metadata = {} + model_hash, legacy_hash = precalculate_safetensors_hashes(state_dict, metadata) + metadata["sshs_model_hash"] = model_hash + metadata["sshs_legacy_hash"] = legacy_hash + + save_file(state_dict, file, metadata) + else: + torch.save(state_dict, file) + +def create_network( + multiplier: float, + network_dim: Optional[int], + network_alpha: Optional[float], + text_encoder: Union[T5EncoderModel, List[T5EncoderModel]], + transformer, + neuron_dropout: Optional[float] = None, + skip_name: str = None, + target_name: str = None, + **kwargs, +): + if network_dim is None: + network_dim = 4 # default + if network_alpha is None: + network_alpha = 1.0 + + network = LoRANetwork( + text_encoder, + transformer, + multiplier=multiplier, + lora_dim=network_dim, + alpha=network_alpha, + dropout=neuron_dropout, + skip_name=skip_name, + target_name=target_name, + varbose=True, + ) + return network + +def merge_lora(pipeline, lora_path, multiplier, device='cpu', dtype=torch.float32, state_dict=None, transformer_only=False, sub_transformer_name="transformer"): + if lora_path is None: + return pipeline + + LORA_PREFIX_TRANSFORMER = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + if state_dict is None: + state_dict = load_file(lora_path) + else: + state_dict = state_dict + updates = defaultdict(dict) + for key, value in state_dict.items(): + if "lora_A" in key or "lora_B" in key: + key = "lora_unet__" + key + key = key.replace(".", "_") + if key.endswith("_lora_up_weight"): + key = key[:-15] + ".lora_up.weight" + if key.endswith("_lora_down_weight"): + key = key[:-17] + ".lora_down.weight" + if key.endswith("_lora_A_default_weight"): + key = key[:-21] + ".lora_A.weight" + if key.endswith("_lora_B_default_weight"): + key = key[:-21] + ".lora_B.weight" + if key.endswith("_lora_A_weight"): + key = key[:-14] + ".lora_A.weight" + if key.endswith("_lora_B_weight"): + key = key[:-14] + ".lora_B.weight" + if key.endswith("_alpha"): + key = key[:-6] + ".alpha" + key = key.replace(".lora_A.default.", ".lora_down.") + key = key.replace(".lora_B.default.", ".lora_up.") + key = key.replace(".lora_A.", ".lora_down.") + key = key.replace(".lora_B.", ".lora_up.") + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + sequential_cpu_offload_flag = False + if pipeline.transformer.device == torch.device(type="meta"): + pipeline.remove_all_hooks() + sequential_cpu_offload_flag = True + offload_device = pipeline._offload_device + + for layer, elems in updates.items(): + + if "lora_te" in layer: + if transformer_only: + continue + else: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") + curr_layer = getattr(pipeline, sub_transformer_name) + + try: + curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) + except Exception: + temp_name = layer_infos.pop(0) + try: + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos)) + break + except Exception: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print(f'Error loading layer in front search: {layer}. Try it in back search.') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + except Exception: + if "lora_te" in layer: + if transformer_only: + continue + else: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_TRANSFORMER + "_")[-1].split("_") + curr_layer = getattr(pipeline, sub_transformer_name) + + len_layer_infos = len(layer_infos) + start_index = 0 if len_layer_infos >= 1 and len(layer_infos[0]) > 0 else 1 + end_indx = len_layer_infos + + error_flag = False if len_layer_infos >= 1 else True + while start_index < len_layer_infos: + try: + if start_index >= end_indx: + print(f'Error loading layer in back search: {layer}') + error_flag = True + break + curr_layer = curr_layer.__getattr__("_".join(layer_infos[start_index:end_indx])) + start_index = end_indx + end_indx = len_layer_infos + except Exception: + end_indx -= 1 + if error_flag: + continue + + origin_dtype = curr_layer.weight.data.dtype + origin_device = curr_layer.weight.data.device + + curr_layer = curr_layer.to(device, dtype) + weight_up = elems['lora_up.weight'].to(device, dtype) + weight_down = elems['lora_down.weight'].to(device, dtype) + + if 'alpha' in elems.keys(): + alpha = elems['alpha'].item() / weight_up.shape[1] + else: + alpha = 1.0 + + if len(weight_up.shape) == 4: + curr_layer.weight.data += multiplier * alpha * torch.mm( + weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) + else: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) + curr_layer = curr_layer.to(origin_device, origin_dtype) + + if sequential_cpu_offload_flag: + pipeline.enable_sequential_cpu_offload(device=offload_device) + return pipeline + +# TODO: Refactor with merge_lora. +def unmerge_lora(pipeline, lora_path, multiplier=1, device="cpu", dtype=torch.float32, sub_transformer_name="transformer"): + if lora_path is None: + return pipeline + + """Unmerge state_dict in LoRANetwork from the pipeline in diffusers.""" + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + state_dict = load_file(lora_path) + + updates = defaultdict(dict) + for key, value in state_dict.items(): + if "lora_A" in key or "lora_B" in key: + key = "lora_unet__" + key + key = key.replace(".", "_") + if key.endswith("_lora_up_weight"): + key = key[:-15] + ".lora_up.weight" + if key.endswith("_lora_down_weight"): + key = key[:-17] + ".lora_down.weight" + if key.endswith("_lora_A_default_weight"): + key = key[:-21] + ".lora_A.weight" + if key.endswith("_lora_B_default_weight"): + key = key[:-21] + ".lora_B.weight" + if key.endswith("_lora_A_weight"): + key = key[:-14] + ".lora_A.weight" + if key.endswith("_lora_B_weight"): + key = key[:-14] + ".lora_B.weight" + if key.endswith("_alpha"): + key = key[:-6] + ".alpha" + key = key.replace(".lora_A.default.", ".lora_down.") + key = key.replace(".lora_B.default.", ".lora_up.") + key = key.replace(".lora_A.", ".lora_down.") + key = key.replace(".lora_B.", ".lora_up.") + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + sequential_cpu_offload_flag = False + if pipeline.transformer.device == torch.device(type="meta"): + pipeline.remove_all_hooks() + sequential_cpu_offload_flag = True + + for layer, elems in updates.items(): + + if "lora_te" in layer: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = getattr(pipeline, sub_transformer_name) + + try: + curr_layer = curr_layer.__getattr__("_".join(layer_infos[1:])) + except Exception: + temp_name = layer_infos.pop(0) + try: + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name + "_" + "_".join(layer_infos)) + break + except Exception: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(layer_infos) == 0: + print(f'Error loading layer in front search: {layer}. Try it in back search.') + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + except Exception: + if "lora_te" in layer: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = getattr(pipeline, sub_transformer_name) + len_layer_infos = len(layer_infos) + + start_index = 0 if len_layer_infos >= 1 and len(layer_infos[0]) > 0 else 1 + end_indx = len_layer_infos + + error_flag = False if len_layer_infos >= 1 else True + while start_index < len_layer_infos: + try: + if start_index >= end_indx: + print(f'Error loading layer in back search: {layer}') + error_flag = True + break + curr_layer = curr_layer.__getattr__("_".join(layer_infos[start_index:end_indx])) + start_index = end_indx + end_indx = len_layer_infos + except Exception: + end_indx -= 1 + if error_flag: + continue + + origin_dtype = curr_layer.weight.data.dtype + origin_device = curr_layer.weight.data.device + + curr_layer = curr_layer.to(device, dtype) + weight_up = elems['lora_up.weight'].to(device, dtype) + weight_down = elems['lora_down.weight'].to(device, dtype) + + if 'alpha' in elems.keys(): + alpha = elems['alpha'].item() / weight_up.shape[1] + else: + alpha = 1.0 + + if len(weight_up.shape) == 4: + curr_layer.weight.data -= multiplier * alpha * torch.mm( + weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2) + ).unsqueeze(2).unsqueeze(3) + else: + curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down) + curr_layer = curr_layer.to(origin_device, origin_dtype) + + if sequential_cpu_offload_flag: + pipeline.enable_sequential_cpu_offload(device=device) + return pipeline diff --git a/videox_fun/utils/utils.py b/videox_fun/utils/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..730d5af33d7203a6a023d6855761fccc39134b25 --- /dev/null +++ b/videox_fun/utils/utils.py @@ -0,0 +1,447 @@ +import gc +import inspect +import os +import shutil +import subprocess +import time + +import cv2 +import imageio +import numpy as np +import torch +import torchvision +from einops import rearrange +from PIL import Image + + +def filter_kwargs(cls, kwargs): + sig = inspect.signature(cls.__init__) + valid_params = set(sig.parameters.keys()) - {'self', 'cls'} + filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + return filtered_kwargs + +def get_width_and_height_from_image_and_base_resolution(image, base_resolution): + target_pixels = int(base_resolution) * int(base_resolution) + original_width, original_height = Image.open(image).size + ratio = (target_pixels / (original_width * original_height)) ** 0.5 + width_slider = round(original_width * ratio) + height_slider = round(original_height * ratio) + return height_slider, width_slider + +def color_transfer(sc, dc): + """ + Transfer color distribution from of sc, referred to dc. + + Args: + sc (numpy.ndarray): input image to be transfered. + dc (numpy.ndarray): reference image + + Returns: + numpy.ndarray: Transferred color distribution on the sc. + """ + + def get_mean_and_std(img): + x_mean, x_std = cv2.meanStdDev(img) + x_mean = np.hstack(np.around(x_mean, 2)) + x_std = np.hstack(np.around(x_std, 2)) + return x_mean, x_std + + sc = cv2.cvtColor(sc, cv2.COLOR_RGB2LAB) + s_mean, s_std = get_mean_and_std(sc) + dc = cv2.cvtColor(dc, cv2.COLOR_RGB2LAB) + t_mean, t_std = get_mean_and_std(dc) + img_n = ((sc - s_mean) * (t_std / s_std)) + t_mean + np.putmask(img_n, img_n > 255, 255) + np.putmask(img_n, img_n < 0, 0) + dst = cv2.cvtColor(cv2.convertScaleAbs(img_n), cv2.COLOR_LAB2RGB) + return dst + +def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=12, imageio_backend=True, color_transfer_post_process=False): + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).numpy().astype(np.uint8) + outputs.append(Image.fromarray(x)) + + if color_transfer_post_process: + for i in range(1, len(outputs)): + outputs[i] = Image.fromarray(color_transfer(np.uint8(outputs[i]), np.uint8(outputs[0]))) + + os.makedirs(os.path.dirname(path), exist_ok=True) + if imageio_backend: + if path.endswith("mp4"): + imageio.mimsave(path, outputs, fps=fps) + else: + imageio.mimsave(path, outputs, duration=(1000 * 1/fps)) + else: + if path.endswith("mp4"): + path = path.replace('.mp4', '.gif') + outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0) + +def merge_video_audio(video_path: str, audio_path: str): + """ + Merge the video and audio into a new video, with the duration set to the shorter of the two, + and overwrite the original video file. + + Parameters: + video_path (str): Path to the original video file + audio_path (str): Path to the audio file + """ + # check + if not os.path.exists(video_path): + raise FileNotFoundError(f"video file {video_path} does not exist") + if not os.path.exists(audio_path): + raise FileNotFoundError(f"audio file {audio_path} does not exist") + + base, ext = os.path.splitext(video_path) + temp_output = f"{base}_temp{ext}" + + try: + # create ffmpeg command + command = [ + 'ffmpeg', + '-y', # overwrite + '-i', + video_path, + '-i', + audio_path, + '-c:v', + 'copy', # copy video stream + '-c:a', + 'aac', # use AAC audio encoder + '-b:a', + '192k', # set audio bitrate (optional) + '-map', + '0:v:0', # select the first video stream + '-map', + '1:a:0', # select the first audio stream + '-shortest', # choose the shortest duration + temp_output + ] + + # execute the command + print("Start merging video and audio...") + result = subprocess.run( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + + # check result + if result.returncode != 0: + error_msg = f"FFmpeg execute failed: {result.stderr}" + print(error_msg) + raise RuntimeError(error_msg) + + shutil.move(temp_output, video_path) + print(f"Merge completed, saved to {video_path}") + + except Exception as e: + if os.path.exists(temp_output): + os.remove(temp_output) + print(f"merge_video_audio failed with error: {e}") + +def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size): + if validation_image_start is not None and validation_image_end is not None: + if type(validation_image_start) is str and os.path.isfile(validation_image_start): + image_start = clip_image = Image.open(validation_image_start).convert("RGB") + image_start = image_start.resize([sample_size[1], sample_size[0]]) + clip_image = clip_image.resize([sample_size[1], sample_size[0]]) + else: + image_start = clip_image = validation_image_start + image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] + clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] + + if type(validation_image_end) is str and os.path.isfile(validation_image_end): + image_end = Image.open(validation_image_end).convert("RGB") + image_end = image_end.resize([sample_size[1], sample_size[0]]) + else: + image_end = validation_image_end + image_end = [_image_end.resize([sample_size[1], sample_size[0]]) for _image_end in image_end] + + if type(image_start) is list: + clip_image = clip_image[0] + start_video = torch.cat( + [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], + dim=2 + ) + input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) + input_video[:, :, :len(image_start)] = start_video + + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, len(image_start):] = 255 + else: + input_video = torch.tile( + torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), + [1, 1, video_length, 1, 1] + ) + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, 1:] = 255 + + if type(image_end) is list: + image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end] + end_video = torch.cat( + [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end], + dim=2 + ) + input_video[:, :, -len(end_video):] = end_video + + input_video_mask[:, :, -len(image_end):] = 0 + else: + image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) + input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) + input_video_mask[:, :, -1:] = 0 + + input_video = input_video / 255 + + elif validation_image_start is not None: + if type(validation_image_start) is str and os.path.isfile(validation_image_start): + image_start = clip_image = Image.open(validation_image_start).convert("RGB") + image_start = image_start.resize([sample_size[1], sample_size[0]]) + clip_image = clip_image.resize([sample_size[1], sample_size[0]]) + else: + image_start = clip_image = validation_image_start + image_start = [_image_start.resize([sample_size[1], sample_size[0]]) for _image_start in image_start] + clip_image = [_clip_image.resize([sample_size[1], sample_size[0]]) for _clip_image in clip_image] + image_end = None + + if type(image_start) is list: + clip_image = clip_image[0] + start_video = torch.cat( + [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], + dim=2 + ) + input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1]) + input_video[:, :, :len(image_start)] = start_video + input_video = input_video / 255 + + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, len(image_start):] = 255 + else: + input_video = torch.tile( + torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), + [1, 1, video_length, 1, 1] + ) / 255 + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, 1:, ] = 255 + else: + image_start = None + image_end = None + input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]]) + input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255 + clip_image = None + + del image_start + del image_end + gc.collect() + + return input_video, input_video_mask, clip_image + +def get_video_to_video_latent(input_video_path, video_length, sample_size, fps=None, validation_video_mask=None, ref_image=None): + if input_video_path is not None: + if isinstance(input_video_path, str): + cap = cv2.VideoCapture(input_video_path) + input_video = [] + + original_fps = cap.get(cv2.CAP_PROP_FPS) + frame_skip = 1 if fps is None else max(1,int(original_fps // fps)) + + frame_count = 0 + + while True: + ret, frame = cap.read() + if not ret: + break + + if frame_count % frame_skip == 0: + frame = cv2.resize(frame, (sample_size[1], sample_size[0])) + input_video.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + + frame_count += 1 + + cap.release() + else: + input_video = input_video_path + + input_video = torch.from_numpy(np.array(input_video))[:video_length] + input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255 + + if validation_video_mask is not None: + validation_video_mask = Image.open(validation_video_mask).convert('L').resize((sample_size[1], sample_size[0])) + input_video_mask = np.where(np.array(validation_video_mask) < 240, 0, 255) + + input_video_mask = torch.from_numpy(np.array(input_video_mask)).unsqueeze(0).unsqueeze(-1).permute([3, 0, 1, 2]).unsqueeze(0) + input_video_mask = torch.tile(input_video_mask, [1, 1, input_video.size()[2], 1, 1]) + input_video_mask = input_video_mask.to(input_video.device, input_video.dtype) + else: + input_video_mask = torch.zeros_like(input_video[:, :1]) + input_video_mask[:, :, :] = 255 + else: + input_video, input_video_mask = None, None + + if ref_image is not None: + if isinstance(ref_image, str): + clip_image = Image.open(ref_image).convert("RGB") + else: + clip_image = Image.fromarray(np.array(ref_image, np.uint8)) + else: + clip_image = None + + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + else: + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + return input_video, input_video_mask, ref_image, clip_image + +def get_image_latent(ref_image=None, sample_size=None, padding=False): + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + if padding: + ref_image = padding_image(ref_image, sample_size[1], sample_size[0]) + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + elif isinstance(ref_image, Image.Image): + ref_image = ref_image.convert("RGB") + if padding: + ref_image = padding_image(ref_image, sample_size[1], sample_size[0]) + ref_image = ref_image.resize((sample_size[1], sample_size[0])) + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + else: + ref_image = torch.from_numpy(np.array(ref_image)) + ref_image = ref_image.unsqueeze(0).permute([3, 0, 1, 2]).unsqueeze(0) / 255 + + return ref_image + +def get_image(ref_image=None): + if ref_image is not None: + if isinstance(ref_image, str): + ref_image = Image.open(ref_image).convert("RGB") + elif isinstance(ref_image, Image.Image): + ref_image = ref_image.convert("RGB") + + return ref_image + +def padding_image(images, new_width, new_height): + new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255)) + + aspect_ratio = images.width / images.height + if new_width / new_height > 1: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + else: + if aspect_ratio > new_width / new_height: + new_img_width = new_width + new_img_height = int(new_img_width / aspect_ratio) + else: + new_img_height = new_height + new_img_width = int(new_img_height * aspect_ratio) + + resized_img = images.resize((new_img_width, new_img_height)) + + paste_x = (new_width - new_img_width) // 2 + paste_y = (new_height - new_img_height) // 2 + + new_image.paste(resized_img, (paste_x, paste_y)) + + return new_image + +def timer(func): + def wrapper(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + print(f"function {func.__name__} running for {end_time - start_time} seconds") + return result + return wrapper + +def timer_record(model_name=""): + def decorator(func): + def wrapper(*args, **kwargs): + torch.cuda.synchronize() + start_time = time.time() + result = func(*args, **kwargs) + torch.cuda.synchronize() + end_time = time.time() + import torch.distributed as dist + if dist.is_initialized(): + if dist.get_rank() == 0: + time_sum = end_time - start_time + print('# --------------------------------------------------------- #') + print(f'# {model_name} time: {time_sum}s') + print('# --------------------------------------------------------- #') + _write_to_excel(model_name, time_sum) + else: + time_sum = end_time - start_time + print('# --------------------------------------------------------- #') + print(f'# {model_name} time: {time_sum}s') + print('# --------------------------------------------------------- #') + _write_to_excel(model_name, time_sum) + return result + return wrapper + return decorator + +def _write_to_excel(model_name, time_sum): + import os + + import pandas as pd + + row_env = os.environ.get(f"{model_name}_EXCEL_ROW", "1") # 默认第1行 + col_env = os.environ.get(f"{model_name}_EXCEL_COL", "1") # 默认第A列 + file_path = os.environ.get("EXCEL_FILE", "timing_records.xlsx") # 默认文件名 + + try: + df = pd.read_excel(file_path, sheet_name="Sheet1", header=None) + except FileNotFoundError: + df = pd.DataFrame() + + row_idx = int(row_env) + col_idx = int(col_env) + + if row_idx >= len(df): + df = pd.concat([df, pd.DataFrame([ [None] * (len(df.columns) if not df.empty else 0) ] * (row_idx - len(df) + 1))], ignore_index=True) + + if col_idx >= len(df.columns): + df = pd.concat([df, pd.DataFrame(columns=range(len(df.columns), col_idx + 1))], axis=1) + + df.iloc[row_idx, col_idx] = time_sum + + df.to_excel(file_path, index=False, header=False, sheet_name="Sheet1") + +def get_autocast_dtype(): + try: + if not torch.cuda.is_available(): + print("CUDA not available, using float16 by default.") + return torch.float16 + + device = torch.cuda.current_device() + prop = torch.cuda.get_device_properties(device) + + print(f"GPU: {prop.name}, Compute Capability: {prop.major}.{prop.minor}") + + if prop.major >= 8: + if torch.cuda.is_bf16_supported(): + print("Using bfloat16.") + return torch.bfloat16 + else: + print("Compute capability >= 8.0 but bfloat16 not supported, falling back to float16.") + return torch.float16 + else: + print("GPU does not support bfloat16 natively, using float16.") + return torch.float16 + + except Exception as e: + print(f"Error detecting GPU capability: {e}, falling back to float16.") + return torch.float16 \ No newline at end of file