From 503843e90262945f839a061f030cc95c1b1402b0 Mon Sep 17 00:00:00 2001 From: Mickael Seznec Date: Wed, 3 Sep 2025 16:06:41 +0200 Subject: [PATCH 1/2] fix: reenable fa3 path for v0 Signed-off-by: Mickael Seznec --- vllm_flash_attn/flash_attn_interface.py | 84 ++++++++++++++++--------- 1 file changed, 56 insertions(+), 28 deletions(-) diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 06de7fd17b..be557fabcd 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -416,34 +416,62 @@ def flash_attn_with_kvcache( cache_batch_idx = maybe_contiguous(cache_batch_idx) block_table = maybe_contiguous(block_table) - if s_aux is not None: - raise NotImplementedError("FA2 does not support s_aux") - if scheduler_metadata is not None and q_descale is not None \ - and k_descale is not None and v_descale is not None: - raise NotImplementedError( - "FA2 does not support scheduler_metadata, q_descale, " - "k_descale, v_descale" - ) - - out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache( - q, k_cache, v_cache, - k, v, # k_new, v_new - cache_seqlens, - rotary_cos, - rotary_sin, - cache_batch_idx, - cache_leftpad, - block_table, - alibi_slopes, - out, - softmax_scale, - causal, - window_size[0], - window_size[1], - softcap, - rotary_interleaved, - num_splits, - ) + if fa_version == 2: + if s_aux is not None: + raise NotImplementedError("FA2 does not support s_aux") + if scheduler_metadata is not None and q_descale is not None \ + and k_descale is not None and v_descale is not None: + raise NotImplementedError( + "FA2 does not support scheduler_metadata, q_descale, " + "k_descale, v_descale" + ) + + out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache( + q, k_cache, v_cache, + k, v, # k_new, v_new + cache_seqlens, + rotary_cos, + rotary_sin, + cache_batch_idx, + cache_leftpad, + block_table, + alibi_slopes, + out, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + rotary_interleaved, + num_splits, + ) + elif fa_version == 3: + assert alibi_slopes is None, "Alibi is not supported in FA3" + out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( + q, k_cache, v_cache, # q, k, v + k, v, # k_new, v_new + None, # q_v + out, + None, None, # cu_seqlens_q, cu_seqlens_k + None, # cu_seqlens_k_new + None, cache_seqlens, # seqused_q, seqused_k + None, None, # max_seqlen_q, max_seqlen_k + block_table, + cache_batch_idx, # kv_batch_idx + None, # leftpad_k + None, None, None, # rotary_cos, rotary_sin, seqlens_rotary + q_descale, k_descale, v_descale, + softmax_scale, + causal, + window_size[0], window_size[1], + softcap, + rotary_interleaved, # rotary_interleaved + scheduler_metadata, + num_splits, # num_splits + None, # pack_gqa + 0, # sm_margin + s_aux, # s_aux + ) return (out, softmax_lse) if return_softmax_lse else out From b93c1eec3cf010612e6ae48d664e1835001ffeb8 Mon Sep 17 00:00:00 2001 From: Mickael Seznec Date: Wed, 3 Sep 2025 16:08:22 +0200 Subject: [PATCH 2/2] check fa3 Signed-off-by: Mickael Seznec --- vllm_flash_attn/flash_attn_interface.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index be557fabcd..cf66c32306 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -445,7 +445,8 @@ def flash_attn_with_kvcache( rotary_interleaved, num_splits, ) - elif fa_version == 3: + else: + assert fa_version == 3 assert alibi_slopes is None, "Alibi is not supported in FA3" out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd( q, k_cache, v_cache, # q, k, v