diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index e31a78ba33ba..5be95901d660 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -617,6 +617,10 @@ def __post_init__(self): "Modify KVEventsConfig.enable_kv_cache_events" "to True to enable.") current_platform.check_and_update_config(self) + assert self.parallel_config.cp_kv_cache_interleave_size <= self.cache_config.block_size and \ + self.cache_config.block_size % self.parallel_config.cp_kv_cache_interleave_size == 0, \ + f"Block_size({self.cache_config.block_size}) should be greater than and divisible by "\ + f"cp_kv_cache_interleave_size({self.parallel_config.cp_kv_cache_interleave_size})." # final check of cudagraph mode after platform-specific update if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index e92aaa3b4b77..ba223efd442e 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -195,6 +195,15 @@ class is dynamically inherited by the worker class. This is used to inject not change by dcp, it simply reuse the GPUs of TP group, and tp_size needs to be divisible by dcp_size.""" + cp_kv_cache_interleave_size: int = 1 + """Interleave size of kv_cache storage while using dcp or cp > 1, + store interleave_size tokens on (d)cp i, then store next interleave_size tokens on (d)cp i+1. + Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size. + Interleave_size=block_size: block-level align, first fill the block on first rank, + token is stored on rank i+1 block j after rank i block j is full. + Block_size should be greater than and divisible by cp_kv_cache_interleave_size. + """ + _api_process_count: int = 1 """ The number of API processes initialized. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index afb384ebdd37..a9a0cd8bf2c9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -318,6 +318,7 @@ class EngineArgs: tensor_parallel_size: int = ParallelConfig.tensor_parallel_size decode_context_parallel_size: int = \ ParallelConfig.decode_context_parallel_size + cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size context_parallel_size: int = ParallelConfig.context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None @@ -654,6 +655,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--decode-context-parallel-size", "-dcp", **parallel_kwargs["decode_context_parallel_size"]) + parallel_group.add_argument( + "--cp-kv-cache-interleave-size", + **parallel_kwargs["cp_kv_cache_interleave_size"]) parallel_group.add_argument( "--context-parallel-size", "-cp", **parallel_kwargs["context_parallel_size"]) @@ -1338,6 +1342,7 @@ def create_engine_config( worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, decode_context_parallel_size=self.decode_context_parallel_size, + cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size, _api_process_count=self._api_process_count, _api_process_rank=self._api_process_rank, ) diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 834ec9b1d30b..2f882765eec5 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3470,3 +3470,28 @@ def length_from_prompt_token_ids_or_embeds( f" prompt_token_ids={prompt_token_len}" f" prompt_embeds={prompt_embeds_len}") return prompt_token_len + + +def get_split_computed_tokens( + num_computed_tokens: np.ndarray, + cp_world_size: int = 1, + dcp_world_size: int = 1, + cp_kv_cache_interleave_size: int = 1 +): + """While using cp or dcp, kv_cache size stored on each rank may be different, + use this function to calculate split num_computed_tokens on each cp/dcp rank. + """ + num_requests = len(num_computed_tokens) + total_world_size = cp_world_size * dcp_world_size + num_computed_tokens_tiled = np.tile( + num_computed_tokens[:, np.newaxis], (1, total_world_size) + ) + rank_offsets = np.tile(np.arange(total_world_size)[np.newaxis, :], (num_requests, 1)) + base = num_computed_tokens_tiled // cp_kv_cache_interleave_size // total_world_size \ + * cp_kv_cache_interleave_size + remainder = num_computed_tokens_tiled - base * total_world_size + remainder = np.clip( + remainder - rank_offsets * cp_kv_cache_interleave_size, 0, cp_kv_cache_interleave_size + ) + num_computed_tokens_of_cp_dcp = base + remainder + return num_computed_tokens_of_cp_dcp.reshape(-1, cp_world_size, dcp_world_size) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5b307810de93..652b6197977d 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -668,10 +668,11 @@ def build(self, decode_threshold=self.reorder_batch_threshold) # Note(hc): update seq_lens of decode reqs under DCP. - if self.dcp_world_size > 1: - seq_lens[:num_decodes] = seq_lens[:num_decodes] \ - // self.dcp_world_size + (self.dcp_rank <= \ - (seq_lens[:num_decodes] - 1) % self.dcp_world_size) + if self.dcp_world_size > 1 and num_decodes > 0: + seq_lens[:num_decodes] = torch.tensor( + common_attn_metadata.num_computed_tokens_of_cp_dcp[:num_decodes, 0, self.dcp_rank], + dtype=seq_lens.dtype, device=seq_lens.device + ) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index ff4e10e82edd..39d5aea659ba 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -84,6 +84,7 @@ class CommonAttentionMetadata: # Needed by custom mask calc for context parallelism query_positions: Optional[np.ndarray] = None cp_kv_recover_idx: Optional[torch.Tensor] = None + num_computed_tokens_of_cp_dcp: Optional[np.ndarray] = None def slice_query_start_locs( query_start_loc: torch.Tensor, diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index d25cb699d346..2be5514fb70f 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -78,7 +78,7 @@ def swap_row(self, src: int, tgt: int) -> None: self.block_table.np[src_tgt] = self.block_table.np[tgt_src] def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + positions: np.ndarray, cp_kv_cache_interleave_size: int = 1) -> None: # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] # -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1] # where K is the max_num_blocks_per_req and the block size is 2. @@ -100,10 +100,12 @@ def compute_slot_mapping(self, req_indices: np.ndarray, # tokens. virtual_block_offsets = positions % virtual_block_size self.current_rank = self.dcp_world_size * self.cp_rank + self.dcp_rank - mask = (virtual_block_offsets % + mask = (virtual_block_offsets // cp_kv_cache_interleave_size % (self.dcp_world_size * self.cp_world_size) == self.current_rank) # Calculate local block_offsets - block_offsets = virtual_block_offsets // (self.dcp_world_size * self.cp_world_size) + block_offsets = virtual_block_offsets \ + // (self.dcp_world_size * self.cp_world_size * cp_kv_cache_interleave_size) \ + * cp_kv_cache_interleave_size + virtual_block_offsets % cp_kv_cache_interleave_size # Calculate slot_mapping slot_mapping = block_numbers * self.block_size + block_offsets # Write final slots, use -1 for not-local @@ -147,27 +149,6 @@ def _make_buffer(self, *size: Union[int, torch.SymInt], device=self.device, pin_memory=self.pin_memory) - def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]: - "Splits computed token counts across dcp and sp dimensions for distributed allocation." - num_requests = len(num_computed_tokens) - num_computed_tokens_of_dcp_sp = [[ - [0] * self.dcp_world_size for _ in range(self.cp_world_size) - ] for _ in range(num_requests)] - total_ranks = self.cp_world_size * self.dcp_world_size - for req_idx in range(num_requests): - total_tokens = num_computed_tokens[req_idx] - if total_tokens <= 0: - continue - base = int(total_tokens) // total_ranks - remainder = int(total_tokens) % total_ranks - for rank_idx in range(total_ranks): - cp_idx = rank_idx // self.dcp_world_size - sp_idx = rank_idx % self.dcp_world_size - num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] = base - if rank_idx < remainder: - num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] += 1 - return num_computed_tokens_of_dcp_sp - class MultiGroupBlockTable: """The BlockTables for each KV cache group.""" @@ -216,9 +197,9 @@ def swap_row(self, src: int, tgt: int) -> None: block_table.swap_row(src, tgt) def compute_slot_mapping(self, req_indices: np.ndarray, - positions: np.ndarray) -> None: + positions: np.ndarray, cp_kv_cache_interleave_size: int = 1) -> None: for block_table in self.block_tables: - block_table.compute_slot_mapping(req_indices, positions) + block_table.compute_slot_mapping(req_indices, positions, cp_kv_cache_interleave_size) def commit_block_table(self, num_reqs: int) -> None: for block_table in self.block_tables: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 757cc5e7fccc..8ed7beaa30a0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -58,7 +58,7 @@ GiB_bytes, check_use_alibi, get_dtype_size, is_pin_memory_available, length_from_prompt_token_ids_or_embeds, round_up, - supports_dynamo, cdiv) + supports_dynamo, cdiv, get_split_computed_tokens) from vllm.v1.attention.backends.flash_attn import AttentionMetadata from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.utils import ( @@ -1123,7 +1123,10 @@ def _prepare_inputs( output_idx += num_sched self.input_batch.block_table.compute_slot_mapping( - req_indices_for_slotmapping, positions_np_for_slotmapping) + req_indices_for_slotmapping, + positions_np_for_slotmapping, + cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, + ) self.input_batch.block_table.commit_slot_mapping( total_num_scheduled_tokens_for_slotmapping) @@ -1305,6 +1308,12 @@ def _prepare_inputs( encoder_seq_lens=encoder_seq_lens, query_positions=positions_np, cp_kv_recover_idx=self.cp_kv_recover_idx, + num_computed_tokens_of_cp_dcp=get_split_computed_tokens( + self.input_batch.num_tokens[:num_reqs], + self.cp_world_size, + self.dcp_world_size, + self.parallel_config.cp_kv_cache_interleave_size, + ), ) if self.speculative_config and \