Spaces:
Paused
Paused
alex
commited on
Commit
·
b2ed9cf
1
Parent(s):
49ba373
3 is good enough
Browse files- ovi/modules/attention.py +29 -9
ovi/modules/attention.py
CHANGED
|
@@ -55,7 +55,7 @@ def flash_attention(
|
|
| 55 |
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
| 56 |
|
| 57 |
# params
|
| 58 |
-
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
|
| 59 |
|
| 60 |
def half(x):
|
| 61 |
return x if x.dtype in half_dtypes else x.to(dtype)
|
|
@@ -93,26 +93,46 @@ def flash_attention(
|
|
| 93 |
|
| 94 |
# apply attention
|
| 95 |
if FLASH_ATTN_3_AVAILABLE:
|
| 96 |
-
|
| 97 |
-
x = flash_attn_interface.flash_attn_varlen_func(
|
| 98 |
q=q,
|
| 99 |
k=k,
|
| 100 |
v=v,
|
| 101 |
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 102 |
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 103 |
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 104 |
-
0, dtype=torch.int32).to(
|
| 105 |
seqused_q=None,
|
| 106 |
seqused_k=None,
|
| 107 |
max_seqlen_q=lq,
|
| 108 |
max_seqlen_k=lk,
|
| 109 |
softmax_scale=softmax_scale,
|
| 110 |
causal=causal,
|
| 111 |
-
deterministic=deterministic
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
else:
|
| 118 |
assert FLASH_ATTN_2_AVAILABLE
|
|
|
|
| 55 |
assert q.device.type == 'cuda' and q.size(-1) <= 256
|
| 56 |
|
| 57 |
# params
|
| 58 |
+
b, lq, nheads, lk, out_dtype = q.size(0), q.size(1), q.size(2), k.size(1), q.dtype
|
| 59 |
|
| 60 |
def half(x):
|
| 61 |
return x if x.dtype in half_dtypes else x.to(dtype)
|
|
|
|
| 93 |
|
| 94 |
# apply attention
|
| 95 |
if FLASH_ATTN_3_AVAILABLE:
|
| 96 |
+
ret = flash_attn_interface.flash_attn_varlen_func(
|
|
|
|
| 97 |
q=q,
|
| 98 |
k=k,
|
| 99 |
v=v,
|
| 100 |
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
|
| 101 |
0, dtype=torch.int32).to(q.device, non_blocking=True),
|
| 102 |
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
|
| 103 |
+
0, dtype=torch.int32).to(k.device, non_blocking=True),
|
| 104 |
seqused_q=None,
|
| 105 |
seqused_k=None,
|
| 106 |
max_seqlen_q=lq,
|
| 107 |
max_seqlen_k=lk,
|
| 108 |
softmax_scale=softmax_scale,
|
| 109 |
causal=causal,
|
| 110 |
+
deterministic=deterministic
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Some FA3 wheels return (out, softmax_lse); some return just out.
|
| 114 |
+
out0 = ret[0] if isinstance(ret, (tuple, list)) else ret
|
| 115 |
+
|
| 116 |
+
# Normalize FA3 output layout to (total_q, nheads, headdim)
|
| 117 |
+
total_q = b * lq
|
| 118 |
+
if out0.dim() == 3:
|
| 119 |
+
if out0.shape[0] == total_q:
|
| 120 |
+
pass # (total_q, nheads, headdim) -> good
|
| 121 |
+
elif out0.shape[0] == nheads and out0.shape[1] == total_q:
|
| 122 |
+
# heads-first -> transpose to (total_q, nheads, headdim)
|
| 123 |
+
out0 = out0.transpose(0, 1).contiguous()
|
| 124 |
+
else:
|
| 125 |
+
raise RuntimeError(
|
| 126 |
+
f"Unexpected FA3 output shape {tuple(out0.shape)}; "
|
| 127 |
+
f"expected (total_q, nheads, headdim) or (nheads, total_q, headdim)"
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
raise RuntimeError(
|
| 131 |
+
f"Unexpected FA3 output rank {out0.dim()} with shape {tuple(out0.shape)}; "
|
| 132 |
+
f"expected a 3D tensor."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
x = out0.unflatten(0, (b, lq))
|
| 136 |
|
| 137 |
else:
|
| 138 |
assert FLASH_ATTN_2_AVAILABLE
|