diff --git a/tests/distributed/test_context_parallel.py b/tests/distributed/test_context_parallel.py index 5495640af07e..7f8e77a75621 100644 --- a/tests/distributed/test_context_parallel.py +++ b/tests/distributed/test_context_parallel.py @@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple): tp_size: int pp_size: int dcp_size: int + dcp_kv_cache_interleave_size: int eager_mode: bool chunked_prefill: bool @@ -52,6 +53,7 @@ def detailed( tp_base: int = 4, pp_base: int = 1, dcp_base: int = 1, + dcp_kv_cache_interleave_size: int = 1, multi_node_only: bool = False, runner: RunnerOption = "auto", load_format: str | None = None, @@ -66,6 +68,7 @@ def detailed( tp_size=tp_base, pp_size=pp_multiplier * pp_base, dcp_size=int(dcp_multiplier * tp_base), + dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, eager_mode=eager_mode_val, chunked_prefill=chunked_prefill_val, ) @@ -108,6 +111,7 @@ def _compare_cp_with_tp( tp_size, pp_size, dcp_size, + dcp_kv_cache_interleave_size, eager_mode, chunked_prefill, ) = parallel_setup @@ -180,6 +184,8 @@ def _compare_cp_with_tp( str(pp_size), "--decode-context-parallel-size", str(dcp_size), + "--dcp-kv-cache-interleave-size", + str(dcp_kv_cache_interleave_size), "--distributed-executor-backend", distributed_backend, ] @@ -207,6 +213,7 @@ def _compare_cp_with_tp( "deepseek-ai/DeepSeek-V2-Lite-Chat": [ CPTestSettings.detailed(), CPTestSettings.detailed(tp_base=2), + CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64), ], "bigcode/gpt_bigcode-santacoder": [ CPTestSettings.detailed(), diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index b6b7ecd2552a..75fdcb8f48b2 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -53,6 +53,7 @@ def _correct_attn_cp_out_kernel( lse = tl.load(lses_ptr + lse_offsets) lse = tl.where((lse != lse) | (lse == float("inf")), -float("inf"), lse) lse_max = tl.max(lse, axis=0) + lse_max = tl.where(lse_max == -float("inf"), 0, lse_max) lse -= lse_max lse_exp = tl.exp(lse) lse_acc = tl.sum(lse_exp, axis=0) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index b7ef0fef6833..6ee1270df569 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -223,6 +223,17 @@ class is dynamically inherited by the worker class. This is used to inject not change by dcp, it simply reuse the GPUs of TP group, and tp_size needs to be divisible by dcp_size.""" + dcp_kv_cache_interleave_size: int = 1 + """Interleave size of kv_cache storage while using dcp or cp > 1, + store interleave_size tokens on (d)cp i, + then store next interleave_size tokens on (d)cp i+1. + Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size. + Interleave_size=block_size: block-level align, first fill the block on first rank, + token is stored on rank i+1 block j after rank i block j is full. + Block_size should be greater than or equal to dcp_kv_cache_interleave_size. + Block_size should be divisible by dcp_kv_cache_interleave_size. + """ + _api_process_count: int = Field(default=1, gt=0) """ The number of API processes initialized. diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index dabd06c32054..544484e3cf9f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -480,6 +480,23 @@ def __post_init__(self): ) current_platform.check_and_update_config(self) + assert ( + self.parallel_config.dcp_kv_cache_interleave_size + <= self.cache_config.block_size + and self.cache_config.block_size + % self.parallel_config.dcp_kv_cache_interleave_size + == 0 + ), ( + f"Block_size({self.cache_config.block_size}) should be " + "greater than or equal to and divisible by dcp_kv_cache_interleave_size " + f"({self.parallel_config.dcp_kv_cache_interleave_size})." + ) + + assert ( + self.parallel_config.dcp_kv_cache_interleave_size == 1 + or self.speculative_config is None + ), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now." + # Do this after all the updates to compilation_config.mode if ( envs.VLLM_USE_V1 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 801c30dc9478..9fbb2d70ad6e 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -362,6 +362,7 @@ class EngineArgs: pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size + dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: int | None = None data_parallel_start_rank: int | None = None @@ -717,6 +718,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "-dcp", **parallel_kwargs["decode_context_parallel_size"], ) + parallel_group.add_argument( + "--dcp-kv-cache-interleave-size", + **parallel_kwargs["dcp_kv_cache_interleave_size"], + ) parallel_group.add_argument( "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] ) @@ -1482,6 +1487,7 @@ def create_engine_config( worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, decode_context_parallel_size=self.decode_context_parallel_size, + dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size, _api_process_count=self._api_process_count, _api_process_rank=self._api_process_rank, ) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9e0c125d9edb..ae06928ba260 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -40,6 +40,7 @@ AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, + get_dcp_local_seq_lens, get_kv_cache_layout, ) from vllm.v1.kv_cache_interface import AttentionSpec @@ -232,6 +233,10 @@ def __init__( self.dcp_world_size = 1 self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = ( + self.parallel_config.dcp_kv_cache_interleave_size + ) + self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) @@ -350,8 +355,12 @@ def schedule( - common_attn_metadata.query_start_loc_cpu[:-1] ) dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu - dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + ( - self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size + + dcp_context_kv_lens_cpu = get_dcp_local_seq_lens( + dcp_context_kv_lens_cpu, + self.dcp_world_size, + self.dcp_rank, + self.dcp_kv_cache_interleave_size, ) dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device) max_dcp_context_kv_len = dcp_context_kv_lens.max().item() diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index f7e6f12363ad..e75d4c075a83 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -222,6 +222,7 @@ from vllm.v1.attention.backends.utils import ( AttentionMetadataBuilder, CommonAttentionMetadata, + get_dcp_local_seq_lens, get_per_layer_parameters, infer_global_hyperparameters, split_decodes_and_prefills, @@ -358,10 +359,9 @@ class ChunkedContextMetadata: workspace: torch.Tensor # for mla DCP - cp_chunk_seq_lens: list[list[int]] | None = None - origin_context_lens: list[int] | None = None - cp_cu_seq_lens: torch.Tensor | None = None - chunk_size: int | None = None + local_chunk_seq_lens: list[list[int]] | None = None + local_context_lens_allrank: list[list[int]] | None = None + local_cu_seq_lens: torch.Tensor | None = None cu_seq_lens_lst: list[list[int]] | None = None block_table: torch.Tensor @@ -554,6 +554,7 @@ def __init__( # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = parallel_config.dcp_kv_cache_interleave_size # Don't try to access the runner on AMD if self.aot_schedule: @@ -774,15 +775,6 @@ def build( ) ) - # Note(hc): update seq_lens of decode reqs under DCP. - if self.dcp_world_size > 1: - assert dcp_local_seq_lens is not None - dcp_local_seq_lens[:num_decodes] = seq_lens[ - :num_decodes - ] // self.dcp_world_size + ( - self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size - ) - assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -791,11 +783,6 @@ def build( reqs_start = num_decodes # prefill_start context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] - # Note(hc): The context lengths in the perspective of dcp rank0. - cp_context_lens_cpu = torch.ceil( - context_lens_cpu.float() / self.dcp_world_size - ).int() - origin_context_lens = context_lens_cpu.tolist() max_context_len_cpu = context_lens_cpu.max().item() num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() prefill_query_start_loc = ( @@ -851,32 +838,52 @@ def build( ) if self.dcp_world_size > 1: + local_context_lens_allrank = get_dcp_local_seq_lens( + context_lens_cpu, + self.dcp_world_size, + None, + self.dcp_kv_cache_interleave_size, + ) + # Note(qcs): The context lengths in the perspective of dcp rank0 + # padded to `dcp_kv_cache_interleave_size`. + local_context_lens_cpu = ( + torch.ceil( + context_lens_cpu.float() + / (self.dcp_world_size * self.dcp_kv_cache_interleave_size) + ).int() + * self.dcp_kv_cache_interleave_size + ) # Note(hc): The above max_context_chunk already enforces # block_size alignment, DCP just need the block_size can # be divisible by dcp_world_size, because DCP use # cp_gather_cache which not require `cp_chunk_starts` # aligned to page_size. assert max_context_chunk % self.dcp_world_size == 0 - cp_max_context_chunk = max_context_chunk // self.dcp_world_size - cp_chunk_starts = ( + local_max_context_chunk = ( + max_context_chunk + // (self.dcp_world_size * self.dcp_kv_cache_interleave_size) + ) * self.dcp_kv_cache_interleave_size + local_chunk_starts = ( torch.arange(num_chunks, dtype=torch.int32) .unsqueeze(1) .expand(-1, num_prefills) - * cp_max_context_chunk + * local_max_context_chunk ) - cp_chunk_ends = torch.min( - cp_context_lens_cpu.unsqueeze(0), - cp_chunk_starts + cp_max_context_chunk, + local_chunk_ends = torch.min( + local_context_lens_cpu.unsqueeze(0), + local_chunk_starts + local_max_context_chunk, ) - cp_chunk_seq_lens = (cp_chunk_ends - cp_chunk_starts).clamp(min=0) + local_chunk_seq_lens = ( + local_chunk_ends - local_chunk_starts + ).clamp(min=0) - cp_cu_seq_lens_cpu = torch.zeros( + local_cu_chunk_seq_lens_cpu = torch.zeros( num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True ) torch.cumsum( - cp_chunk_seq_lens, + local_chunk_seq_lens, dim=1, - out=cp_cu_seq_lens_cpu[:, 1:], + out=local_cu_chunk_seq_lens_cpu[:, 1:], dtype=torch.int32, ) @@ -888,15 +895,16 @@ def build( if self.dcp_world_size > 1: chunked_context_metadata = chunked_context_metadata_cls( cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=cp_chunk_starts.to(device, non_blocking=True), - seq_tot=cp_chunk_seq_lens.sum(dim=1).tolist(), + starts=local_chunk_starts.to(device, non_blocking=True), + seq_tot=local_chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), seq_lens=chunk_seq_lens, workspace=self.chunked_prefill_workspace, - cp_chunk_seq_lens=cp_chunk_seq_lens.tolist(), - origin_context_lens=origin_context_lens, - cp_cu_seq_lens=cp_cu_seq_lens_cpu.to(device, non_blocking=True), - chunk_size=max_context_chunk, + local_chunk_seq_lens=local_chunk_seq_lens.tolist(), + local_context_lens_allrank=local_context_lens_allrank.tolist(), + local_cu_seq_lens=local_cu_chunk_seq_lens_cpu.to( + device, non_blocking=True + ), cu_seq_lens_lst=cu_seq_lens_cpu.tolist(), ) else: @@ -973,64 +981,48 @@ def build( def reorg_kvcache( allgatered_kv_c_normed: torch.Tensor, allgatered_k_pe: torch.Tensor, - cp_chunk_seq_lens_lst: list[int], - origin_context_lens: list[int], - cp_world_size: int, + local_chunk_seq_lens_lst: list[int], + local_context_lens_allrank: list[list[int]], sum_seq_len: int, max_seq_len: int, - chunk_size: int, - chunk_idx: int, toks: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ reorg kvcache after cp local gather to tp layout for attn kernel. Args: - cp_chunk_seq_lens_lst: chunk context lengths under CP. - origin_context_lens: origin full context lengths under CP. - cp_world_size: CP size. + local_chunk_seq_lens_lst: local chunk context lengths + under current CP rank. + local_context_lens_allrank: local context lengths on each CP rank. sum_seq_len: the sum of cp_chunk_seq_lens_lst. max_seq_len: the max value of cp_chunk_seq_lens_lst. - chunk_size: equals to max_context_chunk from - chunked_context_metadata building. - chunk_idx: chunk idx of chunked_prefill. toks: the number of tokens for local gather cache. """ kv_c_segments = [] k_pe_segments = [] src_token_idx = 0 max_seq_len_check = 0 - for cp_chunk_seq_len, origin_context_len in zip( - cp_chunk_seq_lens_lst, origin_context_lens + for local_chunk_seq_len, local_context_lens in zip( + local_chunk_seq_lens_lst, local_context_lens_allrank ): - chunk_context_len = chunk_size - if cp_chunk_seq_len != 0: - chunk_context_len = min( - chunk_context_len, origin_context_len - chunk_size * chunk_idx - ) - cp_target_rank = (chunk_context_len - 1) % cp_world_size cur_seq_len = 0 - for rank in range(cp_world_size): - if rank > cp_target_rank and cp_chunk_seq_len: - real_cp_chunk_seq_len = cp_chunk_seq_len - 1 - else: - real_cp_chunk_seq_len = cp_chunk_seq_len - if real_cp_chunk_seq_len: + for rank, local_context_len in enumerate(local_context_lens): + if local_context_len != 0: kv_c_segment = allgatered_kv_c_normed[ rank * toks + src_token_idx : rank * toks + src_token_idx - + real_cp_chunk_seq_len + + local_context_len ] k_pe_segment = allgatered_k_pe[ rank * toks + src_token_idx : rank * toks + src_token_idx - + real_cp_chunk_seq_len + + local_context_len ] kv_c_segments.append(kv_c_segment) k_pe_segments.append(k_pe_segment) - cur_seq_len += real_cp_chunk_seq_len + cur_seq_len += local_context_len max_seq_len_check = max(max_seq_len_check, cur_seq_len) - src_token_idx += cp_chunk_seq_len + src_token_idx += local_chunk_seq_len reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0) reorganized_k_pe = torch.cat(k_pe_segments, dim=0) assert reorganized_kv_c_normed.shape[0] == sum_seq_len @@ -1263,6 +1255,9 @@ def __init__(self, *args, **kwargs) -> None: get_current_vllm_config() ) ) + self.dcp_kv_cache_interleave_size: int = ( + get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size + ) def _flash_attn_varlen_diff_headdims( self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs @@ -1585,10 +1580,9 @@ def _context_parallel_compute_prefill_context( assert attn_metadata.prefill is not None prefill_metadata = attn_metadata.prefill assert prefill_metadata.chunked_context is not None - assert prefill_metadata.chunked_context.cp_chunk_seq_lens is not None - assert prefill_metadata.chunked_context.origin_context_lens is not None - assert prefill_metadata.chunked_context.cp_cu_seq_lens is not None - assert prefill_metadata.chunked_context.chunk_size is not None + assert prefill_metadata.chunked_context.local_chunk_seq_lens is not None + assert prefill_metadata.chunked_context.local_context_lens_allrank is not None + assert prefill_metadata.chunked_context.local_cu_seq_lens is not None assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None output = None @@ -1601,7 +1595,7 @@ def _context_parallel_compute_prefill_context( src_cache=kv_c_and_k_pe_cache, dst=workspace, block_table=prefill_metadata.block_table, - cu_seq_lens=prefill_metadata.chunked_context.cp_cu_seq_lens[i], + cu_seq_lens=prefill_metadata.chunked_context.local_cu_seq_lens[i], batch_size=attn_metadata.num_prefills, seq_starts=prefill_metadata.chunked_context.starts[i], ) @@ -1631,15 +1625,12 @@ def _context_parallel_compute_prefill_context( kv_c_normed, k_pe = reorg_kvcache( allgatered_kv_c_normed, allgatered_k_pe, - cp_chunk_seq_lens_lst=prefill_metadata.chunked_context.cp_chunk_seq_lens[ + local_chunk_seq_lens_lst=prefill_metadata.chunked_context.local_chunk_seq_lens[ i ], - origin_context_lens=prefill_metadata.chunked_context.origin_context_lens, - cp_world_size=dcp_world_size, + local_context_lens_allrank=prefill_metadata.chunked_context.local_context_lens_allrank, sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1], max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i], - chunk_size=prefill_metadata.chunked_context.chunk_size, - chunk_idx=i, toks=toks, ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index cb5855548098..9234835e7c20 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -992,3 +992,41 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor): nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore return nums_dict, batch_ptr, token_chunk_offset_ptr + + +def get_dcp_local_seq_lens( + seq_lens: torch.Tensor, + dcp_world_size: int = 1, + dcp_rank: int | None = None, + dcp_kv_cache_interleave_size: int = 1, +) -> torch.Tensor: + """While using dcp, kv_cache size stored on each rank may be different, + use this function to calculate split decode seq_lens of each dcp rank. + Only consider dcp now, we can extend the case of cp based on this. + """ + num_requests = seq_lens.size(0) + if dcp_rank is None: + rank_offsets = ( + torch.arange(dcp_world_size, dtype=torch.int32) + .unsqueeze(0) + .repeat(num_requests, 1) + ) + else: + rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32) + seq_lens_tiled = ( + seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1]) + ) + base = ( + seq_lens_tiled + // dcp_kv_cache_interleave_size + // dcp_world_size + * dcp_kv_cache_interleave_size + ) + remainder = seq_lens_tiled - base * dcp_world_size + remainder = torch.clip( + remainder - rank_offsets * dcp_kv_cache_interleave_size, + 0, + dcp_kv_cache_interleave_size, + ) + dcp_local_seq_lens = base + remainder + return dcp_local_seq_lens.squeeze(1) diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 9bf06d51609f..a86042120783 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -22,6 +22,7 @@ def __init__( pin_memory: bool, device: torch.device, kernel_block_size: int, + dcp_kv_cache_interleave_size: int, ): """ Args: @@ -86,6 +87,7 @@ def __init__( # DCP might not be initialized in testing self.dcp_world_size = 1 self.dcp_rank = 0 + self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size def append_row( self, @@ -144,9 +146,19 @@ def compute_slot_mapping( # Use virtual_block_size for mask calculation, which marks local # tokens. virtual_block_offsets = positions % virtual_block_size - mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank + mask = ( + virtual_block_offsets + // self.dcp_kv_cache_interleave_size + % self.dcp_world_size + == self.dcp_rank + ) # Calculate local block_offsets - block_offsets = virtual_block_offsets // self.dcp_world_size + block_offsets = ( + virtual_block_offsets + // (self.dcp_world_size * self.dcp_kv_cache_interleave_size) + * self.dcp_kv_cache_interleave_size + + virtual_block_offsets % self.dcp_kv_cache_interleave_size + ) # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local @@ -234,6 +246,7 @@ def __init__( block_sizes: list[int], kernel_block_sizes: list[int], num_speculative_tokens: int = 0, + dcp_kv_cache_interleave_size: int = 1, ) -> None: # Note(hc): each dcp rank only store # (max_model_len//dcp_world_size) tokens in kvcache, @@ -263,6 +276,7 @@ def __init__( pin_memory, device, kernel_block_size, + dcp_kv_cache_interleave_size, ) for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes) ] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index b8751546f767..101dabcadbb7 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -83,6 +83,7 @@ def __init__( is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, + dcp_kv_cache_interleave_size: int = 1, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode @@ -135,6 +136,7 @@ def __init__( block_sizes=block_sizes, kernel_block_sizes=kernel_block_sizes, num_speculative_tokens=num_speculative_tokens, + dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size, ) # Sampling-related. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d995a609318c..41516b9fc12b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -35,6 +35,7 @@ from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( + get_dcp_group, get_pp_group, get_tp_group, graph_capture, @@ -92,6 +93,7 @@ AttentionMetadataBuilder, CommonAttentionMetadata, create_fast_prefill_custom_backend, + get_dcp_local_seq_lens, reorder_batch_to_split_decodes_and_prefills, split_attn_metadata, ) @@ -256,6 +258,7 @@ def __init__( self.is_multimodal_pruning_enabled = False self.max_model_len = model_config.max_model_len self.dcp_world_size = self.parallel_config.decode_context_parallel_size + self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs @@ -372,6 +375,7 @@ def __init__( # uses output token ids so we set this conservatively. logitsprocs_need_output_token_ids=bool(custom_logitsprocs), is_pooling_model=self.is_pooling_model, + dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -1274,6 +1278,16 @@ def _prepare_inputs( logits_indices ) + # update seq_lens of decode reqs under DCP. + if self.dcp_world_size > 1: + self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( + self.seq_lens.cpu[:num_reqs], + self.dcp_world_size, + self.dcp_rank, + self.parallel_config.dcp_kv_cache_interleave_size, + ) + self.dcp_local_seq_lens.copy_to_gpu(num_reqs) + attn_metadata: PerLayerAttnMetadata = {} if ubatch_slices is not None: attn_metadata = [dict() for _ in range(len(ubatch_slices))]