Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 57 additions & 28 deletions vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,34 +416,63 @@ def flash_attn_with_kvcache(
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)

if s_aux is not None:
raise NotImplementedError("FA2 does not support s_aux")
if scheduler_metadata is not None and q_descale is not None \
and k_descale is not None and v_descale is not None:
raise NotImplementedError(
"FA2 does not support scheduler_metadata, q_descale, "
"k_descale, v_descale"
)

out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache(
q, k_cache, v_cache,
k, v, # k_new, v_new
cache_seqlens,
rotary_cos,
rotary_sin,
cache_batch_idx,
cache_leftpad,
block_table,
alibi_slopes,
out,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
num_splits,
)
if fa_version == 2:
if s_aux is not None:
raise NotImplementedError("FA2 does not support s_aux")
if scheduler_metadata is not None and q_descale is not None \
and k_descale is not None and v_descale is not None:
raise NotImplementedError(
"FA2 does not support scheduler_metadata, q_descale, "
"k_descale, v_descale"
)

out, softmax_lse = torch.ops._vllm_fa2_C.fwd_kvcache(
q, k_cache, v_cache,
k, v, # k_new, v_new
cache_seqlens,
rotary_cos,
rotary_sin,
cache_batch_idx,
cache_leftpad,
block_table,
alibi_slopes,
out,
softmax_scale,
causal,
window_size[0],
window_size[1],
softcap,
rotary_interleaved,
num_splits,
)
else:
assert fa_version == 3
assert alibi_slopes is None, "Alibi is not supported in FA3"
out, softmax_lse, _, _ = torch.ops._vllm_fa3_C.fwd(
q, k_cache, v_cache, # q, k, v
k, v, # k_new, v_new
None, # q_v
out,
None, None, # cu_seqlens_q, cu_seqlens_k
None, # cu_seqlens_k_new
None, cache_seqlens, # seqused_q, seqused_k
None, None, # max_seqlen_q, max_seqlen_k
block_table,
cache_batch_idx, # kv_batch_idx
None, # leftpad_k
None, None, None, # rotary_cos, rotary_sin, seqlens_rotary
q_descale, k_descale, v_descale,
softmax_scale,
causal,
window_size[0], window_size[1],
softcap,
rotary_interleaved, # rotary_interleaved
scheduler_metadata,
num_splits, # num_splits
None, # pack_gqa
0, # sm_margin
s_aux, # s_aux
)
return (out, softmax_lse) if return_softmax_lse else out


Expand Down