From 0a87c8824693c5f88491aa32eb3cec838b6e3d9e Mon Sep 17 00:00:00 2001 From: zhangsicheng5 Date: Fri, 17 Oct 2025 14:52:19 +0800 Subject: [PATCH 01/10] support dcp kv_cache interleave size > 1 Signed-off-by: zhangsicheng5 --- tests/distributed/test_context_parallel.py | 7 +++++ vllm/config/parallel.py | 11 ++++++++ vllm/config/vllm.py | 17 ++++++++++++ vllm/engine/arg_utils.py | 6 ++++ vllm/v1/attention/backends/mla/common.py | 9 ------ vllm/v1/attention/backends/utils.py | 32 ++++++++++++++++++++++ vllm/v1/worker/block_table.py | 18 ++++++++++-- vllm/v1/worker/gpu_input_batch.py | 2 ++ vllm/v1/worker/gpu_model_runner.py | 13 +++++++++ 9 files changed, 104 insertions(+), 11 deletions(-) diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 5495640af07e..7f8e77a75621 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int dcp_size: int + dcp_kv_cache_interleave_size: int eager_mode: bool chunked_prefill: bool @@ -52,6 +53,7 @@ def detailed( tp_base: int = 4, pp_base: int = 1, dcp_base: int = 1, + dcp_kv_cache_interleave_size: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", load_format: str | None = None, @@ -66,6 +68,7 @@ def detailed( tp_size=tp_base, pp_size=pp_multiplier * pp_base, dcp_size=int(dcp_multiplier * tp_base), + dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, ) @@ -108,6 +111,7 @@ def _compare_cp_with_tp( tp_size, pp_size, dcp_size, + dcp_kv_cache_interleave_size, eager_mode, chunked_prefill, ) = parallel_setup @@ -180,6 +184,8 @@ def _compare_cp_with_tp( str(pp_size), "--decode-context-parallel-size", str(dcp_size), + "--dcp-kv-cache-interleave-size", + str(dcp_kv_cache_interleave_size), "--distributed-executor-backend", distributed_backend, ] @@ -207,6 +213,7 @@ def _compare_cp_with_tp( "deepseek-ai/DeepSeek-V2-Lite-Chat": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), + CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64), ], "bigcode/gpt_bigcode-santacoder": [ CPTestSettings.detailed(), diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index b7ef0fef6833..6ee1270df569 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -223,6 +223,17 @@ class is dynamically inherited by the worker class. This is used to inject not change by dcp, it simply reuse the GPUs of TP group, and tp_size needs to be divisible by dcp_size.""" + dcp_kv_cache_interleave_size: int = 1 + """Interleave size of kv_cache storage while using dcp or cp > 1, + store interleave_size tokens on (d)cp i, + then store next interleave_size tokens on (d)cp i+1. + Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size. + Interleave_size=block_size: block-level align, first fill the block on first rank, + token is stored on rank i+1 block j after rank i block j is full. + Block_size should be greater than or equal to dcp_kv_cache_interleave_size. + Block_size should be divisible by dcp_kv_cache_interleave_size. + """ + _api_process_count: int = Field(default=1, gt=0) """ The number of API processes initialized. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index dabd06c32054..544484e3cf9f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -480,6 +480,23 @@ def __post_init__(self): ) current_platform.check_and_update_config(self) + assert ( + self.parallel_config.dcp_kv_cache_interleave_size + <= self.cache_config.block_size + and self.cache_config.block_size + % self.parallel_config.dcp_kv_cache_interleave_size + == 0 + ), ( + f"Block_size({self.cache_config.block_size}) should be " + "greater than or equal to and divisible by dcp_kv_cache_interleave_size " + f"({self.parallel_config.dcp_kv_cache_interleave_size})." + ) + + assert ( + self.parallel_config.dcp_kv_cache_interleave_size == 1 + or self.speculative_config is None + ), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now." + # Do this after all the updates to compilation_config.mode if ( envs.VLLM_USE_V1 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 801c30dc9478..9fbb2d70ad6e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -362,6 +362,7 @@ class EngineArgs: pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size + dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: int | None = None data_parallel_start_rank: int | None = None @@ -717,6 +718,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "-dcp", **parallel_kwargs["decode_context_parallel_size"], ) + parallel_group.add_argument( + "--dcp-kv-cache-interleave-size", + **parallel_kwargs["dcp_kv_cache_interleave_size"], + ) parallel_group.add_argument( "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] ) @@ -1482,6 +1487,7 @@ def create_engine_config( worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, decode_context_parallel_size=self.decode_context_parallel_size, + dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size, _api_process_count=self._api_process_count, _api_process_rank=self._api_process_rank, ) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f7e6f12363ad..995b6bf02fa5 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -774,15 +774,6 @@ def build( ) ) - # Note(hc): update seq_lens of decode reqs under DCP. - if self.dcp_world_size > 1: - assert dcp_local_seq_lens is not None - dcp_local_seq_lens[:num_decodes] = seq_lens[ - :num_decodes - ] // self.dcp_world_size + ( - self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size - ) - assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index cb5855548098..afdcea982309 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -992,3 +992,35 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore return nums_dict, batch_ptr, token_chunk_offset_ptr + + +def get_dcp_local_seq_lens( + seq_lens: torch.Tensor, + dcp_world_size: int = 1, + dcp_kv_cache_interleave_size: int = 1, +) -> torch.Tensor: + """While using dcp, kv_cache size stored on each rank may be different, + use this function to calculate split decode seq_lens of each dcp rank. + Only consider dcp now, we can extend the case of cp based on this. + """ + num_requests = seq_lens.size(0) + seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, dcp_world_size) + rank_offsets = ( + torch.arange(dcp_world_size, dtype=torch.int32) + .unsqueeze(0) + .repeat(num_requests, 1) + ) + base = ( + seq_lens_tiled + // dcp_kv_cache_interleave_size + // dcp_world_size + * dcp_kv_cache_interleave_size + ) + remainder = seq_lens_tiled - base * dcp_world_size + remainder = torch.clip( + remainder - rank_offsets * dcp_kv_cache_interleave_size, + 0, + dcp_kv_cache_interleave_size, + ) + dcp_local_seq_lens = base + remainder + return dcp_local_seq_lens diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 9bf06d51609f..a86042120783 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -22,6 +22,7 @@ def __init__( pin_memory: bool, device: torch.device, kernel_block_size: int, + dcp_kv_cache_interleave_size: int, ): """ Args: @@ -86,6 +87,7 @@ def __init__( # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size def append_row( self, @@ -144,9 +146,19 @@ def compute_slot_mapping( # Use virtual_block_size for mask calculation, which marks local # tokens. virtual_block_offsets = positions % virtual_block_size - mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + mask = ( + virtual_block_offsets + // self.dcp_kv_cache_interleave_size + % self.dcp_world_size + == self.dcp_rank + ) # Calculate local block_offsets - block_offsets = virtual_block_offsets // self.dcp_world_size + block_offsets = ( + virtual_block_offsets + // (self.dcp_world_size * self.dcp_kv_cache_interleave_size) + * self.dcp_kv_cache_interleave_size + + virtual_block_offsets % self.dcp_kv_cache_interleave_size + ) # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local @@ -234,6 +246,7 @@ def __init__( block_sizes: list[int], kernel_block_sizes: list[int], num_speculative_tokens: int = 0, + dcp_kv_cache_interleave_size: int = 1, ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, @@ -263,6 +276,7 @@ def __init__( pin_memory, device, kernel_block_size, + dcp_kv_cache_interleave_size, ) for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b8751546f767..101dabcadbb7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -83,6 +83,7 @@ def __init__( is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, + dcp_kv_cache_interleave_size: int = 1, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -135,6 +136,7 @@ def __init__( block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, num_speculative_tokens=num_speculative_tokens, + dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d995a609318c..a778caafce89 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,6 +35,7 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( + get_dcp_group, get_pp_group, get_tp_group, graph_capture, @@ -92,6 +93,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, + get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, split_attn_metadata, ) @@ -256,6 +258,7 @@ def __init__( self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len self.dcp_world_size = self.parallel_config.decode_context_parallel_size + self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -372,6 +375,7 @@ def __init__( # uses output token ids so we set this conservatively. logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, + dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -1274,6 +1278,15 @@ def _prepare_inputs( logits_indices ) + # update seq_lens of decode reqs under DCP. + if self.dcp_world_size > 1: + self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( + self.seq_lens.cpu[:num_reqs], + self.dcp_world_size, + self.parallel_config.dcp_kv_cache_interleave_size, + )[:, self.dcp_rank] + self.dcp_local_seq_lens.copy_to_gpu(num_reqs) + attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: attn_metadata = [dict() for _ in range(len(ubatch_slices))] From c5dc44a846ca7f9815d5a8682a97af55479f2935 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 21 Oct 2025 19:57:34 +0800 Subject: [PATCH 02/10] [bugfix] fix wrong cp_context_lens and cp_target_rank caused by interleave size > 1 Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/mla/common.py | 33 ++++++++++++++++++++---- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 995b6bf02fa5..f7a4f69b3dc7 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -554,6 +554,7 @@ def __init__( # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = parallel_config.dcp_kv_cache_interleave_size # Don't try to access the runner on AMD if self.aot_schedule: @@ -783,9 +784,13 @@ def build( context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] # Note(hc): The context lengths in the perspective of dcp rank0. - cp_context_lens_cpu = torch.ceil( - context_lens_cpu.float() / self.dcp_world_size - ).int() + cp_context_lens_cpu = ( + torch.ceil( + context_lens_cpu.float() + / (self.dcp_world_size * self.dcp_kv_cache_interleave_size) + ).int() + * self.dcp_kv_cache_interleave_size + ) origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() @@ -972,6 +977,7 @@ def reorg_kvcache( chunk_size: int, chunk_idx: int, toks: int, + interleave_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ reorg kvcache after cp local gather to tp layout for attn kernel. @@ -986,6 +992,7 @@ def reorg_kvcache( chunked_context_metadata building. chunk_idx: chunk idx of chunked_prefill. toks: the number of tokens for local gather cache. + interleave_size: Interleave size of kv_cache storage. """ kv_c_segments = [] k_pe_segments = [] @@ -999,11 +1006,23 @@ def reorg_kvcache( chunk_context_len = min( chunk_context_len, origin_context_len - chunk_size * chunk_idx ) - cp_target_rank = (chunk_context_len - 1) % cp_world_size + + interleave_remainder = chunk_context_len % (interleave_size * cp_world_size) + if interleave_remainder > 0: + cp_target_rank = interleave_remainder // interleave_size + cp_target_rank_remainder = interleave_remainder % interleave_size + else: + cp_target_rank = cp_world_size + cp_target_rank_remainder = interleave_size + cur_seq_len = 0 for rank in range(cp_world_size): if rank > cp_target_rank and cp_chunk_seq_len: - real_cp_chunk_seq_len = cp_chunk_seq_len - 1 + real_cp_chunk_seq_len = cp_chunk_seq_len - interleave_size + elif rank == cp_target_rank and cp_chunk_seq_len: + real_cp_chunk_seq_len = ( + cp_chunk_seq_len - interleave_size + cp_target_rank_remainder + ) else: real_cp_chunk_seq_len = cp_chunk_seq_len if real_cp_chunk_seq_len: @@ -1254,6 +1273,9 @@ def __init__(self, *args, **kwargs) -> None: get_current_vllm_config() ) ) + self.dcp_kv_cache_interleave_size: int = ( + get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size + ) def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs @@ -1632,6 +1654,7 @@ def _context_parallel_compute_prefill_context( chunk_size=prefill_metadata.chunked_context.chunk_size, chunk_idx=i, toks=toks, + interleave_size=self.dcp_kv_cache_interleave_size, ) kv_nope = self.kv_b_proj(kv_c_normed)[0].view( From f695f85fa2a418f2f9d12e8327801000be33b00c Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Wed, 22 Oct 2025 10:31:01 +0800 Subject: [PATCH 03/10] [refactor] add dcp_rank params for get_dcp_local_seq_lens Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/utils.py | 18 +++++++++++------- vllm/v1/worker/gpu_model_runner.py | 3 ++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index afdcea982309..3bc5f8e422e2 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -997,6 +997,7 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): def get_dcp_local_seq_lens( seq_lens: torch.Tensor, dcp_world_size: int = 1, + dcp_rank: int | None = None, dcp_kv_cache_interleave_size: int = 1, ) -> torch.Tensor: """While using dcp, kv_cache size stored on each rank may be different, @@ -1004,12 +1005,15 @@ def get_dcp_local_seq_lens( Only consider dcp now, we can extend the case of cp based on this. """ num_requests = seq_lens.size(0) - seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, dcp_world_size) - rank_offsets = ( - torch.arange(dcp_world_size, dtype=torch.int32) - .unsqueeze(0) - .repeat(num_requests, 1) - ) + if dcp_rank is None: + rank_offsets = ( + torch.arange(dcp_world_size, dtype=torch.int32) + .unsqueeze(0) + .repeat(num_requests, 1) + ) + else: + rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32) + seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, rank_offsets.shape[1]) base = ( seq_lens_tiled // dcp_kv_cache_interleave_size @@ -1023,4 +1027,4 @@ def get_dcp_local_seq_lens( dcp_kv_cache_interleave_size, ) dcp_local_seq_lens = base + remainder - return dcp_local_seq_lens + return dcp_local_seq_lens \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a778caafce89..41516b9fc12b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1283,8 +1283,9 @@ def _prepare_inputs( self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( self.seq_lens.cpu[:num_reqs], self.dcp_world_size, + self.dcp_rank, self.parallel_config.dcp_kv_cache_interleave_size, - )[:, self.dcp_rank] + ) self.dcp_local_seq_lens.copy_to_gpu(num_reqs) attn_metadata: PerLayerAttnMetadata = {} From 3c99f2919dabbc8a7d6f0694d8e864635302fc42 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Wed, 22 Oct 2025 11:03:27 +0800 Subject: [PATCH 04/10] [refactor] Reuse get_dcp_local_seq_lens Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/mla/common.py | 31 +++++++++--------------- vllm/v1/attention/backends/utils.py | 2 +- 2 files changed, 13 insertions(+), 20 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f7a4f69b3dc7..ec851e6a830f 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -222,6 +222,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, + get_dcp_local_seq_lens, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills, @@ -998,8 +999,15 @@ def reorg_kvcache( k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 - for cp_chunk_seq_len, origin_context_len in zip( - cp_chunk_seq_lens_lst, origin_context_lens + local_context_lens_allrank = get_dcp_local_seq_lens( + torch.Tensor(origin_context_lens), + cp_world_size, + None, + interleave_size, + ) + # print(origin_context_lens, local_context_lens_allrank) + for cp_chunk_seq_len, origin_context_len, local_context_lens in zip( + cp_chunk_seq_lens_lst, origin_context_lens, local_context_lens_allrank ): chunk_context_len = chunk_size if cp_chunk_seq_len != 0: @@ -1007,25 +1015,10 @@ def reorg_kvcache( chunk_context_len, origin_context_len - chunk_size * chunk_idx ) - interleave_remainder = chunk_context_len % (interleave_size * cp_world_size) - if interleave_remainder > 0: - cp_target_rank = interleave_remainder // interleave_size - cp_target_rank_remainder = interleave_remainder % interleave_size - else: - cp_target_rank = cp_world_size - cp_target_rank_remainder = interleave_size - cur_seq_len = 0 for rank in range(cp_world_size): - if rank > cp_target_rank and cp_chunk_seq_len: - real_cp_chunk_seq_len = cp_chunk_seq_len - interleave_size - elif rank == cp_target_rank and cp_chunk_seq_len: - real_cp_chunk_seq_len = ( - cp_chunk_seq_len - interleave_size + cp_target_rank_remainder - ) - else: - real_cp_chunk_seq_len = cp_chunk_seq_len - if real_cp_chunk_seq_len: + real_cp_chunk_seq_len = local_context_lens[rank] + if real_cp_chunk_seq_len != 0: kv_c_segment = allgatered_kv_c_normed[ rank * toks + src_token_idx : rank * toks + src_token_idx diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 3bc5f8e422e2..c8954e0c2d55 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1013,7 +1013,7 @@ def get_dcp_local_seq_lens( ) else: rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32) - seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, rank_offsets.shape[1]) + seq_lens_tiled = seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1]) base = ( seq_lens_tiled // dcp_kv_cache_interleave_size From 67b38209a5bf993823f4961a3d64036a85039153 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Wed, 22 Oct 2025 11:06:36 +0800 Subject: [PATCH 05/10] [backend] support interleave size > 1 for flash_attn Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/flash_attn.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9e0c125d9edb..3ad2a12ea614 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -40,6 +40,7 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + get_dcp_local_seq_lens, get_kv_cache_layout, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -232,6 +233,9 @@ def __init__( self.dcp_world_size = 1 self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = \ + self.parallel_config.dcp_kv_cache_interleave_size + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) @@ -350,8 +354,12 @@ def schedule( - common_attn_metadata.query_start_loc_cpu[:-1] ) dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu - dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + ( - self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size + + dcp_context_kv_lens_cpu = get_dcp_local_seq_lens( + dcp_context_kv_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_kv_cache_interleave_size, ) dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) max_dcp_context_kv_len = dcp_context_kv_lens.max().item() From 72e3a0f6a9e5d2ee0daee8068d4d277f65e2fd78 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Thu, 23 Oct 2025 09:39:12 +0800 Subject: [PATCH 06/10] [bugfix] wrong dim of specific rank Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index c8954e0c2d55..aa0b6e9bdcb2 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1027,4 +1027,4 @@ def get_dcp_local_seq_lens( dcp_kv_cache_interleave_size, ) dcp_local_seq_lens = base + remainder - return dcp_local_seq_lens \ No newline at end of file + return dcp_local_seq_lens.squeeze(1) \ No newline at end of file From 5d7184afb48bce88970aedabd0a38a697b4edaf3 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Thu, 23 Oct 2025 19:23:19 +0800 Subject: [PATCH 07/10] [bugfix] wrong dcp_context_kv_lens_cpu Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/flash_attn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3ad2a12ea614..ed4c4d5adb9d 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -360,7 +360,7 @@ def schedule( self.dcp_world_size, self.dcp_rank, self.dcp_kv_cache_interleave_size, - ) + ) + (dcp_context_kv_lens_cpu % self.dcp_world_size == 0) dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) max_dcp_context_kv_len = dcp_context_kv_lens.max().item() From 68623614ed048b6579d6b7a8aaae2e96c0b3d58b Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Fri, 24 Oct 2025 14:52:48 +0800 Subject: [PATCH 08/10] [lint] Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/flash_attn.py | 3 ++- vllm/v1/attention/backends/utils.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ed4c4d5adb9d..afc8dcbbcd41 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -233,8 +233,9 @@ def __init__( self.dcp_world_size = 1 self.dcp_rank = 0 - self.dcp_kv_cache_interleave_size = \ + self.dcp_kv_cache_interleave_size = ( self.parallel_config.dcp_kv_cache_interleave_size + ) self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index aa0b6e9bdcb2..9234835e7c20 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -1013,7 +1013,9 @@ def get_dcp_local_seq_lens( ) else: rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32) - seq_lens_tiled = seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1]) + seq_lens_tiled = ( + seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1]) + ) base = ( seq_lens_tiled // dcp_kv_cache_interleave_size @@ -1027,4 +1029,4 @@ def get_dcp_local_seq_lens( dcp_kv_cache_interleave_size, ) dcp_local_seq_lens = base + remainder - return dcp_local_seq_lens.squeeze(1) \ No newline at end of file + return dcp_local_seq_lens.squeeze(1) From 5735b8f1dd5b6cb44e7ac6cfbd473b5a42d3a593 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Tue, 28 Oct 2025 15:49:16 +0800 Subject: [PATCH 09/10] disable DCP with GQA-flash_attn Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/flash_attn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index afc8dcbbcd41..efd5d5c1af44 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -361,7 +361,7 @@ def schedule( self.dcp_world_size, self.dcp_rank, self.dcp_kv_cache_interleave_size, - ) + (dcp_context_kv_lens_cpu % self.dcp_world_size == 0) + ) dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) max_dcp_context_kv_len = dcp_context_kv_lens.max().item() @@ -446,7 +446,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool: class FlashAttentionImpl(AttentionImpl): - can_return_lse_for_decode: bool = True + # TODO(qcs): enable DCP when `flash_attn_varlen_func` supports ctxlen(seqused_k)=0 + can_return_lse_for_decode: bool = False def __init__( self, From 692fdf5454475e5d4c3157faac67fc98be5bb720 Mon Sep 17 00:00:00 2001 From: QiuChunshuo Date: Mon, 3 Nov 2025 10:31:41 +0800 Subject: [PATCH 10/10] [refactor] rename and clean code Signed-off-by: QiuChunshuo --- vllm/v1/attention/backends/mla/common.py | 134 ++++++++++------------- 1 file changed, 59 insertions(+), 75 deletions(-) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index ec851e6a830f..e75d4c075a83 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -359,10 +359,9 @@ class ChunkedContextMetadata: workspace: torch.Tensor # for mla DCP - cp_chunk_seq_lens: list[list[int]] | None = None - origin_context_lens: list[int] | None = None - cp_cu_seq_lens: torch.Tensor | None = None - chunk_size: int | None = None + local_chunk_seq_lens: list[list[int]] | None = None + local_context_lens_allrank: list[list[int]] | None = None + local_cu_seq_lens: torch.Tensor | None = None cu_seq_lens_lst: list[list[int]] | None = None block_table: torch.Tensor @@ -784,15 +783,6 @@ def build( reqs_start = num_decodes # prefill_start context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] - # Note(hc): The context lengths in the perspective of dcp rank0. - cp_context_lens_cpu = ( - torch.ceil( - context_lens_cpu.float() - / (self.dcp_world_size * self.dcp_kv_cache_interleave_size) - ).int() - * self.dcp_kv_cache_interleave_size - ) - origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = ( @@ -848,32 +838,52 @@ def build( ) if self.dcp_world_size > 1: + local_context_lens_allrank = get_dcp_local_seq_lens( + context_lens_cpu, + self.dcp_world_size, + None, + self.dcp_kv_cache_interleave_size, + ) + # Note(qcs): The context lengths in the perspective of dcp rank0 + # padded to `dcp_kv_cache_interleave_size`. + local_context_lens_cpu = ( + torch.ceil( + context_lens_cpu.float() + / (self.dcp_world_size * self.dcp_kv_cache_interleave_size) + ).int() + * self.dcp_kv_cache_interleave_size + ) # Note(hc): The above max_context_chunk already enforces # block_size alignment, DCP just need the block_size can # be divisible by dcp_world_size, because DCP use # cp_gather_cache which not require `cp_chunk_starts` # aligned to page_size. assert max_context_chunk % self.dcp_world_size == 0 - cp_max_context_chunk = max_context_chunk // self.dcp_world_size - cp_chunk_starts = ( + local_max_context_chunk = ( + max_context_chunk + // (self.dcp_world_size * self.dcp_kv_cache_interleave_size) + ) * self.dcp_kv_cache_interleave_size + local_chunk_starts = ( torch.arange(num_chunks, dtype=torch.int32) .unsqueeze(1) .expand(-1, num_prefills) - * cp_max_context_chunk + * local_max_context_chunk ) - cp_chunk_ends = torch.min( - cp_context_lens_cpu.unsqueeze(0), - cp_chunk_starts + cp_max_context_chunk, + local_chunk_ends = torch.min( + local_context_lens_cpu.unsqueeze(0), + local_chunk_starts + local_max_context_chunk, ) - cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0) + local_chunk_seq_lens = ( + local_chunk_ends - local_chunk_starts + ).clamp(min=0) - cp_cu_seq_lens_cpu = torch.zeros( + local_cu_chunk_seq_lens_cpu = torch.zeros( num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True ) torch.cumsum( - cp_chunk_seq_lens, + local_chunk_seq_lens, dim=1, - out=cp_cu_seq_lens_cpu[:, 1:], + out=local_cu_chunk_seq_lens_cpu[:, 1:], dtype=torch.int32, ) @@ -885,15 +895,16 @@ def build( if self.dcp_world_size > 1: chunked_context_metadata = chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=cp_chunk_starts.to(device, non_blocking=True), - seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), + starts=local_chunk_starts.to(device, non_blocking=True), + seq_tot=local_chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), seq_lens=chunk_seq_lens, workspace=self.chunked_prefill_workspace, - cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), - origin_context_lens=origin_context_lens, - cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True), - chunk_size=max_context_chunk, + local_chunk_seq_lens=local_chunk_seq_lens.tolist(), + local_context_lens_allrank=local_context_lens_allrank.tolist(), + local_cu_seq_lens=local_cu_chunk_seq_lens_cpu.to( + device, non_blocking=True + ), cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), ) else: @@ -970,70 +981,48 @@ def build( def reorg_kvcache( allgatered_kv_c_normed: torch.Tensor, allgatered_k_pe: torch.Tensor, - cp_chunk_seq_lens_lst: list[int], - origin_context_lens: list[int], - cp_world_size: int, + local_chunk_seq_lens_lst: list[int], + local_context_lens_allrank: list[list[int]], sum_seq_len: int, max_seq_len: int, - chunk_size: int, - chunk_idx: int, toks: int, - interleave_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ reorg kvcache after cp local gather to tp layout for attn kernel. Args: - cp_chunk_seq_lens_lst: chunk context lengths under CP. - origin_context_lens: origin full context lengths under CP. - cp_world_size: CP size. + local_chunk_seq_lens_lst: local chunk context lengths + under current CP rank. + local_context_lens_allrank: local context lengths on each CP rank. sum_seq_len: the sum of cp_chunk_seq_lens_lst. max_seq_len: the max value of cp_chunk_seq_lens_lst. - chunk_size: equals to max_context_chunk from - chunked_context_metadata building. - chunk_idx: chunk idx of chunked_prefill. toks: the number of tokens for local gather cache. - interleave_size: Interleave size of kv_cache storage. """ kv_c_segments = [] k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 - local_context_lens_allrank = get_dcp_local_seq_lens( - torch.Tensor(origin_context_lens), - cp_world_size, - None, - interleave_size, - ) - # print(origin_context_lens, local_context_lens_allrank) - for cp_chunk_seq_len, origin_context_len, local_context_lens in zip( - cp_chunk_seq_lens_lst, origin_context_lens, local_context_lens_allrank + for local_chunk_seq_len, local_context_lens in zip( + local_chunk_seq_lens_lst, local_context_lens_allrank ): - chunk_context_len = chunk_size - if cp_chunk_seq_len != 0: - chunk_context_len = min( - chunk_context_len, origin_context_len - chunk_size * chunk_idx - ) - cur_seq_len = 0 - for rank in range(cp_world_size): - real_cp_chunk_seq_len = local_context_lens[rank] - if real_cp_chunk_seq_len != 0: + for rank, local_context_len in enumerate(local_context_lens): + if local_context_len != 0: kv_c_segment = allgatered_kv_c_normed[ rank * toks + src_token_idx : rank * toks + src_token_idx - + real_cp_chunk_seq_len + + local_context_len ] k_pe_segment = allgatered_k_pe[ rank * toks + src_token_idx : rank * toks + src_token_idx - + real_cp_chunk_seq_len + + local_context_len ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) - cur_seq_len += real_cp_chunk_seq_len + cur_seq_len += local_context_len max_seq_len_check = max(max_seq_len_check, cur_seq_len) - src_token_idx += cp_chunk_seq_len + src_token_idx += local_chunk_seq_len reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) reorganized_k_pe = torch.cat(k_pe_segments, dim=0) assert reorganized_kv_c_normed.shape[0] == sum_seq_len @@ -1591,10 +1580,9 @@ def _context_parallel_compute_prefill_context( assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None - assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None - assert prefill_metadata.chunked_context.origin_context_lens is not None - assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None - assert prefill_metadata.chunked_context.chunk_size is not None + assert prefill_metadata.chunked_context.local_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.local_context_lens_allrank is not None + assert prefill_metadata.chunked_context.local_cu_seq_lens is not None assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None output = None @@ -1607,7 +1595,7 @@ def _context_parallel_compute_prefill_context( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, - cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i], + cu_seq_lens=prefill_metadata.chunked_context.local_cu_seq_lens[i], batch_size=attn_metadata.num_prefills, seq_starts=prefill_metadata.chunked_context.starts[i], ) @@ -1637,17 +1625,13 @@ def _context_parallel_compute_prefill_context( kv_c_normed, k_pe = reorg_kvcache( allgatered_kv_c_normed, allgatered_k_pe, - cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ + local_chunk_seq_lens_lst=prefill_metadata.chunked_context.local_chunk_seq_lens[ i ], - origin_context_lens=prefill_metadata.chunked_context.origin_context_lens, - cp_world_size=dcp_world_size, + local_context_lens_allrank=prefill_metadata.chunked_context.local_context_lens_allrank, sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], - chunk_size=prefill_metadata.chunked_context.chunk_size, - chunk_idx=i, toks=toks, - interleave_size=self.dcp_kv_cache_interleave_size, ) kv_nope = self.kv_b_proj(kv_c_normed)[0].view(