Skip to content

Commit c79990e

Browse files
committed
enable users to select triton fa for MLA backend
Signed-off-by: qli88 <qiang.li2@amd.com>
1 parent 900edbf commit c79990e

File tree

1 file changed

+30
-12
lines changed

1 file changed

+30
-12
lines changed

vllm/attention/backends/mla/utils.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@
3030
scaled_quantize)
3131
from vllm.model_executor.layers.rotary_embedding import (
3232
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
33+
from vllm.attention.ops.triton_flash_attention import triton_attention
3334

3435
try:
3536
from vllm.vllm_flash_attn import flash_attn_varlen_func
3637
except ImportError:
3738
from flash_attn import flash_attn_varlen_func
3839

39-
4040
@dataclass
4141
class MLACommonMetadata(AttentionMetadata):
4242
# Input positions for rotrary embeddings since for MLA the rotary
@@ -187,6 +187,7 @@ def __init__(
187187
# Handle the differences between the flash_attn_varlen from flash_attn
188188
# and the one from vllm_flash_attn. The former is used on RoCM and the
189189
# latter has an additional parameter to control FA2 vs FA3
190+
self.triton_flash_attn_func = triton_attention
190191
self.flash_attn_varlen_func = flash_attn_varlen_func
191192
if self.vllm_flash_attn_version is not None:
192193
self.flash_attn_varlen_func = \
@@ -497,17 +498,34 @@ def _forward_prefill_flash(
497498
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
498499
value=0)
499500

500-
attn_output = self.flash_attn_varlen_func(
501-
q=q,
502-
k=k,
503-
v=v_padded,
504-
cu_seqlens_q=seq_start_loc,
505-
cu_seqlens_k=seq_start_loc,
506-
max_seqlen_q=max_prefill_seq_len,
507-
max_seqlen_k=max_prefill_seq_len,
508-
softmax_scale=self.scale,
509-
causal=True,
510-
)
501+
if envs.VLLM_USE_TRITON_FLASH_ATTN:
502+
attn_output, _ = self.triton_flash_attn_func(
503+
q,
504+
k,
505+
v_padded,
506+
None,
507+
seq_start_loc,
508+
seq_start_loc,
509+
max_prefill_seq_len,
510+
max_prefill_seq_len,
511+
True,
512+
self.scale,
513+
None, # attn_mask is None unless applying ALiBi mask
514+
None, # fp8 scales need additional work to integrate
515+
)
516+
else:
517+
attn_output = self.flash_attn_varlen_func(
518+
q=q,
519+
k=k,
520+
v=v_padded,
521+
cu_seqlens_q=seq_start_loc,
522+
cu_seqlens_k=seq_start_loc,
523+
max_seqlen_q=max_prefill_seq_len,
524+
max_seqlen_k=max_prefill_seq_len,
525+
softmax_scale=self.scale,
526+
causal=True,
527+
)
528+
511529
attn_output = attn_output\
512530
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
513531
.reshape(-1, self.num_heads * v.shape[-1])

0 commit comments

Comments
 (0)