diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index cfeda8520e..ba21c49d4d 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -142,6 +142,7 @@ def flash_attn_varlen_func( q_descale=None, k_descale=None, v_descale=None, + num_splits: int = 0, # Version selector fa_version: int = DEFAULT_FA_VERSION, ): @@ -224,6 +225,8 @@ def flash_attn_varlen_func( "FA2 does not support scheduler_metadata, q_descale, " "k_descale, v_descale" ) + if num_splits > 1: + raise NotImplementedError("FA2 does not support num_splits > 1") out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd( q, k, v, out, @@ -270,7 +273,7 @@ def flash_attn_varlen_func( softcap, True, # rotary_interleaved scheduler_metadata, - 0, # num_splits + num_splits, None, # pack_gqa 0, # sm_margin )