Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ def __post_init__(self):
"Modify KVEventsConfig.enable_kv_cache_events"
"to True to enable.")
current_platform.check_and_update_config(self)
assert self.parallel_config.cp_kv_cache_interleave_size <= self.cache_config.block_size and \
self.cache_config.block_size % self.parallel_config.cp_kv_cache_interleave_size == 0, \
f"Block_size({self.cache_config.block_size}) should be greater than and divisible by "\
f"cp_kv_cache_interleave_size({self.parallel_config.cp_kv_cache_interleave_size})."

# final check of cudagraph mode after platform-specific update
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
Expand Down
9 changes: 9 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,15 @@ class is dynamically inherited by the worker class. This is used to inject
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""

cp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using dcp or cp > 1,
store interleave_size tokens on (d)cp i, then store next interleave_size tokens on (d)cp i+1.
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
Interleave_size=block_size: block-level align, first fill the block on first rank,
token is stored on rank i+1 block j after rank i block j is full.
Block_size should be greater than and divisible by cp_kv_cache_interleave_size.
"""

_api_process_count: int = 1
"""
The number of API processes initialized.
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ class EngineArgs:
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = \
ParallelConfig.decode_context_parallel_size
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
context_parallel_size: int = ParallelConfig.context_parallel_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: Optional[int] = None
Expand Down Expand Up @@ -654,6 +655,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parallel_group.add_argument(
"--decode-context-parallel-size", "-dcp",
**parallel_kwargs["decode_context_parallel_size"])
parallel_group.add_argument(
"--cp-kv-cache-interleave-size",
**parallel_kwargs["cp_kv_cache_interleave_size"])
parallel_group.add_argument(
"--context-parallel-size", "-cp",
**parallel_kwargs["context_parallel_size"])
Expand Down Expand Up @@ -1338,6 +1342,7 @@ def create_engine_config(
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
)
Expand Down
25 changes: 25 additions & 0 deletions vllm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3470,3 +3470,28 @@ def length_from_prompt_token_ids_or_embeds(
f" prompt_token_ids={prompt_token_len}"
f" prompt_embeds={prompt_embeds_len}")
return prompt_token_len


def get_split_computed_tokens(
num_computed_tokens: np.ndarray,
cp_world_size: int = 1,
dcp_world_size: int = 1,
cp_kv_cache_interleave_size: int = 1
):
"""While using cp or dcp, kv_cache size stored on each rank may be different,
use this function to calculate split num_computed_tokens on each cp/dcp rank.
"""
num_requests = len(num_computed_tokens)
total_world_size = cp_world_size * dcp_world_size
num_computed_tokens_tiled = np.tile(
num_computed_tokens[:, np.newaxis], (1, total_world_size)
)
rank_offsets = np.tile(np.arange(total_world_size)[np.newaxis, :], (num_requests, 1))
base = num_computed_tokens_tiled // cp_kv_cache_interleave_size // total_world_size \
* cp_kv_cache_interleave_size
remainder = num_computed_tokens_tiled - base * total_world_size
remainder = np.clip(
remainder - rank_offsets * cp_kv_cache_interleave_size, 0, cp_kv_cache_interleave_size
)
num_computed_tokens_of_cp_dcp = base + remainder
return num_computed_tokens_of_cp_dcp.reshape(-1, cp_world_size, dcp_world_size)
9 changes: 5 additions & 4 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,10 +668,11 @@ def build(self,
decode_threshold=self.reorder_batch_threshold)

# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
seq_lens[:num_decodes] = seq_lens[:num_decodes] \
// self.dcp_world_size + (self.dcp_rank <= \
(seq_lens[:num_decodes] - 1) % self.dcp_world_size)
if self.dcp_world_size > 1 and num_decodes > 0:
seq_lens[:num_decodes] = torch.tensor(
common_attn_metadata.num_computed_tokens_of_cp_dcp[:num_decodes, 0, self.dcp_rank],
dtype=seq_lens.dtype, device=seq_lens.device
)

assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class CommonAttentionMetadata:
# Needed by custom mask calc for context parallelism
query_positions: Optional[np.ndarray] = None
cp_kv_recover_idx: Optional[torch.Tensor] = None
num_computed_tokens_of_cp_dcp: Optional[np.ndarray] = None

def slice_query_start_locs(
query_start_loc: torch.Tensor,
Expand Down
33 changes: 7 additions & 26 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def swap_row(self, src: int, tgt: int) -> None:
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]

def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
positions: np.ndarray, cp_kv_cache_interleave_size: int = 1) -> None:
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
Expand All @@ -100,10 +100,12 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
# tokens.
virtual_block_offsets = positions % virtual_block_size
self.current_rank = self.dcp_world_size * self.cp_rank + self.dcp_rank
mask = (virtual_block_offsets %
mask = (virtual_block_offsets // cp_kv_cache_interleave_size %
(self.dcp_world_size * self.cp_world_size) == self.current_rank)
# Calculate local block_offsets
block_offsets = virtual_block_offsets // (self.dcp_world_size * self.cp_world_size)
block_offsets = virtual_block_offsets \
// (self.dcp_world_size * self.cp_world_size * cp_kv_cache_interleave_size) \
* cp_kv_cache_interleave_size + virtual_block_offsets % cp_kv_cache_interleave_size
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
Expand Down Expand Up @@ -147,27 +149,6 @@ def _make_buffer(self, *size: Union[int, torch.SymInt],
device=self.device,
pin_memory=self.pin_memory)

def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]:
"Splits computed token counts across dcp and sp dimensions for distributed allocation."
num_requests = len(num_computed_tokens)
num_computed_tokens_of_dcp_sp = [[
[0] * self.dcp_world_size for _ in range(self.cp_world_size)
] for _ in range(num_requests)]
total_ranks = self.cp_world_size * self.dcp_world_size
for req_idx in range(num_requests):
total_tokens = num_computed_tokens[req_idx]
if total_tokens <= 0:
continue
base = int(total_tokens) // total_ranks
remainder = int(total_tokens) % total_ranks
for rank_idx in range(total_ranks):
cp_idx = rank_idx // self.dcp_world_size
sp_idx = rank_idx % self.dcp_world_size
num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] = base
if rank_idx < remainder:
num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] += 1
return num_computed_tokens_of_dcp_sp


class MultiGroupBlockTable:
"""The BlockTables for each KV cache group."""
Expand Down Expand Up @@ -216,9 +197,9 @@ def swap_row(self, src: int, tgt: int) -> None:
block_table.swap_row(src, tgt)

def compute_slot_mapping(self, req_indices: np.ndarray,
positions: np.ndarray) -> None:
positions: np.ndarray, cp_kv_cache_interleave_size: int = 1) -> None:
for block_table in self.block_tables:
block_table.compute_slot_mapping(req_indices, positions)
block_table.compute_slot_mapping(req_indices, positions, cp_kv_cache_interleave_size)

def commit_block_table(self, num_reqs: int) -> None:
for block_table in self.block_tables:
Expand Down
13 changes: 11 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
GiB_bytes, check_use_alibi, get_dtype_size,
is_pin_memory_available,
length_from_prompt_token_ids_or_embeds, round_up,
supports_dynamo, cdiv)
supports_dynamo, cdiv, get_split_computed_tokens)
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import (
Expand Down Expand Up @@ -1123,7 +1123,10 @@ def _prepare_inputs(
output_idx += num_sched

self.input_batch.block_table.compute_slot_mapping(
req_indices_for_slotmapping, positions_np_for_slotmapping)
req_indices_for_slotmapping,
positions_np_for_slotmapping,
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens_for_slotmapping)

Expand Down Expand Up @@ -1305,6 +1308,12 @@ def _prepare_inputs(
encoder_seq_lens=encoder_seq_lens,
query_positions=positions_np,
cp_kv_recover_idx=self.cp_kv_recover_idx,
num_computed_tokens_of_cp_dcp=get_split_computed_tokens(
self.input_batch.num_tokens[:num_reqs],
self.cp_world_size,
self.dcp_world_size,
self.parallel_config.cp_kv_cache_interleave_size,
),
)

if self.speculative_config and \
Expand Down