Skip to content

Commit 95d6186

Browse files
committed
[Misc] Add num_splits input arg to flash_attn_varlen_func
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 5f36441 commit 95d6186

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

vllm_flash_attn/flash_attn_interface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def flash_attn_varlen_func(
142142
q_descale=None,
143143
k_descale=None,
144144
v_descale=None,
145+
num_splits: int = 0,
145146
# Version selector
146147
fa_version: int = DEFAULT_FA_VERSION,
147148
):
@@ -224,6 +225,8 @@ def flash_attn_varlen_func(
224225
"FA2 does not support scheduler_metadata, q_descale, "
225226
"k_descale, v_descale"
226227
)
228+
if num_splits > 1:
229+
raise NotImplementedError("FA2 does not support num_splits > 1")
227230
out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
228231
q, k, v,
229232
out,
@@ -270,7 +273,7 @@ def flash_attn_varlen_func(
270273
softcap,
271274
True, # rotary_interleaved
272275
scheduler_metadata,
273-
0, # num_splits
276+
num_splits,
274277
None, # pack_gqa
275278
0, # sm_margin
276279
)

0 commit comments

Comments
 (0)