Skip to content

Commit 93cf5a0

Browse files
authored
Pass s_aux through flash_attn_with_kvcache (#79)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent 6dbc6e0 commit 93cf5a0

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

vllm_flash_attn/flash_attn_interface.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ def flash_attn_with_kvcache(
315315
v_descale=None,
316316
# Version selector
317317
fa_version: int = DEFAULT_FA_VERSION,
318+
s_aux=None,
318319
):
319320
"""
320321
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -422,6 +423,8 @@ def flash_attn_with_kvcache(
422423
"FA2 does not support scheduler_metadata, q_descale, "
423424
"k_descale, v_descale"
424425
)
426+
if s_aux is not None:
427+
raise NotImplementedError("FA2 does not support s_aux")
425428
out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache(
426429
q, k_cache, v_cache,
427430
k, v, # k_new, v_new
@@ -466,6 +469,7 @@ def flash_attn_with_kvcache(
466469
num_splits, # num_splits
467470
None, # pack_gqa
468471
0, # sm_margin
472+
s_aux, # s_aux
469473
)
470474
else:
471475
raise ValueError(f"Unsupported FA version: {fa_version}")

0 commit comments

Comments
 (0)