Skip to content

Commit 397fd51

Browse files
committed
[DCP] Support dcp kv_cache interleave size > 1
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
1 parent 782505e commit 397fd51

File tree

8 files changed

+116
-15
lines changed

8 files changed

+116
-15
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
3030
tp_size: int
3131
pp_size: int
3232
dcp_size: int
33+
cp_kv_cache_interleave_size: int
3334
eager_mode: bool
3435
chunked_prefill: bool
3536

@@ -52,6 +53,7 @@ def detailed(
5253
tp_base: int = 4,
5354
pp_base: int = 1,
5455
dcp_base: int = 1,
56+
cp_kv_cache_interleave_size: int = 1,
5557
multi_node_only: bool = False,
5658
runner: RunnerOption = "auto",
5759
load_format: str | None = None,
@@ -66,6 +68,7 @@ def detailed(
6668
tp_size=tp_base,
6769
pp_size=pp_multiplier * pp_base,
6870
dcp_size=int(dcp_multiplier * tp_base),
71+
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
6972
eager_mode=eager_mode_val,
7073
chunked_prefill=chunked_prefill_val,
7174
)
@@ -108,6 +111,7 @@ def _compare_cp_with_tp(
108111
tp_size,
109112
pp_size,
110113
dcp_size,
114+
cp_kv_cache_interleave_size,
111115
eager_mode,
112116
chunked_prefill,
113117
) = parallel_setup
@@ -180,6 +184,8 @@ def _compare_cp_with_tp(
180184
str(pp_size),
181185
"--decode-context-parallel-size",
182186
str(dcp_size),
187+
"--cp-kv-cache-interleave-size",
188+
str(cp_kv_cache_interleave_size),
183189
"--distributed-executor-backend",
184190
distributed_backend,
185191
]
@@ -208,6 +214,7 @@ def _compare_cp_with_tp(
208214
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
209215
CPTestSettings.detailed(),
210216
CPTestSettings.detailed(tp_base=2),
217+
CPTestSettings.detailed(tp_base=2, cp_kv_cache_interleave_size=64),
211218
],
212219
}
213220

vllm/config/parallel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,17 @@ class is dynamically inherited by the worker class. This is used to inject
204204
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
205205
needs to be divisible by dcp_size."""
206206

207+
cp_kv_cache_interleave_size: int = 1
208+
"""Interleave size of kv_cache storage while using dcp or cp > 1,
209+
store interleave_size tokens on (d)cp i,
210+
then store next interleave_size tokens on (d)cp i+1.
211+
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
212+
Interleave_size=block_size: block-level align, first fill the block on first rank,
213+
token is stored on rank i+1 block j after rank i block j is full.
214+
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
215+
Block_size should be divisible by cp_kv_cache_interleave_size.
216+
"""
217+
207218
_api_process_count: int = Field(default=1, gt=0)
208219
"""
209220
The number of API processes initialized.

vllm/config/vllm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,17 @@ def __post_init__(self):
471471
"to True to enable."
472472
)
473473
current_platform.check_and_update_config(self)
474+
assert (
475+
self.parallel_config.cp_kv_cache_interleave_size
476+
<= self.cache_config.block_size
477+
and self.cache_config.block_size
478+
% self.parallel_config.cp_kv_cache_interleave_size
479+
== 0
480+
), (
481+
f"Block_size({self.cache_config.block_size}) should be "
482+
"greater than or equal to and divisible by cp_kv_cache_interleave_size "
483+
f"({self.parallel_config.cp_kv_cache_interleave_size})."
484+
)
474485

475486
# Do this after all the updates to compilation_config.level
476487
if (

vllm/engine/arg_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ class EngineArgs:
362362
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
363363
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
364364
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
365+
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
365366
data_parallel_size: int = ParallelConfig.data_parallel_size
366367
data_parallel_rank: int | None = None
367368
data_parallel_start_rank: int | None = None
@@ -715,6 +716,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
715716
"-dcp",
716717
**parallel_kwargs["decode_context_parallel_size"],
717718
)
719+
parallel_group.add_argument(
720+
"--cp-kv-cache-interleave-size",
721+
**parallel_kwargs["cp_kv_cache_interleave_size"],
722+
)
718723
parallel_group.add_argument(
719724
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
720725
)
@@ -1470,6 +1475,7 @@ def create_engine_config(
14701475
worker_cls=self.worker_cls,
14711476
worker_extension_cls=self.worker_extension_cls,
14721477
decode_context_parallel_size=self.decode_context_parallel_size,
1478+
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
14731479
_api_process_count=self._api_process_count,
14741480
_api_process_rank=self._api_process_rank,
14751481
)
@@ -1480,6 +1486,10 @@ def create_engine_config(
14801486
enable_chunked_prefill=self.enable_chunked_prefill,
14811487
disable_log_stats=self.disable_log_stats,
14821488
)
1489+
if speculative_config is not None and self.cp_kv_cache_interleave_size != 1:
1490+
raise ValueError(
1491+
"MTP with cp_kv_cache_interleave_size > 1 is not supported now."
1492+
)
14831493

14841494
# make sure num_lookahead_slots is set appropriately depending on
14851495
# whether speculative decoding is enabled

vllm/utils/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3426,3 +3426,35 @@ def unique_filepath(fn: Callable[[int], Path]) -> Path:
34263426
if not p.exists():
34273427
return p
34283428
i += 1
3429+
3430+
3431+
def get_dcp_local_seq_lens(
3432+
seq_lens: torch.Tensor,
3433+
dcp_world_size: int = 1,
3434+
cp_kv_cache_interleave_size: int = 1,
3435+
) -> torch.Tensor:
3436+
"""While using dcp, kv_cache size stored on each rank may be different,
3437+
use this function to calculate split decode seq_lens of each dcp rank.
3438+
Only consider dcp now, we can extend the case of cp based on this.
3439+
"""
3440+
num_requests = seq_lens.size(0)
3441+
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, dcp_world_size)
3442+
rank_offsets = (
3443+
torch.arange(dcp_world_size, device=seq_lens.device, dtype=torch.int32)
3444+
.unsqueeze(0)
3445+
.repeat(num_requests, 1)
3446+
)
3447+
base = (
3448+
seq_lens_tiled
3449+
// cp_kv_cache_interleave_size
3450+
// dcp_world_size
3451+
* cp_kv_cache_interleave_size
3452+
)
3453+
remainder = seq_lens_tiled - base * dcp_world_size
3454+
remainder = torch.clip(
3455+
remainder - rank_offsets * cp_kv_cache_interleave_size,
3456+
0,
3457+
cp_kv_cache_interleave_size,
3458+
)
3459+
dcp_local_seq_lens = base + remainder
3460+
return dcp_local_seq_lens

vllm/v1/attention/backends/mla/common.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -749,15 +749,6 @@ def build(
749749
)
750750
)
751751

752-
# Note(hc): update seq_lens of decode reqs under DCP.
753-
if self.dcp_world_size > 1:
754-
assert dcp_local_seq_lens is not None
755-
dcp_local_seq_lens[:num_decodes] = seq_lens[
756-
:num_decodes
757-
] // self.dcp_world_size + (
758-
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
759-
)
760-
761752
assert num_decodes + num_prefills == num_reqs
762753
assert num_decode_tokens + num_prefill_tokens == num_tokens
763754

vllm/v1/worker/block_table.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,10 @@ def swap_row(self, src: int, tgt: int) -> None:
119119
self.block_table.np[src_tgt] = self.block_table.np[tgt_src]
120120

121121
def compute_slot_mapping(
122-
self, req_indices: np.ndarray, positions: np.ndarray
122+
self,
123+
req_indices: np.ndarray,
124+
positions: np.ndarray,
125+
cp_kv_cache_interleave_size: int = 1,
123126
) -> None:
124127
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
125128
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
@@ -144,9 +147,19 @@ def compute_slot_mapping(
144147
# Use virtual_block_size for mask calculation, which marks local
145148
# tokens.
146149
virtual_block_offsets = positions % virtual_block_size
147-
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
150+
mask = (
151+
virtual_block_offsets
152+
// cp_kv_cache_interleave_size
153+
% self.dcp_world_size
154+
== self.dcp_rank
155+
)
148156
# Calculate local block_offsets
149-
block_offsets = virtual_block_offsets // self.dcp_world_size
157+
block_offsets = (
158+
virtual_block_offsets
159+
// (self.dcp_world_size * cp_kv_cache_interleave_size)
160+
* cp_kv_cache_interleave_size
161+
+ virtual_block_offsets % cp_kv_cache_interleave_size
162+
)
150163
# Calculate slot_mapping
151164
slot_mapping = block_numbers * self.block_size + block_offsets
152165
# Write final slots, use -1 for not-local
@@ -284,10 +297,17 @@ def swap_row(self, src: int, tgt: int) -> None:
284297
block_table.swap_row(src, tgt)
285298

286299
def compute_slot_mapping(
287-
self, req_indices: np.ndarray, positions: np.ndarray
300+
self,
301+
req_indices: np.ndarray,
302+
positions: np.ndarray,
303+
cp_kv_cache_interleave_size: int = 1,
288304
) -> None:
289305
for block_table in self.block_tables:
290-
block_table.compute_slot_mapping(req_indices, positions)
306+
block_table.compute_slot_mapping(
307+
req_indices,
308+
positions,
309+
cp_kv_cache_interleave_size,
310+
)
291311

292312
def commit_block_table(self, num_reqs: int) -> None:
293313
for block_table in self.block_tables:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
3636
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
3737
from vllm.distributed.parallel_state import (
38+
get_dcp_group,
3839
get_pp_group,
3940
get_tp_group,
4041
graph_capture,
@@ -78,6 +79,7 @@
7879
GiB_bytes,
7980
cdiv,
8081
check_use_alibi,
82+
get_dcp_local_seq_lens,
8183
get_dtype_size,
8284
is_pin_memory_available,
8385
length_from_prompt_token_ids_or_embeds,
@@ -256,6 +258,11 @@ def __init__(
256258
self.is_multimodal_pruning_enabled = False
257259
self.max_model_len = model_config.max_model_len
258260
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
261+
try:
262+
self.dcp_rank = get_dcp_group().rank_in_group
263+
except AssertionError:
264+
# DCP might not be initialized in testing
265+
self.dcp_rank = 0
259266
self.max_num_tokens = scheduler_config.max_num_batched_tokens
260267
self.max_num_reqs = scheduler_config.max_num_seqs
261268

@@ -1158,7 +1165,11 @@ def _prepare_inputs(
11581165

11591166
output_idx += num_sched
11601167

1161-
self.input_batch.block_table.compute_slot_mapping(req_indices, positions_np)
1168+
self.input_batch.block_table.compute_slot_mapping(
1169+
req_indices,
1170+
positions_np,
1171+
self.parallel_config.cp_kv_cache_interleave_size,
1172+
)
11621173
self.input_batch.block_table.commit_slot_mapping(total_num_scheduled_tokens)
11631174

11641175
# Prepare the attention metadata.
@@ -1276,6 +1287,14 @@ def _prepare_inputs(
12761287
logits_indices
12771288
)
12781289

1290+
# update seq_lens of decode reqs under DCP.
1291+
if self.dcp_world_size > 1:
1292+
self.dcp_local_seq_lens.gpu[:num_reqs] = get_dcp_local_seq_lens(
1293+
seq_lens,
1294+
self.dcp_world_size,
1295+
self.parallel_config.cp_kv_cache_interleave_size,
1296+
)[:, self.dcp_rank]
1297+
12791298
attn_metadata: PerLayerAttnMetadata = {}
12801299
if ubatch_slices is not None:
12811300
attn_metadata = [dict() for _ in range(len(ubatch_slices))]

0 commit comments

Comments
 (0)