From 8cdbd20246b3247e07ad792176e8bd3a93a10110 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Tue, 16 Sep 2025 22:08:07 -0700 Subject: [PATCH 1/8] [Attention][DCP] Support DCP with query len > 1 with FA3 Signed-off-by: Ming Yang --- vllm/v1/attention/backends/mla/flashattn_mla.py | 10 ++++++---- vllm/v1/worker/gpu_model_runner.py | 6 ------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 652b1cdb6b76..4c1b51f8cfa4 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -12,7 +12,6 @@ from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, get_flash_attn_version) from vllm.config import VllmConfig -from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonDecodeMetadata, @@ -97,9 +96,10 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.max_num_splits = ( envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) - # TODO(lucas): Until we add support for the DCP custom masking we need - # to restrict decodes to q_len == 1 when DCP is enabled. - self.reorder_batch_threshold = 1 \ + # TODO(ming): increase the threshold to larger number would cause + # accuracy issue when dcp is enabled. We will remove this once the + # accuracy issue is resolved + self.reorder_batch_threshold = 16 \ if get_dcp_group().world_size > 1 else self.reorder_batch_threshold def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, @@ -260,6 +260,8 @@ def _forward_decode( fa_version=3, # only version 3 is supported scheduler_metadata=attn_metadata.decode.scheduler_metadata, num_splits=attn_metadata.decode.max_num_splits, + cp_world_size=self.dcp_world_size, + cp_rank=self.dcp_rank, ) if self.need_to_return_lse_for_decode: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0960fe3a25fb..6bc6852eb41e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -517,12 +517,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: - # NOTE(lucas): currently no backend supports the custom masking - # required for DCP with q_len > 1, so we assert here. Remove this - # assert once the custom mask is support is added to FA3. - if self.dcp_world_size > 1: - assert self.reorder_batch_threshold == 1, \ - "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, From 45bb7e8f0e02ad2e02a213f5e4f2242d32ae3d35 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Thu, 25 Sep 2025 08:28:45 -0700 Subject: [PATCH 2/8] fix num_heads Signed-off-by: Ming Yang --- vllm/v1/attention/backends/mla/flashattn_mla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 4c1b51f8cfa4..8674c41f9b40 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -109,7 +109,7 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, batch_size=num_reqs, max_seqlen_q=max_query_len, max_seqlen_k=max_seq_len, - num_heads_q=self.num_heads, + num_heads_q=self.num_heads * self.dcp_world_size, num_heads_kv=1, headdim=self.mla_dims.qk_rope_head_dim, cache_seqlens=seqlens, From 9c0176b0474f21db3e2d7fdeabfd5680e6da9c19 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Sun, 28 Sep 2025 19:22:40 -0700 Subject: [PATCH 3/8] keep reorder_batch_threshold check for other backends; also fix max_seq_len Signed-off-by: Ming Yang --- vllm/v1/attention/backends/mla/flashattn_mla.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index 8674c41f9b40..a643717391ca 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -100,7 +100,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # accuracy issue when dcp is enabled. We will remove this once the # accuracy issue is resolved self.reorder_batch_threshold = 16 \ - if get_dcp_group().world_size > 1 else self.reorder_batch_threshold + if self.dcp_world_size > 1 else self.reorder_batch_threshold def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): @@ -130,7 +130,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, num_decode_tokens: int) -> FlashAttnMLADecodeMetadata: query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) max_query_len = query_lens_cpu.max().item() - max_seq_len = seq_lens_cpu.max().item() + max_seq_len = seq_lens_device.max().item() scheduler_metadata = self._schedule_decode( num_reqs=seq_lens_cpu.numel(), diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6bc6852eb41e..880fc63f6106 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -517,6 +517,13 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: return if self.reorder_batch_threshold is not None: + # NOTE(lucas): currently no backend supports the custom masking + # required for DCP with q_len > 1, so we assert here. Remove this + # assert once the custom mask is support is added to FA3. + if self.dcp_world_size > 1 and \ + envs.VLLM_ATTENTION_BACKEND != "FLASH_ATTN_MLA": + assert self.reorder_batch_threshold == 1, \ + "DCP not support reorder_batch_threshold > 1 now." reorder_batch_to_split_decodes_and_prefills( self.input_batch, scheduler_output, From 4cc05c6888a11a5565cc51fb254336b9c3286723 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Mon, 29 Sep 2025 17:24:58 -0700 Subject: [PATCH 4/8] pass cp_tot_seqused_k Signed-off-by: Ming Yang --- vllm/v1/attention/backends/mla/common.py | 19 ++++++++------- .../attention/backends/mla/flashattn_mla.py | 24 +++++++++---------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 1053fde09910..5dd03a3e148d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -359,6 +359,7 @@ class ChunkedContextMetadata( class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor + cp_tot_seq_lens: torch.Tensor D = TypeVar("D", bound=MLACommonDecodeMetadata) @@ -624,15 +625,15 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> MLACommonDecodeMetadata: + def _build_decode( + self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, num_decode_tokens: int, + cp_tot_seq_lens_device: torch.Tensor) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, + cp_tot_seq_lens=cp_tot_seq_lens_device, ) def build_for_cudagraph_capture( @@ -683,7 +684,7 @@ def build(self, # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: - seq_lens[:num_decodes] = seq_lens[:num_decodes] \ + cp_seq_lens = seq_lens[:num_decodes] \ // self.dcp_world_size + (self.dcp_rank <= \ (seq_lens[:num_decodes] - 1) % self.dcp_world_size) @@ -836,10 +837,12 @@ def build(self, decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens_cpu=seq_lens_cpu[:num_decodes], - seq_lens_device=seq_lens[:num_decodes], + seq_lens_device=cp_seq_lens + if self.dcp_world_size > 1 else seq_lens[:num_decodes], query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], query_start_loc_device=query_start_loc[:num_decodes + 1], num_decode_tokens=num_decode_tokens, + cp_tot_seq_lens_device=seq_lens[:num_decodes], ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index a643717391ca..df5be314fdad 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -96,12 +96,6 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.max_num_splits = ( envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH) - # TODO(ming): increase the threshold to larger number would cause - # accuracy issue when dcp is enabled. We will remove this once the - # accuracy issue is resolved - self.reorder_batch_threshold = 16 \ - if self.dcp_world_size > 1 else self.reorder_batch_threshold - def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): if self.fa_aot_schedule: @@ -122,12 +116,16 @@ def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, ) return None - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> FlashAttnMLADecodeMetadata: + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + cp_tot_seq_lens_device: torch.Tensor, + ) -> FlashAttnMLADecodeMetadata: query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) max_query_len = query_lens_cpu.max().item() max_seq_len = seq_lens_device.max().item() @@ -172,6 +170,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, max_seq_len=max_seq_len, scheduler_metadata=scheduler_metadata, max_num_splits=max_num_splits, + cp_tot_seq_lens=cp_tot_seq_lens_device, ) @@ -262,6 +261,7 @@ def _forward_decode( num_splits=attn_metadata.decode.max_num_splits, cp_world_size=self.dcp_world_size, cp_rank=self.dcp_rank, + cp_tot_seqused_k=attn_metadata.decode.cp_tot_seq_lens, ) if self.need_to_return_lse_for_decode: From a6efa960647a22223a54ee43cad5e771da0789f8 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Mon, 29 Sep 2025 18:29:42 -0700 Subject: [PATCH 5/8] fix pre-commit check Signed-off-by: Ming Yang --- vllm/v1/attention/backends/mla/common.py | 15 +++++++++----- .../attention/backends/mla/flashattn_mla.py | 2 +- vllm/v1/attention/backends/mla/flashmla.py | 17 ++++++++++------ .../attention/backends/mla/rocm_aiter_mla.py | 20 ++++++++++++------- 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5dd03a3e148d..25fa1eca858c 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -359,7 +359,7 @@ class ChunkedContextMetadata( class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor - cp_tot_seq_lens: torch.Tensor + cp_tot_seq_lens: Optional[torch.Tensor] D = TypeVar("D", bound=MLACommonDecodeMetadata) @@ -626,10 +626,15 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): prefill.prefill_chunks = self._fi_prefill_chunks def _build_decode( - self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, num_decode_tokens: int, - cp_tot_seq_lens_device: torch.Tensor) -> MLACommonDecodeMetadata: + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + cp_tot_seq_lens_device: Optional[torch.Tensor], + ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index df5be314fdad..0cd827d2134e 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -124,7 +124,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - cp_tot_seq_lens_device: torch.Tensor, + cp_tot_seq_lens_device: Optional[torch.Tensor], ) -> FlashAttnMLADecodeMetadata: query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) max_query_len = query_lens_cpu.max().item() diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index ac0524ba088b..2477523d1295 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -84,12 +84,16 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], device=self.device, dtype=torch.int32) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> FlashMLADecodeMetadata: + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + cp_tot_seq_lens_device: Optional[torch.Tensor], + ) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( seq_lens_device, @@ -129,6 +133,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, + cp_tot_seq_lens=cp_tot_seq_lens_device, ) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 79247e569b1c..c479cf08219b 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -104,12 +104,16 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=device) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens_cpu: torch.Tensor, - seq_lens_device: torch.Tensor, - query_start_loc_cpu: torch.Tensor, - query_start_loc_device: torch.Tensor, - num_decode_tokens: int) -> AiterMLADecodeMetadata: + def _build_decode( + self, + block_table_tensor: torch.Tensor, + seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, + query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor, + num_decode_tokens: int, + cp_tot_seq_lens_device: Optional[torch.Tensor], + ) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device @@ -164,7 +168,9 @@ def _build_decode(self, block_table_tensor: torch.Tensor, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, - qo_indptr=qo_indptr) + qo_indptr=qo_indptr, + cp_tot_seq_lens=cp_tot_seq_lens_device, + ) return attn_metadata From ed6dcdd5e5d6185f9eae60dc4b7c69d954eb8e07 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Tue, 30 Sep 2025 15:19:53 -0700 Subject: [PATCH 6/8] make cp_seq_lens cuda graph compatible Signed-off-by: Ming Yang --- vllm/v1/attention/backends/mla/common.py | 12 ++++++++---- vllm/v1/attention/backends/utils.py | 3 +++ vllm/v1/spec_decode/eagle.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 10 +++++++++- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 25fa1eca858c..cff63d234be7 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -677,6 +677,7 @@ def build(self, query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu + cp_seq_lens = common_attn_metadata.cp_seq_lens query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -689,7 +690,8 @@ def build(self, # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: - cp_seq_lens = seq_lens[:num_decodes] \ + assert cp_seq_lens is not None + cp_seq_lens[:num_decodes] = seq_lens[:num_decodes] \ // self.dcp_world_size + (self.dcp_rank <= \ (seq_lens[:num_decodes] - 1) % self.dcp_world_size) @@ -842,12 +844,14 @@ def build(self, decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens_cpu=seq_lens_cpu[:num_decodes], - seq_lens_device=cp_seq_lens - if self.dcp_world_size > 1 else seq_lens[:num_decodes], + seq_lens_device=cp_seq_lens[:num_decodes] + if self.dcp_world_size > 1 and cp_seq_lens is not None else + seq_lens[:num_decodes], query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], query_start_loc_device=query_start_loc[:num_decodes + 1], num_decode_tokens=num_decode_tokens, - cp_tot_seq_lens_device=seq_lens[:num_decodes], + cp_tot_seq_lens_device=seq_lens[:num_decodes] + if self.dcp_world_size > 1 else None, ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index f37a829f401c..927a4a9a9aa1 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -83,6 +83,9 @@ class CommonAttentionMetadata: # Needed by CrossAttentionBuilder encoder_seq_lens: Optional[np.ndarray] = None + cp_seq_lens: Optional[torch.Tensor] = None + """Sequence lengths of the local rank in context parallelism world""" + def slice_query_start_locs( query_start_loc: torch.Tensor, diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1b5bafb9ca1b..a515a527e725 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -558,6 +558,7 @@ def prepare_inputs_padded(self, block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, + cp_seq_lens=common_attn_metadata.cp_seq_lens, ) token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \ @@ -831,6 +832,7 @@ def prepare_inputs( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, + cp_seq_lens=common_attn_metadata.cp_seq_lens, ) return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 880fc63f6106..4fc0ab25e6c7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -350,6 +350,9 @@ def __init__( self.query_start_loc = self._make_buffer(self.max_num_reqs + 1, dtype=torch.int32) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + if self.dcp_world_size > 1: + self.cp_seq_lens = self._make_buffer(self.max_num_reqs, + dtype=torch.int32) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. @@ -1222,6 +1225,8 @@ def _prepare_inputs( num_logits_indices=logits_indices.size(0), causal=True, encoder_seq_lens=encoder_seq_lens, + cp_seq_lens=self.cp_seq_lens.gpu[:num_reqs] + if self.dcp_world_size > 1 else None, ) if (self.speculative_config @@ -3129,7 +3134,10 @@ def _dummy_run( block_table[kv_cache_group_id].get_device_tensor(num_reqs), slot_mapping=self.input_batch.block_table[ kv_cache_group_id].slot_mapping.gpu[:num_tokens], - causal=True) + causal=True, + cp_seq_lens=self.cp_seq_lens.gpu[:num_reqs] + if self.dcp_world_size > 1 else None, + ) for attn_group in self.attn_groups[kv_cache_group_id]: if ubatch_slices is not None: common_attn_metadata_list = split_attn_metadata( From efaf7fa1e4fc42ec100bdde171690248bfa5aaf1 Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Mon, 6 Oct 2025 10:50:21 -0700 Subject: [PATCH 7/8] update flash-attn commit hash to pick up cp-related changes Signed-off-by: Ming Yang --- cmake/external_projects/vllm_flash_attn.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 3d32121f13ac..d4908772c69e 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a + GIT_TAG 8f468e7da54a8e2f98abfa7c38636aac91c0cba1 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn From 19a7f8c4905e09de7326f93eb085f4739acffc7a Mon Sep 17 00:00:00 2001 From: Ming Yang Date: Mon, 6 Oct 2025 11:59:27 -0700 Subject: [PATCH 8/8] change cp_ prefix to dcp_ Signed-off-by: Ming Yang --- vllm/v1/attention/backends/mla/common.py | 18 +++++++++--------- .../v1/attention/backends/mla/flashattn_mla.py | 6 +++--- vllm/v1/attention/backends/mla/flashmla.py | 4 ++-- .../attention/backends/mla/rocm_aiter_mla.py | 4 ++-- vllm/v1/attention/backends/utils.py | 4 ++-- vllm/v1/spec_decode/eagle.py | 4 ++-- vllm/v1/worker/gpu_model_runner.py | 8 +++++--- 7 files changed, 25 insertions(+), 23 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index eea7618b9ec8..9356e3b3a50c 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -370,7 +370,7 @@ class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata): class MLACommonDecodeMetadata: block_table: torch.Tensor seq_lens: torch.Tensor - cp_tot_seq_lens: Optional[torch.Tensor] + dcp_tot_seq_lens: Optional[torch.Tensor] D = TypeVar("D", bound=MLACommonDecodeMetadata) @@ -663,12 +663,12 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - cp_tot_seq_lens_device: Optional[torch.Tensor], + dcp_tot_seq_lens_device: Optional[torch.Tensor], ) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, seq_lens=seq_lens_device, - cp_tot_seq_lens=cp_tot_seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) def build_for_cudagraph_capture( @@ -710,7 +710,7 @@ def build( query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens seq_lens_cpu = common_attn_metadata.seq_lens_cpu - cp_seq_lens = common_attn_metadata.cp_seq_lens + dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -724,8 +724,8 @@ def build( # Note(hc): update seq_lens of decode reqs under DCP. if self.dcp_world_size > 1: - assert cp_seq_lens is not None - cp_seq_lens[:num_decodes] = seq_lens[ + assert dcp_local_seq_lens is not None + dcp_local_seq_lens[:num_decodes] = seq_lens[ :num_decodes ] // self.dcp_world_size + ( self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size @@ -884,13 +884,13 @@ def build( decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], seq_lens_cpu=seq_lens_cpu[:num_decodes], - seq_lens_device=cp_seq_lens[:num_decodes] - if self.dcp_world_size > 1 and cp_seq_lens is not None + seq_lens_device=dcp_local_seq_lens[:num_decodes] + if self.dcp_world_size > 1 and dcp_local_seq_lens is not None else seq_lens[:num_decodes], query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1], query_start_loc_device=query_start_loc[: num_decodes + 1], num_decode_tokens=num_decode_tokens, - cp_tot_seq_lens_device=seq_lens[:num_decodes] + dcp_tot_seq_lens_device=seq_lens[:num_decodes] if self.dcp_world_size > 1 else None, ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index c838f95eafea..c043990ffcc6 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -135,7 +135,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - cp_tot_seq_lens_device: Optional[torch.Tensor], + dcp_tot_seq_lens_device: Optional[torch.Tensor], ) -> FlashAttnMLADecodeMetadata: query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] max_query_len = query_lens_cpu.max().item() @@ -182,7 +182,7 @@ def _build_decode( max_seq_len=max_seq_len, scheduler_metadata=scheduler_metadata, max_num_splits=max_num_splits, - cp_tot_seq_lens=cp_tot_seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) @@ -286,7 +286,7 @@ def _forward_decode( num_splits=attn_metadata.decode.max_num_splits, cp_world_size=self.dcp_world_size, cp_rank=self.dcp_rank, - cp_tot_seqused_k=attn_metadata.decode.cp_tot_seq_lens, + cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens, ) if self.need_to_return_lse_for_decode: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 78f8b170e423..c1eb12442e0c 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -102,7 +102,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - cp_tot_seq_lens_device: Optional[torch.Tensor], + dcp_tot_seq_lens_device: Optional[torch.Tensor], ) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = get_mla_metadata( seq_lens_device, @@ -143,7 +143,7 @@ def _build_decode( seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, - cp_tot_seq_lens=cp_tot_seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 2c5f6f3b3ad6..195b05e0a301 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -116,7 +116,7 @@ def _build_decode( query_start_loc_cpu: torch.Tensor, query_start_loc_device: torch.Tensor, num_decode_tokens: int, - cp_tot_seq_lens_device: Optional[torch.Tensor], + dcp_tot_seq_lens_device: Optional[torch.Tensor], ) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size block_table_bounds = (seq_lens_device + page_size - 1) // page_size @@ -175,7 +175,7 @@ def _build_decode( paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, qo_indptr=qo_indptr, - cp_tot_seq_lens=cp_tot_seq_lens_device, + dcp_tot_seq_lens=dcp_tot_seq_lens_device, ) return attn_metadata diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 4249ff7f5e6e..711f22e8494e 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -93,8 +93,8 @@ class CommonAttentionMetadata: # Needed by CrossAttentionBuilder encoder_seq_lens: Optional[np.ndarray] = None - cp_seq_lens: Optional[torch.Tensor] = None - """Sequence lengths of the local rank in context parallelism world""" + dcp_local_seq_lens: Optional[torch.Tensor] = None + """Sequence lengths of the local rank in decode context parallelism world""" def slice_query_start_locs( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a9f9c9ededc6..c511e627c15a 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -601,7 +601,7 @@ def prepare_inputs_padded( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, - cp_seq_lens=common_attn_metadata.cp_seq_lens, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) token_indices_to_sample = ( @@ -873,7 +873,7 @@ def prepare_inputs( block_table_tensor=common_attn_metadata.block_table_tensor, slot_mapping=common_attn_metadata.slot_mapping[token_indices], causal=True, - cp_seq_lens=common_attn_metadata.cp_seq_lens, + dcp_local_seq_lens=common_attn_metadata.dcp_local_seq_lens, ) return spec_common_attn_metadata, token_indices diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ca8b3c00fe18..25e1f9fc1142 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -393,7 +393,9 @@ def __init__( ) self.seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) if self.dcp_world_size > 1: - self.cp_seq_lens = self._make_buffer(self.max_num_reqs, dtype=torch.int32) + self.dcp_local_seq_lens = self._make_buffer( + self.max_num_reqs, dtype=torch.int32 + ) # Because inputs_embeds may be bfloat16 and we don't need a numpy # version of this tensor, avoid a RuntimeError by not creating a # numpy buffer. @@ -1334,7 +1336,7 @@ def _prepare_inputs( num_logits_indices=logits_indices.size(0), causal=True, encoder_seq_lens=encoder_seq_lens, - cp_seq_lens=self.cp_seq_lens.gpu[:num_reqs] + dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None, ) @@ -3381,7 +3383,7 @@ def _dummy_run( kv_cache_group_id ].slot_mapping.gpu[:num_tokens], causal=True, - cp_seq_lens=self.cp_seq_lens.gpu[:num_reqs] + dcp_local_seq_lens=self.dcp_local_seq_lens.gpu[:num_reqs] if self.dcp_world_size > 1 else None, )