Spaces:
Running
on
Zero
Running
on
Zero
Alexander Bagus
commited on
Commit
·
26893dc
1
Parent(s):
be751d2
22
Browse files- .gitignore +1 -1
- videox_fun/dist/__init__.py +72 -0
- videox_fun/dist/cogvideox_xfuser.py +93 -0
- videox_fun/dist/flux2_xfuser.py +194 -0
- videox_fun/dist/flux_xfuser.py +165 -0
- videox_fun/dist/fsdp.py +44 -0
- videox_fun/dist/fuser.py +87 -0
- videox_fun/dist/hunyuanvideo_xfuser.py +166 -0
- videox_fun/dist/qwen_xfuser.py +176 -0
- videox_fun/dist/wan_xfuser.py +180 -0
- videox_fun/dist/z_image_xfuser.py +85 -0
.gitignore
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
|
|
|
|
| 2 |
|
| 3 |
# Packages
|
| 4 |
*.egg
|
| 5 |
*.egg-info
|
| 6 |
-
dist
|
| 7 |
build
|
| 8 |
eggs
|
| 9 |
parts
|
|
|
|
| 1 |
|
| 2 |
+
/models/
|
| 3 |
|
| 4 |
# Packages
|
| 5 |
*.egg
|
| 6 |
*.egg-info
|
|
|
|
| 7 |
build
|
| 8 |
eggs
|
| 9 |
parts
|
videox_fun/dist/__init__.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
from .cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
|
| 4 |
+
from .flux2_xfuser import Flux2MultiGPUsAttnProcessor2_0
|
| 5 |
+
from .flux_xfuser import FluxMultiGPUsAttnProcessor2_0
|
| 6 |
+
from .fsdp import shard_model
|
| 7 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 8 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 9 |
+
get_world_group, init_distributed_environment,
|
| 10 |
+
initialize_model_parallel, sequence_parallel_all_gather,
|
| 11 |
+
sequence_parallel_chunk, set_multi_gpus_devices,
|
| 12 |
+
xFuserLongContextAttention)
|
| 13 |
+
from .hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0
|
| 14 |
+
from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0
|
| 15 |
+
from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward
|
| 16 |
+
from .z_image_xfuser import ZMultiGPUsSingleStreamAttnProcessor
|
| 17 |
+
|
| 18 |
+
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
|
| 19 |
+
if importlib.util.find_spec("paifuser") is not None:
|
| 20 |
+
# --------------------------------------------------------------- #
|
| 21 |
+
# The simple_wrapper is used to solve the problem
|
| 22 |
+
# about conflicts between cython and torch.compile
|
| 23 |
+
# --------------------------------------------------------------- #
|
| 24 |
+
def simple_wrapper(func):
|
| 25 |
+
def inner(*args, **kwargs):
|
| 26 |
+
return func(*args, **kwargs)
|
| 27 |
+
return inner
|
| 28 |
+
|
| 29 |
+
# --------------------------------------------------------------- #
|
| 30 |
+
# Sparse Attention Kernel
|
| 31 |
+
# --------------------------------------------------------------- #
|
| 32 |
+
from paifuser.models import parallel_magvit_vae
|
| 33 |
+
from paifuser.ops import wan_usp_sparse_attention_wrapper
|
| 34 |
+
|
| 35 |
+
from . import wan_xfuser
|
| 36 |
+
|
| 37 |
+
# --------------------------------------------------------------- #
|
| 38 |
+
# Sparse Attention
|
| 39 |
+
# --------------------------------------------------------------- #
|
| 40 |
+
usp_sparse_attn_wrap_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward))
|
| 41 |
+
wan_xfuser.usp_attn_forward = usp_sparse_attn_wrap_forward
|
| 42 |
+
usp_attn_forward = usp_sparse_attn_wrap_forward
|
| 43 |
+
print("Import PAI VAE Turbo and Sparse Attention")
|
| 44 |
+
|
| 45 |
+
# --------------------------------------------------------------- #
|
| 46 |
+
# Fast Rope Kernel
|
| 47 |
+
# --------------------------------------------------------------- #
|
| 48 |
+
import types
|
| 49 |
+
|
| 50 |
+
import torch
|
| 51 |
+
from paifuser.ops import (ENABLE_KERNEL, usp_fast_rope_apply_qk,
|
| 52 |
+
usp_rope_apply_real_qk)
|
| 53 |
+
|
| 54 |
+
def deepcopy_function(f):
|
| 55 |
+
return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
|
| 56 |
+
|
| 57 |
+
local_rope_apply_qk = deepcopy_function(wan_xfuser.rope_apply_qk)
|
| 58 |
+
|
| 59 |
+
if ENABLE_KERNEL:
|
| 60 |
+
def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 61 |
+
if torch.is_grad_enabled():
|
| 62 |
+
return local_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 63 |
+
else:
|
| 64 |
+
return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 68 |
+
return usp_rope_apply_real_qk(q, k, grid_sizes, freqs)
|
| 69 |
+
|
| 70 |
+
wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk
|
| 71 |
+
rope_apply_qk = adaptive_fast_usp_rope_apply_qk
|
| 72 |
+
print("Import PAI Fast rope")
|
videox_fun/dist/cogvideox_xfuser.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention import Attention
|
| 6 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 7 |
+
|
| 8 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 9 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 10 |
+
init_distributed_environment, initialize_model_parallel,
|
| 11 |
+
xFuserLongContextAttention)
|
| 12 |
+
|
| 13 |
+
class CogVideoXMultiGPUsAttnProcessor2_0:
|
| 14 |
+
r"""
|
| 15 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 16 |
+
query and key vectors, but does not include spatial normalization.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 21 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 22 |
+
|
| 23 |
+
def __call__(
|
| 24 |
+
self,
|
| 25 |
+
attn: Attention,
|
| 26 |
+
hidden_states: torch.Tensor,
|
| 27 |
+
encoder_hidden_states: torch.Tensor,
|
| 28 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 29 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 30 |
+
) -> torch.Tensor:
|
| 31 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 32 |
+
|
| 33 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 34 |
+
|
| 35 |
+
batch_size, sequence_length, _ = (
|
| 36 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
if attention_mask is not None:
|
| 40 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 41 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 42 |
+
|
| 43 |
+
query = attn.to_q(hidden_states)
|
| 44 |
+
key = attn.to_k(hidden_states)
|
| 45 |
+
value = attn.to_v(hidden_states)
|
| 46 |
+
|
| 47 |
+
inner_dim = key.shape[-1]
|
| 48 |
+
head_dim = inner_dim // attn.heads
|
| 49 |
+
|
| 50 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 51 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 52 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 53 |
+
|
| 54 |
+
if attn.norm_q is not None:
|
| 55 |
+
query = attn.norm_q(query)
|
| 56 |
+
if attn.norm_k is not None:
|
| 57 |
+
key = attn.norm_k(key)
|
| 58 |
+
|
| 59 |
+
# Apply RoPE if needed
|
| 60 |
+
if image_rotary_emb is not None:
|
| 61 |
+
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
|
| 62 |
+
if not attn.is_cross_attention:
|
| 63 |
+
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
|
| 64 |
+
|
| 65 |
+
img_q = query[:, :, text_seq_length:].transpose(1, 2)
|
| 66 |
+
txt_q = query[:, :, :text_seq_length].transpose(1, 2)
|
| 67 |
+
img_k = key[:, :, text_seq_length:].transpose(1, 2)
|
| 68 |
+
txt_k = key[:, :, :text_seq_length].transpose(1, 2)
|
| 69 |
+
img_v = value[:, :, text_seq_length:].transpose(1, 2)
|
| 70 |
+
txt_v = value[:, :, :text_seq_length].transpose(1, 2)
|
| 71 |
+
|
| 72 |
+
hidden_states = xFuserLongContextAttention()(
|
| 73 |
+
None,
|
| 74 |
+
img_q, img_k, img_v, dropout_p=0.0, causal=False,
|
| 75 |
+
joint_tensor_query=txt_q,
|
| 76 |
+
joint_tensor_key=txt_k,
|
| 77 |
+
joint_tensor_value=txt_v,
|
| 78 |
+
joint_strategy='front',
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 82 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 83 |
+
|
| 84 |
+
# linear proj
|
| 85 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 86 |
+
# dropout
|
| 87 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 88 |
+
|
| 89 |
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
| 90 |
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
| 91 |
+
)
|
| 92 |
+
return hidden_states, encoder_hidden_states
|
| 93 |
+
|
videox_fun/dist/flux2_xfuser.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention_processor import Attention
|
| 6 |
+
|
| 7 |
+
from .fuser import xFuserLongContextAttention
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 11 |
+
query = attn.to_q(hidden_states)
|
| 12 |
+
key = attn.to_k(hidden_states)
|
| 13 |
+
value = attn.to_v(hidden_states)
|
| 14 |
+
|
| 15 |
+
encoder_query = encoder_key = encoder_value = None
|
| 16 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 17 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 18 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 19 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 20 |
+
|
| 21 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 25 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def apply_rotary_emb(
|
| 29 |
+
x: torch.Tensor,
|
| 30 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 31 |
+
use_real: bool = True,
|
| 32 |
+
use_real_unbind_dim: int = -1,
|
| 33 |
+
sequence_dim: int = 2,
|
| 34 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 35 |
+
"""
|
| 36 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 37 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 38 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 39 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
x (`torch.Tensor`):
|
| 43 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 44 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 48 |
+
"""
|
| 49 |
+
if use_real:
|
| 50 |
+
cos, sin = freqs_cis # [S, D]
|
| 51 |
+
if sequence_dim == 2:
|
| 52 |
+
cos = cos[None, None, :, :]
|
| 53 |
+
sin = sin[None, None, :, :]
|
| 54 |
+
elif sequence_dim == 1:
|
| 55 |
+
cos = cos[None, :, None, :]
|
| 56 |
+
sin = sin[None, :, None, :]
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 59 |
+
|
| 60 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 61 |
+
|
| 62 |
+
if use_real_unbind_dim == -1:
|
| 63 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 64 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 65 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 66 |
+
elif use_real_unbind_dim == -2:
|
| 67 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 68 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 69 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 72 |
+
|
| 73 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 74 |
+
|
| 75 |
+
return out
|
| 76 |
+
else:
|
| 77 |
+
# used for lumina
|
| 78 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 79 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 80 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 81 |
+
|
| 82 |
+
return x_out.type_as(x)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Flux2MultiGPUsAttnProcessor2_0:
|
| 86 |
+
r"""
|
| 87 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 88 |
+
query and key vectors, but does not include spatial normalization.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self):
|
| 92 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 93 |
+
raise ImportError("Flux2MultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 94 |
+
|
| 95 |
+
def __call__(
|
| 96 |
+
self,
|
| 97 |
+
attn: "FluxAttention",
|
| 98 |
+
hidden_states: torch.Tensor,
|
| 99 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 100 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 101 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 102 |
+
text_seq_len: int = None,
|
| 103 |
+
) -> torch.FloatTensor:
|
| 104 |
+
# Determine which type of attention we're processing
|
| 105 |
+
is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None
|
| 106 |
+
|
| 107 |
+
if is_parallel_self_attn:
|
| 108 |
+
# Parallel in (QKV + MLP in) projection
|
| 109 |
+
hidden_states = attn.to_qkv_mlp_proj(hidden_states)
|
| 110 |
+
qkv, mlp_hidden_states = torch.split(
|
| 111 |
+
hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# Handle the attention logic
|
| 115 |
+
query, key, value = qkv.chunk(3, dim=-1)
|
| 116 |
+
|
| 117 |
+
else:
|
| 118 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 119 |
+
attn, hidden_states, encoder_hidden_states
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Common processing for query, key, value
|
| 123 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 124 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 125 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 126 |
+
|
| 127 |
+
query = attn.norm_q(query)
|
| 128 |
+
key = attn.norm_k(key)
|
| 129 |
+
|
| 130 |
+
# Handle encoder projections (only for standard attention)
|
| 131 |
+
if not is_parallel_self_attn and attn.added_kv_proj_dim is not None:
|
| 132 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 133 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 134 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 135 |
+
|
| 136 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 137 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 138 |
+
|
| 139 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 140 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 141 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 142 |
+
|
| 143 |
+
# Apply rotary embeddings
|
| 144 |
+
if image_rotary_emb is not None:
|
| 145 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 146 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 147 |
+
|
| 148 |
+
if not is_parallel_self_attn and attn.added_kv_proj_dim is not None and text_seq_len is None:
|
| 149 |
+
text_seq_len = encoder_query.shape[1]
|
| 150 |
+
|
| 151 |
+
txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
|
| 152 |
+
img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
|
| 153 |
+
|
| 154 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 155 |
+
def half(x):
|
| 156 |
+
return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
|
| 157 |
+
|
| 158 |
+
hidden_states = xFuserLongContextAttention()(
|
| 159 |
+
None,
|
| 160 |
+
half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
|
| 161 |
+
joint_tensor_query=half(txt_query) if txt_query is not None else None,
|
| 162 |
+
joint_tensor_key=half(txt_key) if txt_key is not None else None,
|
| 163 |
+
joint_tensor_value=half(txt_value) if txt_value is not None else None,
|
| 164 |
+
joint_strategy='front',
|
| 165 |
+
)
|
| 166 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 167 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 168 |
+
|
| 169 |
+
if is_parallel_self_attn:
|
| 170 |
+
# Handle the feedforward (FF) logic
|
| 171 |
+
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
| 172 |
+
|
| 173 |
+
# Concatenate and parallel output projection
|
| 174 |
+
hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
|
| 175 |
+
hidden_states = attn.to_out(hidden_states)
|
| 176 |
+
|
| 177 |
+
return hidden_states
|
| 178 |
+
|
| 179 |
+
else:
|
| 180 |
+
# Split encoder and latent hidden states if encoder was used
|
| 181 |
+
if encoder_hidden_states is not None:
|
| 182 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 183 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 184 |
+
)
|
| 185 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 186 |
+
|
| 187 |
+
# Project output
|
| 188 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 189 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 190 |
+
|
| 191 |
+
if encoder_hidden_states is not None:
|
| 192 |
+
return hidden_states, encoder_hidden_states
|
| 193 |
+
else:
|
| 194 |
+
return hidden_states
|
videox_fun/dist/flux_xfuser.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention_processor import Attention
|
| 6 |
+
|
| 7 |
+
from .fuser import xFuserLongContextAttention
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 11 |
+
query = attn.to_q(hidden_states)
|
| 12 |
+
key = attn.to_k(hidden_states)
|
| 13 |
+
value = attn.to_v(hidden_states)
|
| 14 |
+
|
| 15 |
+
encoder_query = encoder_key = encoder_value = None
|
| 16 |
+
if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
|
| 17 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 18 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 19 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 20 |
+
|
| 21 |
+
return query, key, value, encoder_query, encoder_key, encoder_value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
|
| 25 |
+
return _get_projections(attn, hidden_states, encoder_hidden_states)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def apply_rotary_emb(
|
| 29 |
+
x: torch.Tensor,
|
| 30 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 31 |
+
use_real: bool = True,
|
| 32 |
+
use_real_unbind_dim: int = -1,
|
| 33 |
+
sequence_dim: int = 2,
|
| 34 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 35 |
+
"""
|
| 36 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 37 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 38 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 39 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
x (`torch.Tensor`):
|
| 43 |
+
Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
|
| 44 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 48 |
+
"""
|
| 49 |
+
if use_real:
|
| 50 |
+
cos, sin = freqs_cis # [S, D]
|
| 51 |
+
if sequence_dim == 2:
|
| 52 |
+
cos = cos[None, None, :, :]
|
| 53 |
+
sin = sin[None, None, :, :]
|
| 54 |
+
elif sequence_dim == 1:
|
| 55 |
+
cos = cos[None, :, None, :]
|
| 56 |
+
sin = sin[None, :, None, :]
|
| 57 |
+
else:
|
| 58 |
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
| 59 |
+
|
| 60 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 61 |
+
|
| 62 |
+
if use_real_unbind_dim == -1:
|
| 63 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 64 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
| 65 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 66 |
+
elif use_real_unbind_dim == -2:
|
| 67 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 68 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
| 69 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 70 |
+
else:
|
| 71 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 72 |
+
|
| 73 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 74 |
+
|
| 75 |
+
return out
|
| 76 |
+
else:
|
| 77 |
+
# used for lumina
|
| 78 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 79 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 80 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 81 |
+
|
| 82 |
+
return x_out.type_as(x)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class FluxMultiGPUsAttnProcessor2_0:
|
| 86 |
+
r"""
|
| 87 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 88 |
+
query and key vectors, but does not include spatial normalization.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self):
|
| 92 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 93 |
+
raise ImportError("FluxMultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 94 |
+
|
| 95 |
+
def __call__(
|
| 96 |
+
self,
|
| 97 |
+
attn: "FluxAttention",
|
| 98 |
+
hidden_states: torch.Tensor,
|
| 99 |
+
encoder_hidden_states: torch.Tensor = None,
|
| 100 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 101 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 102 |
+
text_seq_len: int = None,
|
| 103 |
+
) -> torch.FloatTensor:
|
| 104 |
+
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
| 105 |
+
attn, hidden_states, encoder_hidden_states
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 109 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 110 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 111 |
+
|
| 112 |
+
query = attn.norm_q(query)
|
| 113 |
+
key = attn.norm_k(key)
|
| 114 |
+
|
| 115 |
+
if attn.added_kv_proj_dim is not None:
|
| 116 |
+
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
| 117 |
+
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
| 118 |
+
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
| 119 |
+
|
| 120 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 121 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 122 |
+
|
| 123 |
+
query = torch.cat([encoder_query, query], dim=1)
|
| 124 |
+
key = torch.cat([encoder_key, key], dim=1)
|
| 125 |
+
value = torch.cat([encoder_value, value], dim=1)
|
| 126 |
+
|
| 127 |
+
# Apply rotary embeddings
|
| 128 |
+
if image_rotary_emb is not None:
|
| 129 |
+
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
| 130 |
+
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
| 131 |
+
|
| 132 |
+
if attn.added_kv_proj_dim is not None and text_seq_len is None:
|
| 133 |
+
text_seq_len = encoder_query.shape[1]
|
| 134 |
+
|
| 135 |
+
txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
|
| 136 |
+
img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
|
| 137 |
+
|
| 138 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 139 |
+
def half(x):
|
| 140 |
+
return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
|
| 141 |
+
|
| 142 |
+
hidden_states = xFuserLongContextAttention()(
|
| 143 |
+
None,
|
| 144 |
+
half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
|
| 145 |
+
joint_tensor_query=half(txt_query) if txt_query is not None else None,
|
| 146 |
+
joint_tensor_key=half(txt_key) if txt_key is not None else None,
|
| 147 |
+
joint_tensor_value=half(txt_value) if txt_value is not None else None,
|
| 148 |
+
joint_strategy='front',
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Reshape back
|
| 152 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 153 |
+
hidden_states = hidden_states.to(img_query.dtype)
|
| 154 |
+
|
| 155 |
+
if encoder_hidden_states is not None:
|
| 156 |
+
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
| 157 |
+
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
| 158 |
+
)
|
| 159 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 160 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 161 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 162 |
+
|
| 163 |
+
return hidden_states, encoder_hidden_states
|
| 164 |
+
else:
|
| 165 |
+
return hidden_states
|
videox_fun/dist/fsdp.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py
|
| 2 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 3 |
+
import gc
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
| 8 |
+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
|
| 9 |
+
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
|
| 10 |
+
from torch.distributed.utils import _free_storage
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def shard_model(
|
| 14 |
+
model,
|
| 15 |
+
device_id,
|
| 16 |
+
param_dtype=torch.bfloat16,
|
| 17 |
+
reduce_dtype=torch.float32,
|
| 18 |
+
buffer_dtype=torch.float32,
|
| 19 |
+
process_group=None,
|
| 20 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
| 21 |
+
sync_module_states=True,
|
| 22 |
+
module_to_wrapper=None,
|
| 23 |
+
):
|
| 24 |
+
model = FSDP(
|
| 25 |
+
module=model,
|
| 26 |
+
process_group=process_group,
|
| 27 |
+
sharding_strategy=sharding_strategy,
|
| 28 |
+
auto_wrap_policy=partial(
|
| 29 |
+
lambda_auto_wrap_policy, lambda_fn=lambda m: m in (model.blocks if module_to_wrapper is None else module_to_wrapper)),
|
| 30 |
+
mixed_precision=MixedPrecision(
|
| 31 |
+
param_dtype=param_dtype,
|
| 32 |
+
reduce_dtype=reduce_dtype,
|
| 33 |
+
buffer_dtype=buffer_dtype),
|
| 34 |
+
device_id=device_id,
|
| 35 |
+
sync_module_states=sync_module_states)
|
| 36 |
+
return model
|
| 37 |
+
|
| 38 |
+
def free_model(model):
|
| 39 |
+
for m in model.modules():
|
| 40 |
+
if isinstance(m, FSDP):
|
| 41 |
+
_free_storage(m._handle.flat_param.data)
|
| 42 |
+
del model
|
| 43 |
+
gc.collect()
|
| 44 |
+
torch.cuda.empty_cache()
|
videox_fun/dist/fuser.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
# The pai_fuser is an internally developed acceleration package, which can be used on PAI.
|
| 8 |
+
if importlib.util.find_spec("paifuser") is not None:
|
| 9 |
+
import paifuser
|
| 10 |
+
from paifuser.xfuser.core.distributed import (
|
| 11 |
+
get_sequence_parallel_rank, get_sequence_parallel_world_size,
|
| 12 |
+
get_sp_group, get_world_group, init_distributed_environment,
|
| 13 |
+
initialize_model_parallel, model_parallel_is_initialized)
|
| 14 |
+
from paifuser.xfuser.core.long_ctx_attention import \
|
| 15 |
+
xFuserLongContextAttention
|
| 16 |
+
print("Import PAI DiT Turbo")
|
| 17 |
+
else:
|
| 18 |
+
import xfuser
|
| 19 |
+
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
| 20 |
+
get_sequence_parallel_world_size,
|
| 21 |
+
get_sp_group, get_world_group,
|
| 22 |
+
init_distributed_environment,
|
| 23 |
+
initialize_model_parallel,
|
| 24 |
+
model_parallel_is_initialized)
|
| 25 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 26 |
+
print("Xfuser import sucessful")
|
| 27 |
+
except Exception as ex:
|
| 28 |
+
get_sequence_parallel_world_size = None
|
| 29 |
+
get_sequence_parallel_rank = None
|
| 30 |
+
xFuserLongContextAttention = None
|
| 31 |
+
get_sp_group = None
|
| 32 |
+
get_world_group = None
|
| 33 |
+
init_distributed_environment = None
|
| 34 |
+
initialize_model_parallel = None
|
| 35 |
+
|
| 36 |
+
def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1):
|
| 37 |
+
if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1:
|
| 38 |
+
if get_sp_group is None:
|
| 39 |
+
raise RuntimeError("xfuser is not installed.")
|
| 40 |
+
dist.init_process_group("nccl")
|
| 41 |
+
print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % (
|
| 42 |
+
ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(),
|
| 43 |
+
dist.get_world_size()))
|
| 44 |
+
assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \
|
| 45 |
+
"number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size()
|
| 46 |
+
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
| 47 |
+
initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree,
|
| 48 |
+
classifier_free_guidance_degree=classifier_free_guidance_degree,
|
| 49 |
+
ring_degree=ring_degree,
|
| 50 |
+
ulysses_degree=ulysses_degree)
|
| 51 |
+
# device = torch.device("cuda:%d" % dist.get_rank())
|
| 52 |
+
device = torch.device(f"cuda:{get_world_group().local_rank}")
|
| 53 |
+
print('rank=%d device=%s' % (get_world_group().rank, str(device)))
|
| 54 |
+
else:
|
| 55 |
+
device = "cuda"
|
| 56 |
+
return device
|
| 57 |
+
|
| 58 |
+
def sequence_parallel_chunk(x, dim=1):
|
| 59 |
+
if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
sp_world_size = get_sequence_parallel_world_size()
|
| 63 |
+
if sp_world_size <= 1:
|
| 64 |
+
return x
|
| 65 |
+
|
| 66 |
+
sp_rank = get_sequence_parallel_rank()
|
| 67 |
+
sp_group = get_sp_group()
|
| 68 |
+
|
| 69 |
+
if x.size(1) % sp_world_size != 0:
|
| 70 |
+
raise ValueError(f"Dim 1 of x ({x.size(1)}) not divisible by SP world size ({sp_world_size})")
|
| 71 |
+
|
| 72 |
+
chunks = torch.chunk(x, sp_world_size, dim=1)
|
| 73 |
+
x = chunks[sp_rank]
|
| 74 |
+
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
def sequence_parallel_all_gather(x, dim=1):
|
| 78 |
+
if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
sp_world_size = get_sequence_parallel_world_size()
|
| 82 |
+
if sp_world_size <= 1:
|
| 83 |
+
return x # No gathering needed
|
| 84 |
+
|
| 85 |
+
sp_group = get_sp_group()
|
| 86 |
+
gathered_x = sp_group.all_gather(x, dim=dim)
|
| 87 |
+
return gathered_x
|
videox_fun/dist/hunyuanvideo_xfuser.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from diffusers.models.attention import Attention
|
| 6 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
| 7 |
+
|
| 8 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 9 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 10 |
+
init_distributed_environment, initialize_model_parallel,
|
| 11 |
+
xFuserLongContextAttention)
|
| 12 |
+
|
| 13 |
+
def extract_seqlens_from_mask(attn_mask, text_seq_length):
|
| 14 |
+
if attn_mask is None:
|
| 15 |
+
return None
|
| 16 |
+
|
| 17 |
+
if len(attn_mask.shape) == 4:
|
| 18 |
+
bs, _, _, seq_len = attn_mask.shape
|
| 19 |
+
|
| 20 |
+
if attn_mask.dtype == torch.bool:
|
| 21 |
+
valid_mask = attn_mask.squeeze(1).squeeze(1)
|
| 22 |
+
else:
|
| 23 |
+
valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1))
|
| 24 |
+
elif len(attn_mask.shape) == 3:
|
| 25 |
+
raise ValueError(
|
| 26 |
+
"attn_mask should be 2D or 4D tensor, but got {}".format(
|
| 27 |
+
attn_mask.shape))
|
| 28 |
+
|
| 29 |
+
seqlens = valid_mask[:, -text_seq_length:].sum(dim=1)
|
| 30 |
+
return seqlens
|
| 31 |
+
|
| 32 |
+
class HunyuanVideoMultiGPUsAttnProcessor2_0:
|
| 33 |
+
r"""
|
| 34 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 35 |
+
query and key vectors, but does not include spatial normalization.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
if xFuserLongContextAttention is not None:
|
| 40 |
+
try:
|
| 41 |
+
self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
|
| 42 |
+
except Exception:
|
| 43 |
+
self.hybrid_seq_parallel_attn = None
|
| 44 |
+
else:
|
| 45 |
+
self.hybrid_seq_parallel_attn = None
|
| 46 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 47 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 48 |
+
|
| 49 |
+
def __call__(
|
| 50 |
+
self,
|
| 51 |
+
attn: Attention,
|
| 52 |
+
hidden_states: torch.Tensor,
|
| 53 |
+
encoder_hidden_states: torch.Tensor,
|
| 54 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 55 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 56 |
+
) -> torch.Tensor:
|
| 57 |
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 58 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 59 |
+
|
| 60 |
+
# 1. QKV projections
|
| 61 |
+
query = attn.to_q(hidden_states)
|
| 62 |
+
key = attn.to_k(hidden_states)
|
| 63 |
+
value = attn.to_v(hidden_states)
|
| 64 |
+
|
| 65 |
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 66 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 67 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 68 |
+
|
| 69 |
+
# 2. QK normalization
|
| 70 |
+
if attn.norm_q is not None:
|
| 71 |
+
query = attn.norm_q(query)
|
| 72 |
+
if attn.norm_k is not None:
|
| 73 |
+
key = attn.norm_k(key)
|
| 74 |
+
|
| 75 |
+
# 3. Rotational positional embeddings applied to latent stream
|
| 76 |
+
if image_rotary_emb is not None:
|
| 77 |
+
if attn.add_q_proj is None and encoder_hidden_states is not None:
|
| 78 |
+
query = torch.cat(
|
| 79 |
+
[
|
| 80 |
+
apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 81 |
+
query[:, :, -encoder_hidden_states.shape[1] :],
|
| 82 |
+
],
|
| 83 |
+
dim=2,
|
| 84 |
+
)
|
| 85 |
+
key = torch.cat(
|
| 86 |
+
[
|
| 87 |
+
apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
|
| 88 |
+
key[:, :, -encoder_hidden_states.shape[1] :],
|
| 89 |
+
],
|
| 90 |
+
dim=2,
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
| 94 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
| 95 |
+
|
| 96 |
+
# 4. Encoder condition QKV projection and normalization
|
| 97 |
+
if attn.add_q_proj is not None and encoder_hidden_states is not None:
|
| 98 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 99 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 100 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 101 |
+
|
| 102 |
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 103 |
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 104 |
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 105 |
+
|
| 106 |
+
if attn.norm_added_q is not None:
|
| 107 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 108 |
+
if attn.norm_added_k is not None:
|
| 109 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 110 |
+
|
| 111 |
+
query = torch.cat([query, encoder_query], dim=2)
|
| 112 |
+
key = torch.cat([key, encoder_key], dim=2)
|
| 113 |
+
value = torch.cat([value, encoder_value], dim=2)
|
| 114 |
+
|
| 115 |
+
# 5. Attention
|
| 116 |
+
if encoder_hidden_states is not None:
|
| 117 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 118 |
+
|
| 119 |
+
q_lens = k_lens = extract_seqlens_from_mask(attention_mask, text_seq_length)
|
| 120 |
+
|
| 121 |
+
img_q = query[:, :, :-text_seq_length].transpose(1, 2)
|
| 122 |
+
txt_q = query[:, :, -text_seq_length:].transpose(1, 2)
|
| 123 |
+
img_k = key[:, :, :-text_seq_length].transpose(1, 2)
|
| 124 |
+
txt_k = key[:, :, -text_seq_length:].transpose(1, 2)
|
| 125 |
+
img_v = value[:, :, :-text_seq_length].transpose(1, 2)
|
| 126 |
+
txt_v = value[:, :, -text_seq_length:].transpose(1, 2)
|
| 127 |
+
|
| 128 |
+
hidden_states = torch.zeros_like(query.transpose(1, 2))
|
| 129 |
+
local_q_length = img_q.size()[1]
|
| 130 |
+
for i in range(len(q_lens)):
|
| 131 |
+
hidden_states[i][:local_q_length + q_lens[i]] = self.hybrid_seq_parallel_attn(
|
| 132 |
+
None,
|
| 133 |
+
img_q[i].unsqueeze(0), img_k[i].unsqueeze(0), img_v[i].unsqueeze(0), dropout_p=0.0, causal=False,
|
| 134 |
+
joint_tensor_query=txt_q[i][:q_lens[i]].unsqueeze(0),
|
| 135 |
+
joint_tensor_key=txt_k[i][:q_lens[i]].unsqueeze(0),
|
| 136 |
+
joint_tensor_value=txt_v[i][:q_lens[i]].unsqueeze(0),
|
| 137 |
+
joint_strategy='rear',
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
query = query.transpose(1, 2)
|
| 141 |
+
key = key.transpose(1, 2)
|
| 142 |
+
value = value.transpose(1, 2)
|
| 143 |
+
hidden_states = self.hybrid_seq_parallel_attn(
|
| 144 |
+
None,
|
| 145 |
+
query, key, value, dropout_p=0.0, causal=False
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 149 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 150 |
+
|
| 151 |
+
# 6. Output projection
|
| 152 |
+
if encoder_hidden_states is not None:
|
| 153 |
+
hidden_states, encoder_hidden_states = (
|
| 154 |
+
hidden_states[:, : -encoder_hidden_states.shape[1]],
|
| 155 |
+
hidden_states[:, -encoder_hidden_states.shape[1] :],
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if getattr(attn, "to_out", None) is not None:
|
| 159 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 160 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 161 |
+
|
| 162 |
+
if getattr(attn, "to_add_out", None) is not None:
|
| 163 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 164 |
+
|
| 165 |
+
return hidden_states, encoder_hidden_states
|
| 166 |
+
|
videox_fun/dist/qwen_xfuser.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import types
|
| 7 |
+
import warnings
|
| 8 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.cuda.amp as amp
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 16 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 17 |
+
from diffusers.loaders.single_file_model import FromOriginalModelMixin
|
| 18 |
+
from diffusers.models.attention import FeedForward
|
| 19 |
+
from diffusers.models.attention_processor import Attention
|
| 20 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 21 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 22 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 23 |
+
from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
|
| 24 |
+
from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
|
| 25 |
+
scale_lora_layers, unscale_lora_layers)
|
| 26 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 27 |
+
from torch import nn
|
| 28 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 29 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 30 |
+
init_distributed_environment, initialize_model_parallel,
|
| 31 |
+
xFuserLongContextAttention)
|
| 32 |
+
|
| 33 |
+
def apply_rotary_emb_qwen(
|
| 34 |
+
x: torch.Tensor,
|
| 35 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 36 |
+
use_real: bool = True,
|
| 37 |
+
use_real_unbind_dim: int = -1,
|
| 38 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 39 |
+
"""
|
| 40 |
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
| 41 |
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
| 42 |
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
| 43 |
+
tensors contain rotary embeddings and are returned as real tensors.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
x (`torch.Tensor`):
|
| 47 |
+
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
| 48 |
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 52 |
+
"""
|
| 53 |
+
if use_real:
|
| 54 |
+
cos, sin = freqs_cis # [S, D]
|
| 55 |
+
cos = cos[None, None]
|
| 56 |
+
sin = sin[None, None]
|
| 57 |
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
| 58 |
+
|
| 59 |
+
if use_real_unbind_dim == -1:
|
| 60 |
+
# Used for flux, cogvideox, hunyuan-dit
|
| 61 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
| 62 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 63 |
+
elif use_real_unbind_dim == -2:
|
| 64 |
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
| 65 |
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
| 66 |
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
| 69 |
+
|
| 70 |
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 71 |
+
|
| 72 |
+
return out
|
| 73 |
+
else:
|
| 74 |
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 75 |
+
freqs_cis = freqs_cis.unsqueeze(1)
|
| 76 |
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
| 77 |
+
|
| 78 |
+
return x_out.type_as(x)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class QwenImageMultiGPUsAttnProcessor2_0:
|
| 82 |
+
r"""
|
| 83 |
+
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
| 84 |
+
query and key vectors, but does not include spatial normalization.
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
def __init__(self):
|
| 88 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 89 |
+
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 90 |
+
|
| 91 |
+
def __call__(
|
| 92 |
+
self,
|
| 93 |
+
attn: Attention,
|
| 94 |
+
hidden_states: torch.FloatTensor, # Image stream
|
| 95 |
+
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
| 96 |
+
encoder_hidden_states_mask: torch.FloatTensor = None,
|
| 97 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 98 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
| 99 |
+
) -> torch.FloatTensor:
|
| 100 |
+
if encoder_hidden_states is None:
|
| 101 |
+
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
|
| 102 |
+
|
| 103 |
+
seq_txt = encoder_hidden_states.shape[1]
|
| 104 |
+
|
| 105 |
+
# Compute QKV for image stream (sample projections)
|
| 106 |
+
img_query = attn.to_q(hidden_states)
|
| 107 |
+
img_key = attn.to_k(hidden_states)
|
| 108 |
+
img_value = attn.to_v(hidden_states)
|
| 109 |
+
|
| 110 |
+
# Compute QKV for text stream (context projections)
|
| 111 |
+
txt_query = attn.add_q_proj(encoder_hidden_states)
|
| 112 |
+
txt_key = attn.add_k_proj(encoder_hidden_states)
|
| 113 |
+
txt_value = attn.add_v_proj(encoder_hidden_states)
|
| 114 |
+
|
| 115 |
+
# Reshape for multi-head attention
|
| 116 |
+
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
| 117 |
+
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
| 118 |
+
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
| 119 |
+
|
| 120 |
+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
| 121 |
+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
| 122 |
+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
| 123 |
+
|
| 124 |
+
# Apply QK normalization
|
| 125 |
+
if attn.norm_q is not None:
|
| 126 |
+
img_query = attn.norm_q(img_query)
|
| 127 |
+
if attn.norm_k is not None:
|
| 128 |
+
img_key = attn.norm_k(img_key)
|
| 129 |
+
if attn.norm_added_q is not None:
|
| 130 |
+
txt_query = attn.norm_added_q(txt_query)
|
| 131 |
+
if attn.norm_added_k is not None:
|
| 132 |
+
txt_key = attn.norm_added_k(txt_key)
|
| 133 |
+
|
| 134 |
+
# Apply RoPE
|
| 135 |
+
if image_rotary_emb is not None:
|
| 136 |
+
img_freqs, txt_freqs = image_rotary_emb
|
| 137 |
+
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
| 138 |
+
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
| 139 |
+
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
| 140 |
+
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
| 141 |
+
|
| 142 |
+
# Concatenate for joint attention
|
| 143 |
+
# Order: [text, image]
|
| 144 |
+
# joint_query = torch.cat([txt_query, img_query], dim=1)
|
| 145 |
+
# joint_key = torch.cat([txt_key, img_key], dim=1)
|
| 146 |
+
# joint_value = torch.cat([txt_value, img_value], dim=1)
|
| 147 |
+
|
| 148 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 149 |
+
def half(x):
|
| 150 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 151 |
+
|
| 152 |
+
joint_hidden_states = xFuserLongContextAttention()(
|
| 153 |
+
None,
|
| 154 |
+
half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
|
| 155 |
+
joint_tensor_query=half(txt_query),
|
| 156 |
+
joint_tensor_key=half(txt_key),
|
| 157 |
+
joint_tensor_value=half(txt_value),
|
| 158 |
+
joint_strategy='front',
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# Reshape back
|
| 162 |
+
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
| 163 |
+
joint_hidden_states = joint_hidden_states.to(img_query.dtype)
|
| 164 |
+
|
| 165 |
+
# Split attention outputs back
|
| 166 |
+
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
| 167 |
+
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
| 168 |
+
|
| 169 |
+
# Apply output projections
|
| 170 |
+
img_attn_output = attn.to_out[0](img_attn_output)
|
| 171 |
+
if len(attn.to_out) > 1:
|
| 172 |
+
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
| 173 |
+
|
| 174 |
+
txt_attn_output = attn.to_add_out(txt_attn_output)
|
| 175 |
+
|
| 176 |
+
return img_attn_output, txt_attn_output
|
videox_fun/dist/wan_xfuser.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.cuda.amp as amp
|
| 3 |
+
|
| 4 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 5 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 6 |
+
init_distributed_environment, initialize_model_parallel,
|
| 7 |
+
xFuserLongContextAttention)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def pad_freqs(original_tensor, target_len):
|
| 11 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 12 |
+
pad_size = target_len - seq_len
|
| 13 |
+
padding_tensor = torch.ones(
|
| 14 |
+
pad_size,
|
| 15 |
+
s1,
|
| 16 |
+
s2,
|
| 17 |
+
dtype=original_tensor.dtype,
|
| 18 |
+
device=original_tensor.device)
|
| 19 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 20 |
+
return padded_tensor
|
| 21 |
+
|
| 22 |
+
@amp.autocast(enabled=False)
|
| 23 |
+
@torch.compiler.disable()
|
| 24 |
+
def rope_apply(x, grid_sizes, freqs):
|
| 25 |
+
"""
|
| 26 |
+
x: [B, L, N, C].
|
| 27 |
+
grid_sizes: [B, 3].
|
| 28 |
+
freqs: [M, C // 2].
|
| 29 |
+
"""
|
| 30 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 31 |
+
# split freqs
|
| 32 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
| 33 |
+
|
| 34 |
+
# loop over samples
|
| 35 |
+
output = []
|
| 36 |
+
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
|
| 37 |
+
seq_len = f * h * w
|
| 38 |
+
|
| 39 |
+
# precompute multipliers
|
| 40 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
|
| 41 |
+
s, n, -1, 2))
|
| 42 |
+
freqs_i = torch.cat([
|
| 43 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 44 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 45 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 46 |
+
],
|
| 47 |
+
dim=-1).reshape(seq_len, 1, -1)
|
| 48 |
+
|
| 49 |
+
# apply rotary embedding
|
| 50 |
+
sp_size = get_sequence_parallel_world_size()
|
| 51 |
+
sp_rank = get_sequence_parallel_rank()
|
| 52 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 53 |
+
s_per_rank = s
|
| 54 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 55 |
+
s_per_rank), :, :]
|
| 56 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 57 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 58 |
+
|
| 59 |
+
# append to collection
|
| 60 |
+
output.append(x_i)
|
| 61 |
+
return torch.stack(output)
|
| 62 |
+
|
| 63 |
+
def rope_apply_qk(q, k, grid_sizes, freqs):
|
| 64 |
+
q = rope_apply(q, grid_sizes, freqs)
|
| 65 |
+
k = rope_apply(k, grid_sizes, freqs)
|
| 66 |
+
return q, k
|
| 67 |
+
|
| 68 |
+
def usp_attn_forward(self,
|
| 69 |
+
x,
|
| 70 |
+
seq_lens,
|
| 71 |
+
grid_sizes,
|
| 72 |
+
freqs,
|
| 73 |
+
dtype=torch.bfloat16,
|
| 74 |
+
t=0):
|
| 75 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 76 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 77 |
+
|
| 78 |
+
def half(x):
|
| 79 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 80 |
+
|
| 81 |
+
# query, key, value function
|
| 82 |
+
def qkv_fn(x):
|
| 83 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 84 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 85 |
+
v = self.v(x).view(b, s, n, d)
|
| 86 |
+
return q, k, v
|
| 87 |
+
|
| 88 |
+
q, k, v = qkv_fn(x)
|
| 89 |
+
q, k = rope_apply_qk(q, k, grid_sizes, freqs)
|
| 90 |
+
|
| 91 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 92 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 93 |
+
# if k_lens is not None:
|
| 94 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 95 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 96 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 97 |
+
|
| 98 |
+
x = xFuserLongContextAttention()(
|
| 99 |
+
None,
|
| 100 |
+
query=half(q),
|
| 101 |
+
key=half(k),
|
| 102 |
+
value=half(v),
|
| 103 |
+
window_size=self.window_size)
|
| 104 |
+
|
| 105 |
+
# TODO: padding after attention.
|
| 106 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 107 |
+
|
| 108 |
+
# output
|
| 109 |
+
x = x.flatten(2)
|
| 110 |
+
x = self.o(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
@amp.autocast(enabled=False)
|
| 114 |
+
@torch.compiler.disable()
|
| 115 |
+
def s2v_rope_apply(x, grid_sizes, freqs):
|
| 116 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 117 |
+
# loop over samples
|
| 118 |
+
output = []
|
| 119 |
+
for i, _ in enumerate(x):
|
| 120 |
+
s = x.size(1)
|
| 121 |
+
# precompute multipliers
|
| 122 |
+
x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
|
| 123 |
+
s, n, -1, 2))
|
| 124 |
+
freqs_i = freqs[i]
|
| 125 |
+
freqs_i_rank = pad_freqs(freqs_i, s)
|
| 126 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 127 |
+
x_i = torch.cat([x_i, x[i, s:]])
|
| 128 |
+
# append to collection
|
| 129 |
+
output.append(x_i)
|
| 130 |
+
return torch.stack(output).float()
|
| 131 |
+
|
| 132 |
+
def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
|
| 133 |
+
q = s2v_rope_apply(q, grid_sizes, freqs)
|
| 134 |
+
k = s2v_rope_apply(k, grid_sizes, freqs)
|
| 135 |
+
return q, k
|
| 136 |
+
|
| 137 |
+
def usp_attn_s2v_forward(self,
|
| 138 |
+
x,
|
| 139 |
+
seq_lens,
|
| 140 |
+
grid_sizes,
|
| 141 |
+
freqs,
|
| 142 |
+
dtype=torch.bfloat16,
|
| 143 |
+
t=0):
|
| 144 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 145 |
+
half_dtypes = (torch.float16, torch.bfloat16)
|
| 146 |
+
|
| 147 |
+
def half(x):
|
| 148 |
+
return x if x.dtype in half_dtypes else x.to(dtype)
|
| 149 |
+
|
| 150 |
+
# query, key, value function
|
| 151 |
+
def qkv_fn(x):
|
| 152 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 153 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 154 |
+
v = self.v(x).view(b, s, n, d)
|
| 155 |
+
return q, k, v
|
| 156 |
+
|
| 157 |
+
q, k, v = qkv_fn(x)
|
| 158 |
+
q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
|
| 159 |
+
|
| 160 |
+
# TODO: We should use unpaded q,k,v for attention.
|
| 161 |
+
# k_lens = seq_lens // get_sequence_parallel_world_size()
|
| 162 |
+
# if k_lens is not None:
|
| 163 |
+
# q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
|
| 164 |
+
# k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
|
| 165 |
+
# v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
|
| 166 |
+
|
| 167 |
+
x = xFuserLongContextAttention()(
|
| 168 |
+
None,
|
| 169 |
+
query=half(q),
|
| 170 |
+
key=half(k),
|
| 171 |
+
value=half(v),
|
| 172 |
+
window_size=self.window_size)
|
| 173 |
+
|
| 174 |
+
# TODO: padding after attention.
|
| 175 |
+
# x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
|
| 176 |
+
|
| 177 |
+
# output
|
| 178 |
+
x = x.flatten(2)
|
| 179 |
+
x = self.o(x)
|
| 180 |
+
return x
|
videox_fun/dist/z_image_xfuser.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.cuda.amp as amp
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from diffusers.models.attention import Attention
|
| 8 |
+
|
| 9 |
+
from .fuser import (get_sequence_parallel_rank,
|
| 10 |
+
get_sequence_parallel_world_size, get_sp_group,
|
| 11 |
+
init_distributed_environment, initialize_model_parallel,
|
| 12 |
+
xFuserLongContextAttention)
|
| 13 |
+
|
| 14 |
+
class ZMultiGPUsSingleStreamAttnProcessor:
|
| 15 |
+
"""
|
| 16 |
+
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
|
| 17 |
+
original Z-ImageAttention module.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
_attention_backend = None
|
| 21 |
+
_parallel_config = None
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 25 |
+
raise ImportError(
|
| 26 |
+
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def __call__(
|
| 30 |
+
self,
|
| 31 |
+
attn: Attention,
|
| 32 |
+
hidden_states: torch.Tensor,
|
| 33 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 34 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 35 |
+
) -> torch.Tensor:
|
| 36 |
+
query = attn.to_q(hidden_states)
|
| 37 |
+
key = attn.to_k(hidden_states)
|
| 38 |
+
value = attn.to_v(hidden_states)
|
| 39 |
+
|
| 40 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 41 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 42 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 43 |
+
|
| 44 |
+
# Apply Norms
|
| 45 |
+
if attn.norm_q is not None:
|
| 46 |
+
query = attn.norm_q(query)
|
| 47 |
+
if attn.norm_k is not None:
|
| 48 |
+
key = attn.norm_k(key)
|
| 49 |
+
|
| 50 |
+
# Apply RoPE
|
| 51 |
+
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 52 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 53 |
+
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
| 54 |
+
freqs_cis = freqs_cis.unsqueeze(2)
|
| 55 |
+
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
| 56 |
+
return x_out.type_as(x_in) # todo
|
| 57 |
+
|
| 58 |
+
if freqs_cis is not None:
|
| 59 |
+
query = apply_rotary_emb(query, freqs_cis)
|
| 60 |
+
key = apply_rotary_emb(key, freqs_cis)
|
| 61 |
+
|
| 62 |
+
# Cast to correct dtype
|
| 63 |
+
dtype = query.dtype
|
| 64 |
+
query, key = query.to(dtype), key.to(dtype)
|
| 65 |
+
|
| 66 |
+
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
| 67 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
| 68 |
+
attention_mask = attention_mask[:, None, None, :]
|
| 69 |
+
|
| 70 |
+
# Compute joint attention
|
| 71 |
+
hidden_states = xFuserLongContextAttention()(
|
| 72 |
+
query,
|
| 73 |
+
key,
|
| 74 |
+
value,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Reshape back
|
| 78 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 79 |
+
hidden_states = hidden_states.to(dtype)
|
| 80 |
+
|
| 81 |
+
output = attn.to_out[0](hidden_states)
|
| 82 |
+
if len(attn.to_out) > 1: # dropout
|
| 83 |
+
output = attn.to_out[1](output)
|
| 84 |
+
|
| 85 |
+
return output
|