Skip to content

Commit c01c208

Browse files
LucasWilkinsonMatthewBonanni
authored andcommitted
[Attention] add DCP support for FLASH_ATTN_MLA backend (vllm-project#24453)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 8d3b842 commit c01c208

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from vllm.attention.utils.fa_utils import (flash_attn_supports_mla,
1212
get_flash_attn_version)
1313
from vllm.config import VllmConfig
14+
from vllm.distributed.parallel_state import get_dcp_group
1415
from vllm.logger import init_logger
1516
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
1617
MLACommonDecodeMetadata,
@@ -98,6 +99,11 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9899
# pre-allocated during capture.
99100
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
100101

102+
# TODO(lucas): Until we add support for the DCP custom masking we need
103+
# to restrict decodes to q_len == 1 when DCP is enabled.
104+
self.__class__.reorder_batch_threshold = 1 \
105+
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
106+
101107
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
102108
max_seq_len, causal):
103109
if self.fa_aot_schedule:
@@ -172,6 +178,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
172178

173179

174180
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
181+
can_return_lse_for_decode: bool = True
175182

176183
def __init__(
177184
self,
@@ -239,7 +246,7 @@ def _forward_decode(
239246
# to prevent invalid grid configuration during graph capture.
240247
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
241248

242-
o = flash_attn_varlen_func(
249+
attn_out = flash_attn_varlen_func(
243250
q=q_pe,
244251
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
245252
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
@@ -251,9 +258,16 @@ def _forward_decode(
251258
block_table=attn_metadata.decode.block_table,
252259
softmax_scale=self.scale,
253260
causal=True,
261+
return_softmax_lse=self.need_to_return_lse_for_decode,
254262
fa_version=3, # only version 3 is supported
255263
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
256264
num_splits=attn_metadata.decode.max_num_splits,
257265
)
258266

259-
return self._v_up_proj(o)
267+
if self.need_to_return_lse_for_decode:
268+
o, lse = attn_out
269+
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
270+
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
271+
else:
272+
o = attn_out
273+
return o, None

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,9 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
440440
return
441441

442442
if self.reorder_batch_threshold is not None:
443+
# NOTE(lucas): currently no backend supports the custom masking
444+
# required for DCP with q_len > 1, so we assert here. Remove this
445+
# assert once the custom mask is support is added to FA3.
443446
if self.dcp_world_size > 1:
444447
assert self.reorder_batch_threshold == 1, \
445448
"DCP not support reorder_batch_threshold > 1 now."

0 commit comments

Comments
 (0)