Alexander Bagus commited on
Commit
26893dc
·
1 Parent(s): be751d2
.gitignore CHANGED
@@ -1,9 +1,9 @@
1
 
 
2
 
3
  # Packages
4
  *.egg
5
  *.egg-info
6
- dist
7
  build
8
  eggs
9
  parts
 
1
 
2
+ /models/
3
 
4
  # Packages
5
  *.egg
6
  *.egg-info
 
7
  build
8
  eggs
9
  parts
videox_fun/dist/__init__.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ from .cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
4
+ from .flux2_xfuser import Flux2MultiGPUsAttnProcessor2_0
5
+ from .flux_xfuser import FluxMultiGPUsAttnProcessor2_0
6
+ from .fsdp import shard_model
7
+ from .fuser import (get_sequence_parallel_rank,
8
+ get_sequence_parallel_world_size, get_sp_group,
9
+ get_world_group, init_distributed_environment,
10
+ initialize_model_parallel, sequence_parallel_all_gather,
11
+ sequence_parallel_chunk, set_multi_gpus_devices,
12
+ xFuserLongContextAttention)
13
+ from .hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0
14
+ from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0
15
+ from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward
16
+ from .z_image_xfuser import ZMultiGPUsSingleStreamAttnProcessor
17
+
18
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
19
+ if importlib.util.find_spec("paifuser") is not None:
20
+ # --------------------------------------------------------------- #
21
+ # The simple_wrapper is used to solve the problem
22
+ # about conflicts between cython and torch.compile
23
+ # --------------------------------------------------------------- #
24
+ def simple_wrapper(func):
25
+ def inner(*args, **kwargs):
26
+ return func(*args, **kwargs)
27
+ return inner
28
+
29
+ # --------------------------------------------------------------- #
30
+ # Sparse Attention Kernel
31
+ # --------------------------------------------------------------- #
32
+ from paifuser.models import parallel_magvit_vae
33
+ from paifuser.ops import wan_usp_sparse_attention_wrapper
34
+
35
+ from . import wan_xfuser
36
+
37
+ # --------------------------------------------------------------- #
38
+ # Sparse Attention
39
+ # --------------------------------------------------------------- #
40
+ usp_sparse_attn_wrap_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward))
41
+ wan_xfuser.usp_attn_forward = usp_sparse_attn_wrap_forward
42
+ usp_attn_forward = usp_sparse_attn_wrap_forward
43
+ print("Import PAI VAE Turbo and Sparse Attention")
44
+
45
+ # --------------------------------------------------------------- #
46
+ # Fast Rope Kernel
47
+ # --------------------------------------------------------------- #
48
+ import types
49
+
50
+ import torch
51
+ from paifuser.ops import (ENABLE_KERNEL, usp_fast_rope_apply_qk,
52
+ usp_rope_apply_real_qk)
53
+
54
+ def deepcopy_function(f):
55
+ return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
56
+
57
+ local_rope_apply_qk = deepcopy_function(wan_xfuser.rope_apply_qk)
58
+
59
+ if ENABLE_KERNEL:
60
+ def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
61
+ if torch.is_grad_enabled():
62
+ return local_rope_apply_qk(q, k, grid_sizes, freqs)
63
+ else:
64
+ return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs)
65
+
66
+ else:
67
+ def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
68
+ return usp_rope_apply_real_qk(q, k, grid_sizes, freqs)
69
+
70
+ wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk
71
+ rope_apply_qk = adaptive_fast_usp_rope_apply_qk
72
+ print("Import PAI Fast rope")
videox_fun/dist/cogvideox_xfuser.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention import Attention
6
+ from diffusers.models.embeddings import apply_rotary_emb
7
+
8
+ from .fuser import (get_sequence_parallel_rank,
9
+ get_sequence_parallel_world_size, get_sp_group,
10
+ init_distributed_environment, initialize_model_parallel,
11
+ xFuserLongContextAttention)
12
+
13
+ class CogVideoXMultiGPUsAttnProcessor2_0:
14
+ r"""
15
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
16
+ query and key vectors, but does not include spatial normalization.
17
+ """
18
+
19
+ def __init__(self):
20
+ if not hasattr(F, "scaled_dot_product_attention"):
21
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
22
+
23
+ def __call__(
24
+ self,
25
+ attn: Attention,
26
+ hidden_states: torch.Tensor,
27
+ encoder_hidden_states: torch.Tensor,
28
+ attention_mask: Optional[torch.Tensor] = None,
29
+ image_rotary_emb: Optional[torch.Tensor] = None,
30
+ ) -> torch.Tensor:
31
+ text_seq_length = encoder_hidden_states.size(1)
32
+
33
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
34
+
35
+ batch_size, sequence_length, _ = (
36
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
+ )
38
+
39
+ if attention_mask is not None:
40
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
41
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
42
+
43
+ query = attn.to_q(hidden_states)
44
+ key = attn.to_k(hidden_states)
45
+ value = attn.to_v(hidden_states)
46
+
47
+ inner_dim = key.shape[-1]
48
+ head_dim = inner_dim // attn.heads
49
+
50
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
51
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
52
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
53
+
54
+ if attn.norm_q is not None:
55
+ query = attn.norm_q(query)
56
+ if attn.norm_k is not None:
57
+ key = attn.norm_k(key)
58
+
59
+ # Apply RoPE if needed
60
+ if image_rotary_emb is not None:
61
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
62
+ if not attn.is_cross_attention:
63
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
64
+
65
+ img_q = query[:, :, text_seq_length:].transpose(1, 2)
66
+ txt_q = query[:, :, :text_seq_length].transpose(1, 2)
67
+ img_k = key[:, :, text_seq_length:].transpose(1, 2)
68
+ txt_k = key[:, :, :text_seq_length].transpose(1, 2)
69
+ img_v = value[:, :, text_seq_length:].transpose(1, 2)
70
+ txt_v = value[:, :, :text_seq_length].transpose(1, 2)
71
+
72
+ hidden_states = xFuserLongContextAttention()(
73
+ None,
74
+ img_q, img_k, img_v, dropout_p=0.0, causal=False,
75
+ joint_tensor_query=txt_q,
76
+ joint_tensor_key=txt_k,
77
+ joint_tensor_value=txt_v,
78
+ joint_strategy='front',
79
+ )
80
+
81
+ hidden_states = hidden_states.flatten(2, 3)
82
+ hidden_states = hidden_states.to(query.dtype)
83
+
84
+ # linear proj
85
+ hidden_states = attn.to_out[0](hidden_states)
86
+ # dropout
87
+ hidden_states = attn.to_out[1](hidden_states)
88
+
89
+ encoder_hidden_states, hidden_states = hidden_states.split(
90
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
91
+ )
92
+ return hidden_states, encoder_hidden_states
93
+
videox_fun/dist/flux2_xfuser.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+ from .fuser import xFuserLongContextAttention
8
+
9
+
10
+ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
11
+ query = attn.to_q(hidden_states)
12
+ key = attn.to_k(hidden_states)
13
+ value = attn.to_v(hidden_states)
14
+
15
+ encoder_query = encoder_key = encoder_value = None
16
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
17
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
18
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
19
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
20
+
21
+ return query, key, value, encoder_query, encoder_key, encoder_value
22
+
23
+
24
+ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
25
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
26
+
27
+
28
+ def apply_rotary_emb(
29
+ x: torch.Tensor,
30
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
31
+ use_real: bool = True,
32
+ use_real_unbind_dim: int = -1,
33
+ sequence_dim: int = 2,
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ """
36
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
37
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
38
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
39
+ tensors contain rotary embeddings and are returned as real tensors.
40
+
41
+ Args:
42
+ x (`torch.Tensor`):
43
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
44
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
45
+
46
+ Returns:
47
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
48
+ """
49
+ if use_real:
50
+ cos, sin = freqs_cis # [S, D]
51
+ if sequence_dim == 2:
52
+ cos = cos[None, None, :, :]
53
+ sin = sin[None, None, :, :]
54
+ elif sequence_dim == 1:
55
+ cos = cos[None, :, None, :]
56
+ sin = sin[None, :, None, :]
57
+ else:
58
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
59
+
60
+ cos, sin = cos.to(x.device), sin.to(x.device)
61
+
62
+ if use_real_unbind_dim == -1:
63
+ # Used for flux, cogvideox, hunyuan-dit
64
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
65
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
66
+ elif use_real_unbind_dim == -2:
67
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
68
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
69
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
70
+ else:
71
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
72
+
73
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
74
+
75
+ return out
76
+ else:
77
+ # used for lumina
78
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
79
+ freqs_cis = freqs_cis.unsqueeze(2)
80
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
81
+
82
+ return x_out.type_as(x)
83
+
84
+
85
+ class Flux2MultiGPUsAttnProcessor2_0:
86
+ r"""
87
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
88
+ query and key vectors, but does not include spatial normalization.
89
+ """
90
+
91
+ def __init__(self):
92
+ if not hasattr(F, "scaled_dot_product_attention"):
93
+ raise ImportError("Flux2MultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
94
+
95
+ def __call__(
96
+ self,
97
+ attn: "FluxAttention",
98
+ hidden_states: torch.Tensor,
99
+ encoder_hidden_states: Optional[torch.Tensor] = None,
100
+ attention_mask: Optional[torch.Tensor] = None,
101
+ image_rotary_emb: Optional[torch.Tensor] = None,
102
+ text_seq_len: int = None,
103
+ ) -> torch.FloatTensor:
104
+ # Determine which type of attention we're processing
105
+ is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None
106
+
107
+ if is_parallel_self_attn:
108
+ # Parallel in (QKV + MLP in) projection
109
+ hidden_states = attn.to_qkv_mlp_proj(hidden_states)
110
+ qkv, mlp_hidden_states = torch.split(
111
+ hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
112
+ )
113
+
114
+ # Handle the attention logic
115
+ query, key, value = qkv.chunk(3, dim=-1)
116
+
117
+ else:
118
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
119
+ attn, hidden_states, encoder_hidden_states
120
+ )
121
+
122
+ # Common processing for query, key, value
123
+ query = query.unflatten(-1, (attn.heads, -1))
124
+ key = key.unflatten(-1, (attn.heads, -1))
125
+ value = value.unflatten(-1, (attn.heads, -1))
126
+
127
+ query = attn.norm_q(query)
128
+ key = attn.norm_k(key)
129
+
130
+ # Handle encoder projections (only for standard attention)
131
+ if not is_parallel_self_attn and attn.added_kv_proj_dim is not None:
132
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
133
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
134
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
135
+
136
+ encoder_query = attn.norm_added_q(encoder_query)
137
+ encoder_key = attn.norm_added_k(encoder_key)
138
+
139
+ query = torch.cat([encoder_query, query], dim=1)
140
+ key = torch.cat([encoder_key, key], dim=1)
141
+ value = torch.cat([encoder_value, value], dim=1)
142
+
143
+ # Apply rotary embeddings
144
+ if image_rotary_emb is not None:
145
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
146
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
147
+
148
+ if not is_parallel_self_attn and attn.added_kv_proj_dim is not None and text_seq_len is None:
149
+ text_seq_len = encoder_query.shape[1]
150
+
151
+ txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
152
+ img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
153
+
154
+ half_dtypes = (torch.float16, torch.bfloat16)
155
+ def half(x):
156
+ return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
157
+
158
+ hidden_states = xFuserLongContextAttention()(
159
+ None,
160
+ half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
161
+ joint_tensor_query=half(txt_query) if txt_query is not None else None,
162
+ joint_tensor_key=half(txt_key) if txt_key is not None else None,
163
+ joint_tensor_value=half(txt_value) if txt_value is not None else None,
164
+ joint_strategy='front',
165
+ )
166
+ hidden_states = hidden_states.flatten(2, 3)
167
+ hidden_states = hidden_states.to(query.dtype)
168
+
169
+ if is_parallel_self_attn:
170
+ # Handle the feedforward (FF) logic
171
+ mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
172
+
173
+ # Concatenate and parallel output projection
174
+ hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
175
+ hidden_states = attn.to_out(hidden_states)
176
+
177
+ return hidden_states
178
+
179
+ else:
180
+ # Split encoder and latent hidden states if encoder was used
181
+ if encoder_hidden_states is not None:
182
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
183
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
184
+ )
185
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
186
+
187
+ # Project output
188
+ hidden_states = attn.to_out[0](hidden_states)
189
+ hidden_states = attn.to_out[1](hidden_states)
190
+
191
+ if encoder_hidden_states is not None:
192
+ return hidden_states, encoder_hidden_states
193
+ else:
194
+ return hidden_states
videox_fun/dist/flux_xfuser.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention_processor import Attention
6
+
7
+ from .fuser import xFuserLongContextAttention
8
+
9
+
10
+ def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
11
+ query = attn.to_q(hidden_states)
12
+ key = attn.to_k(hidden_states)
13
+ value = attn.to_v(hidden_states)
14
+
15
+ encoder_query = encoder_key = encoder_value = None
16
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
17
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
18
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
19
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
20
+
21
+ return query, key, value, encoder_query, encoder_key, encoder_value
22
+
23
+
24
+ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
25
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
26
+
27
+
28
+ def apply_rotary_emb(
29
+ x: torch.Tensor,
30
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
31
+ use_real: bool = True,
32
+ use_real_unbind_dim: int = -1,
33
+ sequence_dim: int = 2,
34
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
35
+ """
36
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
37
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
38
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
39
+ tensors contain rotary embeddings and are returned as real tensors.
40
+
41
+ Args:
42
+ x (`torch.Tensor`):
43
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
44
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
45
+
46
+ Returns:
47
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
48
+ """
49
+ if use_real:
50
+ cos, sin = freqs_cis # [S, D]
51
+ if sequence_dim == 2:
52
+ cos = cos[None, None, :, :]
53
+ sin = sin[None, None, :, :]
54
+ elif sequence_dim == 1:
55
+ cos = cos[None, :, None, :]
56
+ sin = sin[None, :, None, :]
57
+ else:
58
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
59
+
60
+ cos, sin = cos.to(x.device), sin.to(x.device)
61
+
62
+ if use_real_unbind_dim == -1:
63
+ # Used for flux, cogvideox, hunyuan-dit
64
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
65
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
66
+ elif use_real_unbind_dim == -2:
67
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
68
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
69
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
70
+ else:
71
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
72
+
73
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
74
+
75
+ return out
76
+ else:
77
+ # used for lumina
78
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
79
+ freqs_cis = freqs_cis.unsqueeze(2)
80
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
81
+
82
+ return x_out.type_as(x)
83
+
84
+
85
+ class FluxMultiGPUsAttnProcessor2_0:
86
+ r"""
87
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
88
+ query and key vectors, but does not include spatial normalization.
89
+ """
90
+
91
+ def __init__(self):
92
+ if not hasattr(F, "scaled_dot_product_attention"):
93
+ raise ImportError("FluxMultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
94
+
95
+ def __call__(
96
+ self,
97
+ attn: "FluxAttention",
98
+ hidden_states: torch.Tensor,
99
+ encoder_hidden_states: torch.Tensor = None,
100
+ attention_mask: Optional[torch.Tensor] = None,
101
+ image_rotary_emb: Optional[torch.Tensor] = None,
102
+ text_seq_len: int = None,
103
+ ) -> torch.FloatTensor:
104
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
105
+ attn, hidden_states, encoder_hidden_states
106
+ )
107
+
108
+ query = query.unflatten(-1, (attn.heads, -1))
109
+ key = key.unflatten(-1, (attn.heads, -1))
110
+ value = value.unflatten(-1, (attn.heads, -1))
111
+
112
+ query = attn.norm_q(query)
113
+ key = attn.norm_k(key)
114
+
115
+ if attn.added_kv_proj_dim is not None:
116
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
117
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
118
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
119
+
120
+ encoder_query = attn.norm_added_q(encoder_query)
121
+ encoder_key = attn.norm_added_k(encoder_key)
122
+
123
+ query = torch.cat([encoder_query, query], dim=1)
124
+ key = torch.cat([encoder_key, key], dim=1)
125
+ value = torch.cat([encoder_value, value], dim=1)
126
+
127
+ # Apply rotary embeddings
128
+ if image_rotary_emb is not None:
129
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
130
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
131
+
132
+ if attn.added_kv_proj_dim is not None and text_seq_len is None:
133
+ text_seq_len = encoder_query.shape[1]
134
+
135
+ txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
136
+ img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
137
+
138
+ half_dtypes = (torch.float16, torch.bfloat16)
139
+ def half(x):
140
+ return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
141
+
142
+ hidden_states = xFuserLongContextAttention()(
143
+ None,
144
+ half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
145
+ joint_tensor_query=half(txt_query) if txt_query is not None else None,
146
+ joint_tensor_key=half(txt_key) if txt_key is not None else None,
147
+ joint_tensor_value=half(txt_value) if txt_value is not None else None,
148
+ joint_strategy='front',
149
+ )
150
+
151
+ # Reshape back
152
+ hidden_states = hidden_states.flatten(2, 3)
153
+ hidden_states = hidden_states.to(img_query.dtype)
154
+
155
+ if encoder_hidden_states is not None:
156
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
157
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
158
+ )
159
+ hidden_states = attn.to_out[0](hidden_states)
160
+ hidden_states = attn.to_out[1](hidden_states)
161
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
162
+
163
+ return hidden_states, encoder_hidden_states
164
+ else:
165
+ return hidden_states
videox_fun/dist/fsdp.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import gc
4
+ from functools import partial
5
+
6
+ import torch
7
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
9
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
10
+ from torch.distributed.utils import _free_storage
11
+
12
+
13
+ def shard_model(
14
+ model,
15
+ device_id,
16
+ param_dtype=torch.bfloat16,
17
+ reduce_dtype=torch.float32,
18
+ buffer_dtype=torch.float32,
19
+ process_group=None,
20
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
21
+ sync_module_states=True,
22
+ module_to_wrapper=None,
23
+ ):
24
+ model = FSDP(
25
+ module=model,
26
+ process_group=process_group,
27
+ sharding_strategy=sharding_strategy,
28
+ auto_wrap_policy=partial(
29
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in (model.blocks if module_to_wrapper is None else module_to_wrapper)),
30
+ mixed_precision=MixedPrecision(
31
+ param_dtype=param_dtype,
32
+ reduce_dtype=reduce_dtype,
33
+ buffer_dtype=buffer_dtype),
34
+ device_id=device_id,
35
+ sync_module_states=sync_module_states)
36
+ return model
37
+
38
+ def free_model(model):
39
+ for m in model.modules():
40
+ if isinstance(m, FSDP):
41
+ _free_storage(m._handle.flat_param.data)
42
+ del model
43
+ gc.collect()
44
+ torch.cuda.empty_cache()
videox_fun/dist/fuser.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+ try:
7
+ # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
8
+ if importlib.util.find_spec("paifuser") is not None:
9
+ import paifuser
10
+ from paifuser.xfuser.core.distributed import (
11
+ get_sequence_parallel_rank, get_sequence_parallel_world_size,
12
+ get_sp_group, get_world_group, init_distributed_environment,
13
+ initialize_model_parallel, model_parallel_is_initialized)
14
+ from paifuser.xfuser.core.long_ctx_attention import \
15
+ xFuserLongContextAttention
16
+ print("Import PAI DiT Turbo")
17
+ else:
18
+ import xfuser
19
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
20
+ get_sequence_parallel_world_size,
21
+ get_sp_group, get_world_group,
22
+ init_distributed_environment,
23
+ initialize_model_parallel,
24
+ model_parallel_is_initialized)
25
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
26
+ print("Xfuser import sucessful")
27
+ except Exception as ex:
28
+ get_sequence_parallel_world_size = None
29
+ get_sequence_parallel_rank = None
30
+ xFuserLongContextAttention = None
31
+ get_sp_group = None
32
+ get_world_group = None
33
+ init_distributed_environment = None
34
+ initialize_model_parallel = None
35
+
36
+ def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1):
37
+ if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1:
38
+ if get_sp_group is None:
39
+ raise RuntimeError("xfuser is not installed.")
40
+ dist.init_process_group("nccl")
41
+ print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % (
42
+ ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(),
43
+ dist.get_world_size()))
44
+ assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \
45
+ "number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size()
46
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
47
+ initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree,
48
+ classifier_free_guidance_degree=classifier_free_guidance_degree,
49
+ ring_degree=ring_degree,
50
+ ulysses_degree=ulysses_degree)
51
+ # device = torch.device("cuda:%d" % dist.get_rank())
52
+ device = torch.device(f"cuda:{get_world_group().local_rank}")
53
+ print('rank=%d device=%s' % (get_world_group().rank, str(device)))
54
+ else:
55
+ device = "cuda"
56
+ return device
57
+
58
+ def sequence_parallel_chunk(x, dim=1):
59
+ if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
60
+ return x
61
+
62
+ sp_world_size = get_sequence_parallel_world_size()
63
+ if sp_world_size <= 1:
64
+ return x
65
+
66
+ sp_rank = get_sequence_parallel_rank()
67
+ sp_group = get_sp_group()
68
+
69
+ if x.size(1) % sp_world_size != 0:
70
+ raise ValueError(f"Dim 1 of x ({x.size(1)}) not divisible by SP world size ({sp_world_size})")
71
+
72
+ chunks = torch.chunk(x, sp_world_size, dim=1)
73
+ x = chunks[sp_rank]
74
+
75
+ return x
76
+
77
+ def sequence_parallel_all_gather(x, dim=1):
78
+ if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
79
+ return x
80
+
81
+ sp_world_size = get_sequence_parallel_world_size()
82
+ if sp_world_size <= 1:
83
+ return x # No gathering needed
84
+
85
+ sp_group = get_sp_group()
86
+ gathered_x = sp_group.all_gather(x, dim=dim)
87
+ return gathered_x
videox_fun/dist/hunyuanvideo_xfuser.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from diffusers.models.attention import Attention
6
+ from diffusers.models.embeddings import apply_rotary_emb
7
+
8
+ from .fuser import (get_sequence_parallel_rank,
9
+ get_sequence_parallel_world_size, get_sp_group,
10
+ init_distributed_environment, initialize_model_parallel,
11
+ xFuserLongContextAttention)
12
+
13
+ def extract_seqlens_from_mask(attn_mask, text_seq_length):
14
+ if attn_mask is None:
15
+ return None
16
+
17
+ if len(attn_mask.shape) == 4:
18
+ bs, _, _, seq_len = attn_mask.shape
19
+
20
+ if attn_mask.dtype == torch.bool:
21
+ valid_mask = attn_mask.squeeze(1).squeeze(1)
22
+ else:
23
+ valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1))
24
+ elif len(attn_mask.shape) == 3:
25
+ raise ValueError(
26
+ "attn_mask should be 2D or 4D tensor, but got {}".format(
27
+ attn_mask.shape))
28
+
29
+ seqlens = valid_mask[:, -text_seq_length:].sum(dim=1)
30
+ return seqlens
31
+
32
+ class HunyuanVideoMultiGPUsAttnProcessor2_0:
33
+ r"""
34
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
35
+ query and key vectors, but does not include spatial normalization.
36
+ """
37
+
38
+ def __init__(self):
39
+ if xFuserLongContextAttention is not None:
40
+ try:
41
+ self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
42
+ except Exception:
43
+ self.hybrid_seq_parallel_attn = None
44
+ else:
45
+ self.hybrid_seq_parallel_attn = None
46
+ if not hasattr(F, "scaled_dot_product_attention"):
47
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
48
+
49
+ def __call__(
50
+ self,
51
+ attn: Attention,
52
+ hidden_states: torch.Tensor,
53
+ encoder_hidden_states: torch.Tensor,
54
+ attention_mask: Optional[torch.Tensor] = None,
55
+ image_rotary_emb: Optional[torch.Tensor] = None,
56
+ ) -> torch.Tensor:
57
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
58
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
59
+
60
+ # 1. QKV projections
61
+ query = attn.to_q(hidden_states)
62
+ key = attn.to_k(hidden_states)
63
+ value = attn.to_v(hidden_states)
64
+
65
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
66
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
67
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
68
+
69
+ # 2. QK normalization
70
+ if attn.norm_q is not None:
71
+ query = attn.norm_q(query)
72
+ if attn.norm_k is not None:
73
+ key = attn.norm_k(key)
74
+
75
+ # 3. Rotational positional embeddings applied to latent stream
76
+ if image_rotary_emb is not None:
77
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
78
+ query = torch.cat(
79
+ [
80
+ apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
81
+ query[:, :, -encoder_hidden_states.shape[1] :],
82
+ ],
83
+ dim=2,
84
+ )
85
+ key = torch.cat(
86
+ [
87
+ apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
88
+ key[:, :, -encoder_hidden_states.shape[1] :],
89
+ ],
90
+ dim=2,
91
+ )
92
+ else:
93
+ query = apply_rotary_emb(query, image_rotary_emb)
94
+ key = apply_rotary_emb(key, image_rotary_emb)
95
+
96
+ # 4. Encoder condition QKV projection and normalization
97
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
98
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
99
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
100
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
101
+
102
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
103
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
104
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
105
+
106
+ if attn.norm_added_q is not None:
107
+ encoder_query = attn.norm_added_q(encoder_query)
108
+ if attn.norm_added_k is not None:
109
+ encoder_key = attn.norm_added_k(encoder_key)
110
+
111
+ query = torch.cat([query, encoder_query], dim=2)
112
+ key = torch.cat([key, encoder_key], dim=2)
113
+ value = torch.cat([value, encoder_value], dim=2)
114
+
115
+ # 5. Attention
116
+ if encoder_hidden_states is not None:
117
+ text_seq_length = encoder_hidden_states.size(1)
118
+
119
+ q_lens = k_lens = extract_seqlens_from_mask(attention_mask, text_seq_length)
120
+
121
+ img_q = query[:, :, :-text_seq_length].transpose(1, 2)
122
+ txt_q = query[:, :, -text_seq_length:].transpose(1, 2)
123
+ img_k = key[:, :, :-text_seq_length].transpose(1, 2)
124
+ txt_k = key[:, :, -text_seq_length:].transpose(1, 2)
125
+ img_v = value[:, :, :-text_seq_length].transpose(1, 2)
126
+ txt_v = value[:, :, -text_seq_length:].transpose(1, 2)
127
+
128
+ hidden_states = torch.zeros_like(query.transpose(1, 2))
129
+ local_q_length = img_q.size()[1]
130
+ for i in range(len(q_lens)):
131
+ hidden_states[i][:local_q_length + q_lens[i]] = self.hybrid_seq_parallel_attn(
132
+ None,
133
+ img_q[i].unsqueeze(0), img_k[i].unsqueeze(0), img_v[i].unsqueeze(0), dropout_p=0.0, causal=False,
134
+ joint_tensor_query=txt_q[i][:q_lens[i]].unsqueeze(0),
135
+ joint_tensor_key=txt_k[i][:q_lens[i]].unsqueeze(0),
136
+ joint_tensor_value=txt_v[i][:q_lens[i]].unsqueeze(0),
137
+ joint_strategy='rear',
138
+ )
139
+ else:
140
+ query = query.transpose(1, 2)
141
+ key = key.transpose(1, 2)
142
+ value = value.transpose(1, 2)
143
+ hidden_states = self.hybrid_seq_parallel_attn(
144
+ None,
145
+ query, key, value, dropout_p=0.0, causal=False
146
+ )
147
+
148
+ hidden_states = hidden_states.flatten(2, 3)
149
+ hidden_states = hidden_states.to(query.dtype)
150
+
151
+ # 6. Output projection
152
+ if encoder_hidden_states is not None:
153
+ hidden_states, encoder_hidden_states = (
154
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
155
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
156
+ )
157
+
158
+ if getattr(attn, "to_out", None) is not None:
159
+ hidden_states = attn.to_out[0](hidden_states)
160
+ hidden_states = attn.to_out[1](hidden_states)
161
+
162
+ if getattr(attn, "to_add_out", None) is not None:
163
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
164
+
165
+ return hidden_states, encoder_hidden_states
166
+
videox_fun/dist/qwen_xfuser.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import types
7
+ import warnings
8
+ from typing import Any, Dict, List, Optional, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.cuda.amp as amp
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
16
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
17
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
+ from diffusers.models.attention import FeedForward
19
+ from diffusers.models.attention_processor import Attention
20
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
21
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
22
+ from diffusers.models.modeling_utils import ModelMixin
23
+ from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
24
+ from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
25
+ scale_lora_layers, unscale_lora_layers)
26
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
27
+ from torch import nn
28
+ from .fuser import (get_sequence_parallel_rank,
29
+ get_sequence_parallel_world_size, get_sp_group,
30
+ init_distributed_environment, initialize_model_parallel,
31
+ xFuserLongContextAttention)
32
+
33
+ def apply_rotary_emb_qwen(
34
+ x: torch.Tensor,
35
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
36
+ use_real: bool = True,
37
+ use_real_unbind_dim: int = -1,
38
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
39
+ """
40
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
41
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
42
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
43
+ tensors contain rotary embeddings and are returned as real tensors.
44
+
45
+ Args:
46
+ x (`torch.Tensor`):
47
+ Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
48
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
49
+
50
+ Returns:
51
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
52
+ """
53
+ if use_real:
54
+ cos, sin = freqs_cis # [S, D]
55
+ cos = cos[None, None]
56
+ sin = sin[None, None]
57
+ cos, sin = cos.to(x.device), sin.to(x.device)
58
+
59
+ if use_real_unbind_dim == -1:
60
+ # Used for flux, cogvideox, hunyuan-dit
61
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
62
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
63
+ elif use_real_unbind_dim == -2:
64
+ # Used for Stable Audio, OmniGen, CogView4 and Cosmos
65
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
66
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
67
+ else:
68
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
69
+
70
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
71
+
72
+ return out
73
+ else:
74
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
75
+ freqs_cis = freqs_cis.unsqueeze(1)
76
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
77
+
78
+ return x_out.type_as(x)
79
+
80
+
81
+ class QwenImageMultiGPUsAttnProcessor2_0:
82
+ r"""
83
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
84
+ query and key vectors, but does not include spatial normalization.
85
+ """
86
+
87
+ def __init__(self):
88
+ if not hasattr(F, "scaled_dot_product_attention"):
89
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
90
+
91
+ def __call__(
92
+ self,
93
+ attn: Attention,
94
+ hidden_states: torch.FloatTensor, # Image stream
95
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
96
+ encoder_hidden_states_mask: torch.FloatTensor = None,
97
+ attention_mask: Optional[torch.FloatTensor] = None,
98
+ image_rotary_emb: Optional[torch.Tensor] = None,
99
+ ) -> torch.FloatTensor:
100
+ if encoder_hidden_states is None:
101
+ raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
102
+
103
+ seq_txt = encoder_hidden_states.shape[1]
104
+
105
+ # Compute QKV for image stream (sample projections)
106
+ img_query = attn.to_q(hidden_states)
107
+ img_key = attn.to_k(hidden_states)
108
+ img_value = attn.to_v(hidden_states)
109
+
110
+ # Compute QKV for text stream (context projections)
111
+ txt_query = attn.add_q_proj(encoder_hidden_states)
112
+ txt_key = attn.add_k_proj(encoder_hidden_states)
113
+ txt_value = attn.add_v_proj(encoder_hidden_states)
114
+
115
+ # Reshape for multi-head attention
116
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
117
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
118
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
119
+
120
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
121
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
122
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
123
+
124
+ # Apply QK normalization
125
+ if attn.norm_q is not None:
126
+ img_query = attn.norm_q(img_query)
127
+ if attn.norm_k is not None:
128
+ img_key = attn.norm_k(img_key)
129
+ if attn.norm_added_q is not None:
130
+ txt_query = attn.norm_added_q(txt_query)
131
+ if attn.norm_added_k is not None:
132
+ txt_key = attn.norm_added_k(txt_key)
133
+
134
+ # Apply RoPE
135
+ if image_rotary_emb is not None:
136
+ img_freqs, txt_freqs = image_rotary_emb
137
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
138
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
139
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
140
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
141
+
142
+ # Concatenate for joint attention
143
+ # Order: [text, image]
144
+ # joint_query = torch.cat([txt_query, img_query], dim=1)
145
+ # joint_key = torch.cat([txt_key, img_key], dim=1)
146
+ # joint_value = torch.cat([txt_value, img_value], dim=1)
147
+
148
+ half_dtypes = (torch.float16, torch.bfloat16)
149
+ def half(x):
150
+ return x if x.dtype in half_dtypes else x.to(dtype)
151
+
152
+ joint_hidden_states = xFuserLongContextAttention()(
153
+ None,
154
+ half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
155
+ joint_tensor_query=half(txt_query),
156
+ joint_tensor_key=half(txt_key),
157
+ joint_tensor_value=half(txt_value),
158
+ joint_strategy='front',
159
+ )
160
+
161
+ # Reshape back
162
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
163
+ joint_hidden_states = joint_hidden_states.to(img_query.dtype)
164
+
165
+ # Split attention outputs back
166
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
167
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
168
+
169
+ # Apply output projections
170
+ img_attn_output = attn.to_out[0](img_attn_output)
171
+ if len(attn.to_out) > 1:
172
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
173
+
174
+ txt_attn_output = attn.to_add_out(txt_attn_output)
175
+
176
+ return img_attn_output, txt_attn_output
videox_fun/dist/wan_xfuser.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.cuda.amp as amp
3
+
4
+ from .fuser import (get_sequence_parallel_rank,
5
+ get_sequence_parallel_world_size, get_sp_group,
6
+ init_distributed_environment, initialize_model_parallel,
7
+ xFuserLongContextAttention)
8
+
9
+
10
+ def pad_freqs(original_tensor, target_len):
11
+ seq_len, s1, s2 = original_tensor.shape
12
+ pad_size = target_len - seq_len
13
+ padding_tensor = torch.ones(
14
+ pad_size,
15
+ s1,
16
+ s2,
17
+ dtype=original_tensor.dtype,
18
+ device=original_tensor.device)
19
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
+ return padded_tensor
21
+
22
+ @amp.autocast(enabled=False)
23
+ @torch.compiler.disable()
24
+ def rope_apply(x, grid_sizes, freqs):
25
+ """
26
+ x: [B, L, N, C].
27
+ grid_sizes: [B, 3].
28
+ freqs: [M, C // 2].
29
+ """
30
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
+ # split freqs
32
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
+
34
+ # loop over samples
35
+ output = []
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
41
+ s, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_sequence_parallel_world_size()
51
+ sp_rank = get_sequence_parallel_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
+ s_per_rank), :, :]
56
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
+ x_i = torch.cat([x_i, x[i, s:]])
58
+
59
+ # append to collection
60
+ output.append(x_i)
61
+ return torch.stack(output)
62
+
63
+ def rope_apply_qk(q, k, grid_sizes, freqs):
64
+ q = rope_apply(q, grid_sizes, freqs)
65
+ k = rope_apply(k, grid_sizes, freqs)
66
+ return q, k
67
+
68
+ def usp_attn_forward(self,
69
+ x,
70
+ seq_lens,
71
+ grid_sizes,
72
+ freqs,
73
+ dtype=torch.bfloat16,
74
+ t=0):
75
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
76
+ half_dtypes = (torch.float16, torch.bfloat16)
77
+
78
+ def half(x):
79
+ return x if x.dtype in half_dtypes else x.to(dtype)
80
+
81
+ # query, key, value function
82
+ def qkv_fn(x):
83
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
84
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
85
+ v = self.v(x).view(b, s, n, d)
86
+ return q, k, v
87
+
88
+ q, k, v = qkv_fn(x)
89
+ q, k = rope_apply_qk(q, k, grid_sizes, freqs)
90
+
91
+ # TODO: We should use unpaded q,k,v for attention.
92
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
93
+ # if k_lens is not None:
94
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
95
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
96
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
97
+
98
+ x = xFuserLongContextAttention()(
99
+ None,
100
+ query=half(q),
101
+ key=half(k),
102
+ value=half(v),
103
+ window_size=self.window_size)
104
+
105
+ # TODO: padding after attention.
106
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
107
+
108
+ # output
109
+ x = x.flatten(2)
110
+ x = self.o(x)
111
+ return x
112
+
113
+ @amp.autocast(enabled=False)
114
+ @torch.compiler.disable()
115
+ def s2v_rope_apply(x, grid_sizes, freqs):
116
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
117
+ # loop over samples
118
+ output = []
119
+ for i, _ in enumerate(x):
120
+ s = x.size(1)
121
+ # precompute multipliers
122
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
123
+ s, n, -1, 2))
124
+ freqs_i = freqs[i]
125
+ freqs_i_rank = pad_freqs(freqs_i, s)
126
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
127
+ x_i = torch.cat([x_i, x[i, s:]])
128
+ # append to collection
129
+ output.append(x_i)
130
+ return torch.stack(output).float()
131
+
132
+ def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
133
+ q = s2v_rope_apply(q, grid_sizes, freqs)
134
+ k = s2v_rope_apply(k, grid_sizes, freqs)
135
+ return q, k
136
+
137
+ def usp_attn_s2v_forward(self,
138
+ x,
139
+ seq_lens,
140
+ grid_sizes,
141
+ freqs,
142
+ dtype=torch.bfloat16,
143
+ t=0):
144
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
145
+ half_dtypes = (torch.float16, torch.bfloat16)
146
+
147
+ def half(x):
148
+ return x if x.dtype in half_dtypes else x.to(dtype)
149
+
150
+ # query, key, value function
151
+ def qkv_fn(x):
152
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
153
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
154
+ v = self.v(x).view(b, s, n, d)
155
+ return q, k, v
156
+
157
+ q, k, v = qkv_fn(x)
158
+ q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
159
+
160
+ # TODO: We should use unpaded q,k,v for attention.
161
+ # k_lens = seq_lens // get_sequence_parallel_world_size()
162
+ # if k_lens is not None:
163
+ # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
164
+ # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
165
+ # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
166
+
167
+ x = xFuserLongContextAttention()(
168
+ None,
169
+ query=half(q),
170
+ key=half(k),
171
+ value=half(v),
172
+ window_size=self.window_size)
173
+
174
+ # TODO: padding after attention.
175
+ # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
176
+
177
+ # output
178
+ x = x.flatten(2)
179
+ x = self.o(x)
180
+ return x
videox_fun/dist/z_image_xfuser.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.cuda.amp as amp
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.attention import Attention
8
+
9
+ from .fuser import (get_sequence_parallel_rank,
10
+ get_sequence_parallel_world_size, get_sp_group,
11
+ init_distributed_environment, initialize_model_parallel,
12
+ xFuserLongContextAttention)
13
+
14
+ class ZMultiGPUsSingleStreamAttnProcessor:
15
+ """
16
+ Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
17
+ original Z-ImageAttention module.
18
+ """
19
+
20
+ _attention_backend = None
21
+ _parallel_config = None
22
+
23
+ def __init__(self):
24
+ if not hasattr(F, "scaled_dot_product_attention"):
25
+ raise ImportError(
26
+ "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
27
+ )
28
+
29
+ def __call__(
30
+ self,
31
+ attn: Attention,
32
+ hidden_states: torch.Tensor,
33
+ attention_mask: Optional[torch.Tensor] = None,
34
+ freqs_cis: Optional[torch.Tensor] = None,
35
+ ) -> torch.Tensor:
36
+ query = attn.to_q(hidden_states)
37
+ key = attn.to_k(hidden_states)
38
+ value = attn.to_v(hidden_states)
39
+
40
+ query = query.unflatten(-1, (attn.heads, -1))
41
+ key = key.unflatten(-1, (attn.heads, -1))
42
+ value = value.unflatten(-1, (attn.heads, -1))
43
+
44
+ # Apply Norms
45
+ if attn.norm_q is not None:
46
+ query = attn.norm_q(query)
47
+ if attn.norm_k is not None:
48
+ key = attn.norm_k(key)
49
+
50
+ # Apply RoPE
51
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
52
+ with torch.amp.autocast("cuda", enabled=False):
53
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
54
+ freqs_cis = freqs_cis.unsqueeze(2)
55
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
56
+ return x_out.type_as(x_in) # todo
57
+
58
+ if freqs_cis is not None:
59
+ query = apply_rotary_emb(query, freqs_cis)
60
+ key = apply_rotary_emb(key, freqs_cis)
61
+
62
+ # Cast to correct dtype
63
+ dtype = query.dtype
64
+ query, key = query.to(dtype), key.to(dtype)
65
+
66
+ # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
67
+ if attention_mask is not None and attention_mask.ndim == 2:
68
+ attention_mask = attention_mask[:, None, None, :]
69
+
70
+ # Compute joint attention
71
+ hidden_states = xFuserLongContextAttention()(
72
+ query,
73
+ key,
74
+ value,
75
+ )
76
+
77
+ # Reshape back
78
+ hidden_states = hidden_states.flatten(2, 3)
79
+ hidden_states = hidden_states.to(dtype)
80
+
81
+ output = attn.to_out[0](hidden_states)
82
+ if len(attn.to_out) > 1: # dropout
83
+ output = attn.to_out[1](output)
84
+
85
+ return output