Skip to content

Commit c9f6cf5

Browse files
Merge pull request #4 from zhangsicheng5/long_seq_dev
support cp_kv_cache_interleave_size on mla & dcp
2 parents 3f73536 + d0a5654 commit c9f6cf5

File tree

8 files changed

+67
-32
lines changed

8 files changed

+67
-32
lines changed

vllm/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,10 @@ def __post_init__(self):
617617
"Modify KVEventsConfig.enable_kv_cache_events"
618618
"to True to enable.")
619619
current_platform.check_and_update_config(self)
620+
assert self.parallel_config.cp_kv_cache_interleave_size <= self.cache_config.block_size and \
621+
self.cache_config.block_size % self.parallel_config.cp_kv_cache_interleave_size == 0, \
622+
f"Block_size({self.cache_config.block_size}) should be greater than and divisible by "\
623+
f"cp_kv_cache_interleave_size({self.parallel_config.cp_kv_cache_interleave_size})."
620624

621625
# final check of cudagraph mode after platform-specific update
622626
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():

vllm/config/parallel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,15 @@ class is dynamically inherited by the worker class. This is used to inject
195195
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
196196
needs to be divisible by dcp_size."""
197197

198+
cp_kv_cache_interleave_size: int = 1
199+
"""Interleave size of kv_cache storage while using dcp or cp > 1,
200+
store interleave_size tokens on (d)cp i, then store next interleave_size tokens on (d)cp i+1.
201+
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
202+
Interleave_size=block_size: block-level align, first fill the block on first rank,
203+
token is stored on rank i+1 block j after rank i block j is full.
204+
Block_size should be greater than and divisible by cp_kv_cache_interleave_size.
205+
"""
206+
198207
_api_process_count: int = 1
199208
"""
200209
The number of API processes initialized.

vllm/engine/arg_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ class EngineArgs:
318318
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
319319
decode_context_parallel_size: int = \
320320
ParallelConfig.decode_context_parallel_size
321+
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
321322
context_parallel_size: int = ParallelConfig.context_parallel_size
322323
data_parallel_size: int = ParallelConfig.data_parallel_size
323324
data_parallel_rank: Optional[int] = None
@@ -654,6 +655,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
654655
parallel_group.add_argument(
655656
"--decode-context-parallel-size", "-dcp",
656657
**parallel_kwargs["decode_context_parallel_size"])
658+
parallel_group.add_argument(
659+
"--cp-kv-cache-interleave-size",
660+
**parallel_kwargs["cp_kv_cache_interleave_size"])
657661
parallel_group.add_argument(
658662
"--context-parallel-size", "-cp",
659663
**parallel_kwargs["context_parallel_size"])
@@ -1338,6 +1342,7 @@ def create_engine_config(
13381342
worker_cls=self.worker_cls,
13391343
worker_extension_cls=self.worker_extension_cls,
13401344
decode_context_parallel_size=self.decode_context_parallel_size,
1345+
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
13411346
_api_process_count=self._api_process_count,
13421347
_api_process_rank=self._api_process_rank,
13431348
)

vllm/utils/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3470,3 +3470,28 @@ def length_from_prompt_token_ids_or_embeds(
34703470
f" prompt_token_ids={prompt_token_len}"
34713471
f" prompt_embeds={prompt_embeds_len}")
34723472
return prompt_token_len
3473+
3474+
3475+
def get_split_computed_tokens(
3476+
num_computed_tokens: np.ndarray,
3477+
cp_world_size: int = 1,
3478+
dcp_world_size: int = 1,
3479+
cp_kv_cache_interleave_size: int = 1
3480+
):
3481+
"""While using cp or dcp, kv_cache size stored on each rank may be different,
3482+
use this function to calculate split num_computed_tokens on each cp/dcp rank.
3483+
"""
3484+
num_requests = len(num_computed_tokens)
3485+
total_world_size = cp_world_size * dcp_world_size
3486+
num_computed_tokens_tiled = np.tile(
3487+
num_computed_tokens[:, np.newaxis], (1, total_world_size)
3488+
)
3489+
rank_offsets = np.tile(np.arange(total_world_size)[np.newaxis, :], (num_requests, 1))
3490+
base = num_computed_tokens_tiled // cp_kv_cache_interleave_size // total_world_size \
3491+
* cp_kv_cache_interleave_size
3492+
remainder = num_computed_tokens_tiled - base * total_world_size
3493+
remainder = np.clip(
3494+
remainder - rank_offsets * cp_kv_cache_interleave_size, 0, cp_kv_cache_interleave_size
3495+
)
3496+
num_computed_tokens_of_cp_dcp = base + remainder
3497+
return num_computed_tokens_of_cp_dcp.reshape(-1, cp_world_size, dcp_world_size)

vllm/v1/attention/backends/mla/common.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -668,10 +668,11 @@ def build(self,
668668
decode_threshold=self.reorder_batch_threshold)
669669

670670
# Note(hc): update seq_lens of decode reqs under DCP.
671-
if self.dcp_world_size > 1:
672-
seq_lens[:num_decodes] = seq_lens[:num_decodes] \
673-
// self.dcp_world_size + (self.dcp_rank <= \
674-
(seq_lens[:num_decodes] - 1) % self.dcp_world_size)
671+
if self.dcp_world_size > 1 and num_decodes > 0:
672+
seq_lens[:num_decodes] = torch.tensor(
673+
common_attn_metadata.num_computed_tokens_of_cp_dcp[:num_decodes, 0, self.dcp_rank],
674+
dtype=seq_lens.dtype, device=seq_lens.device
675+
)
675676

676677
assert num_decodes + num_prefills == num_reqs
677678
assert num_decode_tokens + num_prefill_tokens == num_tokens

vllm/v1/attention/backends/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class CommonAttentionMetadata:
8484
# Needed by custom mask calc for context parallelism
8585
query_positions: Optional[np.ndarray] = None
8686
cp_kv_recover_idx: Optional[torch.Tensor] = None
87+
num_computed_tokens_of_cp_dcp: Optional[np.ndarray] = None
8788

8889
def slice_query_start_locs(
8990
query_start_loc: torch.Tensor,

vllm/v1/worker/block_table.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def swap_row(self, src: int, tgt: int) -> None:
7878
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
7979

8080
def compute_slot_mapping(self, req_indices: np.ndarray,
81-
positions: np.ndarray) -> None:
81+
positions: np.ndarray, cp_kv_cache_interleave_size: int = 1) -> None:
8282
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
8383
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
8484
# where K is the max_num_blocks_per_req and the block size is 2.
@@ -100,10 +100,12 @@ def compute_slot_mapping(self, req_indices: np.ndarray,
100100
# tokens.
101101
virtual_block_offsets = positions % virtual_block_size
102102
self.current_rank = self.dcp_world_size * self.cp_rank + self.dcp_rank
103-
mask = (virtual_block_offsets %
103+
mask = (virtual_block_offsets // cp_kv_cache_interleave_size %
104104
(self.dcp_world_size * self.cp_world_size) == self.current_rank)
105105
# Calculate local block_offsets
106-
block_offsets = virtual_block_offsets // (self.dcp_world_size * self.cp_world_size)
106+
block_offsets = virtual_block_offsets \
107+
// (self.dcp_world_size * self.cp_world_size * cp_kv_cache_interleave_size) \
108+
* cp_kv_cache_interleave_size + virtual_block_offsets % cp_kv_cache_interleave_size
107109
# Calculate slot_mapping
108110
slot_mapping = block_numbers * self.block_size + block_offsets
109111
# Write final slots, use -1 for not-local
@@ -147,27 +149,6 @@ def _make_buffer(self, *size: Union[int, torch.SymInt],
147149
device=self.device,
148150
pin_memory=self.pin_memory)
149151

150-
def get_split_computed_tokens(self, num_computed_tokens: np.ndarray) -> list[list[list[int]]]:
151-
"Splits computed token counts across dcp and sp dimensions for distributed allocation."
152-
num_requests = len(num_computed_tokens)
153-
num_computed_tokens_of_dcp_sp = [[
154-
[0] * self.dcp_world_size for _ in range(self.cp_world_size)
155-
] for _ in range(num_requests)]
156-
total_ranks = self.cp_world_size * self.dcp_world_size
157-
for req_idx in range(num_requests):
158-
total_tokens = num_computed_tokens[req_idx]
159-
if total_tokens <= 0:
160-
continue
161-
base = int(total_tokens) // total_ranks
162-
remainder = int(total_tokens) % total_ranks
163-
for rank_idx in range(total_ranks):
164-
cp_idx = rank_idx // self.dcp_world_size
165-
sp_idx = rank_idx % self.dcp_world_size
166-
num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] = base
167-
if rank_idx < remainder:
168-
num_computed_tokens_of_dcp_sp[req_idx][cp_idx][sp_idx] += 1
169-
return num_computed_tokens_of_dcp_sp
170-
171152

172153
class MultiGroupBlockTable:
173154
"""The BlockTables for each KV cache group."""
@@ -216,9 +197,9 @@ def swap_row(self, src: int, tgt: int) -> None:
216197
block_table.swap_row(src, tgt)
217198

218199
def compute_slot_mapping(self, req_indices: np.ndarray,
219-
positions: np.ndarray) -> None:
200+
positions: np.ndarray, cp_kv_cache_interleave_size: int = 1) -> None:
220201
for block_table in self.block_tables:
221-
block_table.compute_slot_mapping(req_indices, positions)
202+
block_table.compute_slot_mapping(req_indices, positions, cp_kv_cache_interleave_size)
222203

223204
def commit_block_table(self, num_reqs: int) -> None:
224205
for block_table in self.block_tables:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
GiB_bytes, check_use_alibi, get_dtype_size,
5959
is_pin_memory_available,
6060
length_from_prompt_token_ids_or_embeds, round_up,
61-
supports_dynamo, cdiv)
61+
supports_dynamo, cdiv, get_split_computed_tokens)
6262
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
6363
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
6464
from vllm.v1.attention.backends.utils import (
@@ -1123,7 +1123,10 @@ def _prepare_inputs(
11231123
output_idx += num_sched
11241124

11251125
self.input_batch.block_table.compute_slot_mapping(
1126-
req_indices_for_slotmapping, positions_np_for_slotmapping)
1126+
req_indices_for_slotmapping,
1127+
positions_np_for_slotmapping,
1128+
cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size,
1129+
)
11271130
self.input_batch.block_table.commit_slot_mapping(
11281131
total_num_scheduled_tokens_for_slotmapping)
11291132

@@ -1305,6 +1308,12 @@ def _prepare_inputs(
13051308
encoder_seq_lens=encoder_seq_lens,
13061309
query_positions=positions_np,
13071310
cp_kv_recover_idx=self.cp_kv_recover_idx,
1311+
num_computed_tokens_of_cp_dcp=get_split_computed_tokens(
1312+
self.input_batch.num_tokens[:num_reqs],
1313+
self.cp_world_size,
1314+
self.dcp_world_size,
1315+
self.parallel_config.cp_kv_cache_interleave_size,
1316+
),
13081317
)
13091318

13101319
if self.speculative_config and \

0 commit comments

Comments
 (0)