Skip to content

Commit c1167e7

Browse files
minosfuture0xrushi
authored andcommitted
[Attention][DCP] Support DCP with query length > 1 (MTP) with FA3 (vllm-project#25049)
Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
1 parent d794213 commit c1167e7

File tree

8 files changed

+45
-13
lines changed

8 files changed

+45
-13
lines changed

cmake/external_projects/vllm_flash_attn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ else()
3838
FetchContent_Declare(
3939
vllm-flash-attn
4040
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
41-
GIT_TAG 4695e6bed5366c41e28c06cd86170166e4f43d00
41+
GIT_TAG 8f468e7da54a8e2f98abfa7c38636aac91c0cba1
4242
GIT_PROGRESS TRUE
4343
# Don't share the vllm-flash-attn build between build types
4444
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

vllm/v1/attention/backends/mla/common.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
370370
class MLACommonDecodeMetadata:
371371
block_table: torch.Tensor
372372
seq_lens: torch.Tensor
373+
dcp_tot_seq_lens: Optional[torch.Tensor]
373374

374375

375376
D = TypeVar("D", bound=MLACommonDecodeMetadata)
@@ -682,10 +683,12 @@ def _build_decode(
682683
query_start_loc_cpu: torch.Tensor,
683684
query_start_loc_device: torch.Tensor,
684685
num_decode_tokens: int,
686+
dcp_tot_seq_lens_device: Optional[torch.Tensor],
685687
) -> MLACommonDecodeMetadata:
686688
return MLACommonDecodeMetadata(
687689
block_table=block_table_tensor,
688690
seq_lens=seq_lens_device,
691+
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
689692
)
690693

691694
def build_for_cudagraph_capture(
@@ -727,6 +730,7 @@ def build(
727730
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
728731
seq_lens = common_attn_metadata.seq_lens
729732
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
733+
dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
730734

731735
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
732736

@@ -742,7 +746,10 @@ def build(
742746

743747
# Note(hc): update seq_lens of decode reqs under DCP.
744748
if self.dcp_world_size > 1:
745-
seq_lens[:num_decodes] = seq_lens[:num_decodes] // self.dcp_world_size + (
749+
assert dcp_local_seq_lens is not None
750+
dcp_local_seq_lens[:num_decodes] = seq_lens[
751+
:num_decodes
752+
] // self.dcp_world_size + (
746753
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
747754
)
748755

@@ -899,10 +906,15 @@ def build(
899906
decode_metadata = self._build_decode(
900907
block_table_tensor=block_table_tensor[:num_decodes, ...],
901908
seq_lens_cpu=seq_lens_cpu[:num_decodes],
902-
seq_lens_device=seq_lens[:num_decodes],
909+
seq_lens_device=dcp_local_seq_lens[:num_decodes]
910+
if self.dcp_world_size > 1 and dcp_local_seq_lens is not None
911+
else seq_lens[:num_decodes],
903912
query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
904913
query_start_loc_device=query_start_loc[: num_decodes + 1],
905914
num_decode_tokens=num_decode_tokens,
915+
dcp_tot_seq_lens_device=seq_lens[:num_decodes]
916+
if self.dcp_world_size > 1
917+
else None,
906918
)
907919

908920
attn_metadata = self.metadata_cls(

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
get_flash_attn_version,
1818
)
1919
from vllm.config import VllmConfig
20-
from vllm.distributed.parallel_state import get_dcp_group
2120
from vllm.logger import init_logger
2221
from vllm.v1.attention.backends.mla.common import (
2322
MLACommonBackend,
@@ -107,12 +106,6 @@ def __init__(
107106
# pre-allocated during capture.
108107
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
109108

110-
# TODO(lucas): Until we add support for the DCP custom masking we need
111-
# to restrict decodes to q_len == 1 when DCP is enabled.
112-
self.reorder_batch_threshold = (
113-
1 if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
114-
)
115-
116109
def _schedule_decode(
117110
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
118111
):
@@ -121,7 +114,7 @@ def _schedule_decode(
121114
batch_size=num_reqs,
122115
max_seqlen_q=max_query_len,
123116
max_seqlen_k=max_seq_len,
124-
num_heads_q=self.num_heads,
117+
num_heads_q=self.num_heads * self.dcp_world_size,
125118
num_heads_kv=1,
126119
headdim=self.mla_dims.qk_rope_head_dim,
127120
cache_seqlens=seqlens,
@@ -142,10 +135,11 @@ def _build_decode(
142135
query_start_loc_cpu: torch.Tensor,
143136
query_start_loc_device: torch.Tensor,
144137
num_decode_tokens: int,
138+
dcp_tot_seq_lens_device: Optional[torch.Tensor],
145139
) -> FlashAttnMLADecodeMetadata:
146140
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
147141
max_query_len = query_lens_cpu.max().item()
148-
max_seq_len = seq_lens_cpu.max().item()
142+
max_seq_len = seq_lens_device.max().item()
149143

150144
scheduler_metadata = self._schedule_decode(
151145
num_reqs=seq_lens_cpu.numel(),
@@ -188,6 +182,7 @@ def _build_decode(
188182
max_seq_len=max_seq_len,
189183
scheduler_metadata=scheduler_metadata,
190184
max_num_splits=max_num_splits,
185+
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
191186
)
192187

193188

@@ -289,6 +284,9 @@ def _forward_decode(
289284
fa_version=3, # only version 3 is supported
290285
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
291286
num_splits=attn_metadata.decode.max_num_splits,
287+
cp_world_size=self.dcp_world_size,
288+
cp_rank=self.dcp_rank,
289+
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
292290
)
293291

294292
if self.need_to_return_lse_for_decode:

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def _build_decode(
106106
query_start_loc_cpu: torch.Tensor,
107107
query_start_loc_device: torch.Tensor,
108108
num_decode_tokens: int,
109+
dcp_tot_seq_lens_device: Optional[torch.Tensor],
109110
) -> FlashMLADecodeMetadata:
110111
tile_scheduler_metadata, num_splits = get_mla_metadata(
111112
seq_lens_device,
@@ -146,6 +147,7 @@ def _build_decode(
146147
seq_lens=seq_lens_device,
147148
tile_scheduler_metadata=tile_scheduler_metadata,
148149
num_splits=num_splits,
150+
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
149151
)
150152

151153

vllm/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def _build_decode(
116116
query_start_loc_cpu: torch.Tensor,
117117
query_start_loc_device: torch.Tensor,
118118
num_decode_tokens: int,
119+
dcp_tot_seq_lens_device: Optional[torch.Tensor],
119120
) -> AiterMLADecodeMetadata:
120121
page_size = self.kv_cache_spec.block_size
121122
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
@@ -174,6 +175,7 @@ def _build_decode(
174175
paged_kv_indices=paged_kv_indices,
175176
paged_kv_last_page_len=paged_kv_last_page_len,
176177
qo_indptr=qo_indptr,
178+
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
177179
)
178180

179181
return attn_metadata

vllm/v1/attention/backends/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ class CommonAttentionMetadata:
9393
# Needed by CrossAttentionBuilder
9494
encoder_seq_lens: Optional[np.ndarray] = None
9595

96+
dcp_local_seq_lens: Optional[torch.Tensor] = None
97+
"""Sequence lengths of the local rank in decode context parallelism world"""
98+
9699

97100
def slice_query_start_locs(
98101
query_start_loc: torch.Tensor,

vllm/v1/spec_decode/eagle.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def prepare_inputs_padded(
597597
block_table_tensor=common_attn_metadata.block_table_tensor,
598598
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
599599
causal=True,
600+
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
600601
)
601602

602603
token_indices_to_sample = (
@@ -868,6 +869,7 @@ def prepare_inputs(
868869
block_table_tensor=common_attn_metadata.block_table_tensor,
869870
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
870871
causal=True,
872+
dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens,
871873
)
872874

873875
return spec_common_attn_metadata, token_indices

vllm/v1/worker/gpu_model_runner.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,10 @@ def __init__(
398398
self.max_num_reqs + 1, dtype=torch.int32
399399
)
400400
self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
401+
if self.dcp_world_size > 1:
402+
self.dcp_local_seq_lens = self._make_buffer(
403+
self.max_num_reqs, dtype=torch.int32
404+
)
401405
# Because inputs_embeds may be bfloat16 and we don't need a numpy
402406
# version of this tensor, avoid a RuntimeError by not creating a
403407
# numpy buffer.
@@ -581,7 +585,10 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
581585
# NOTE(lucas): currently no backend supports the custom masking
582586
# required for DCP with q_len > 1, so we assert here. Remove this
583587
# assert once the custom mask is support is added to FA3.
584-
if self.dcp_world_size > 1:
588+
if (
589+
self.dcp_world_size > 1
590+
and envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA"
591+
):
585592
assert self.reorder_batch_threshold == 1, (
586593
"DCP not support reorder_batch_threshold > 1 now."
587594
)
@@ -1335,6 +1342,9 @@ def _prepare_inputs(
13351342
num_logits_indices=logits_indices.size(0),
13361343
causal=True,
13371344
encoder_seq_lens=encoder_seq_lens,
1345+
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
1346+
if self.dcp_world_size > 1
1347+
else None,
13381348
)
13391349

13401350
if self.speculative_config and spec_decode_common_attn_metadata is None:
@@ -3310,6 +3320,9 @@ def _dummy_run(
33103320
kv_cache_group_id
33113321
].slot_mapping.gpu[:num_tokens],
33123322
causal=True,
3323+
dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs]
3324+
if self.dcp_world_size > 1
3325+
else None,
33133326
)
33143327
for attn_group in self.attn_groups[kv_cache_group_id]:
33153328
if ubatch_slices is not None:

0 commit comments

Comments
 (0)