alex commited on
Commit
b2ed9cf
·
1 Parent(s): 49ba373

3 is good enough

Browse files
Files changed (1) hide show
  1. ovi/modules/attention.py +29 -9
ovi/modules/attention.py CHANGED
@@ -55,7 +55,7 @@ def flash_attention(
55
  assert q.device.type == 'cuda' and q.size(-1) <= 256
56
 
57
  # params
58
- b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
59
 
60
  def half(x):
61
  return x if x.dtype in half_dtypes else x.to(dtype)
@@ -93,26 +93,46 @@ def flash_attention(
93
 
94
  # apply attention
95
  if FLASH_ATTN_3_AVAILABLE:
96
- # Note: dropout_p, window_size are not supported in FA3 now.
97
- x = flash_attn_interface.flash_attn_varlen_func(
98
  q=q,
99
  k=k,
100
  v=v,
101
  cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
102
  0, dtype=torch.int32).to(q.device, non_blocking=True),
103
  cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
104
- 0, dtype=torch.int32).to(q.device, non_blocking=True),
105
  seqused_q=None,
106
  seqused_k=None,
107
  max_seqlen_q=lq,
108
  max_seqlen_k=lk,
109
  softmax_scale=softmax_scale,
110
  causal=causal,
111
- deterministic=deterministic)
112
-
113
- if isinstance(x, tuple):
114
- x = x[0]
115
- x = x.unflatten(0, (b, lq))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  else:
118
  assert FLASH_ATTN_2_AVAILABLE
 
55
  assert q.device.type == 'cuda' and q.size(-1) <= 256
56
 
57
  # params
58
+ b, lq, nheads, lk, out_dtype = q.size(0), q.size(1), q.size(2), k.size(1), q.dtype
59
 
60
  def half(x):
61
  return x if x.dtype in half_dtypes else x.to(dtype)
 
93
 
94
  # apply attention
95
  if FLASH_ATTN_3_AVAILABLE:
96
+ ret = flash_attn_interface.flash_attn_varlen_func(
 
97
  q=q,
98
  k=k,
99
  v=v,
100
  cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
101
  0, dtype=torch.int32).to(q.device, non_blocking=True),
102
  cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
103
+ 0, dtype=torch.int32).to(k.device, non_blocking=True),
104
  seqused_q=None,
105
  seqused_k=None,
106
  max_seqlen_q=lq,
107
  max_seqlen_k=lk,
108
  softmax_scale=softmax_scale,
109
  causal=causal,
110
+ deterministic=deterministic
111
+ )
112
+
113
+ # Some FA3 wheels return (out, softmax_lse); some return just out.
114
+ out0 = ret[0] if isinstance(ret, (tuple, list)) else ret
115
+
116
+ # Normalize FA3 output layout to (total_q, nheads, headdim)
117
+ total_q = b * lq
118
+ if out0.dim() == 3:
119
+ if out0.shape[0] == total_q:
120
+ pass # (total_q, nheads, headdim) -> good
121
+ elif out0.shape[0] == nheads and out0.shape[1] == total_q:
122
+ # heads-first -> transpose to (total_q, nheads, headdim)
123
+ out0 = out0.transpose(0, 1).contiguous()
124
+ else:
125
+ raise RuntimeError(
126
+ f"Unexpected FA3 output shape {tuple(out0.shape)}; "
127
+ f"expected (total_q, nheads, headdim) or (nheads, total_q, headdim)"
128
+ )
129
+ else:
130
+ raise RuntimeError(
131
+ f"Unexpected FA3 output rank {out0.dim()} with shape {tuple(out0.shape)}; "
132
+ f"expected a 3D tensor."
133
+ )
134
+
135
+ x = out0.unflatten(0, (b, lq))
136
 
137
  else:
138
  assert FLASH_ATTN_2_AVAILABLE