Skip to content

Commit 0a87c88

Browse files
committed
support dcp kv_cache interleave size > 1
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
1 parent 650b51f commit 0a87c88

File tree

9 files changed

+104
-11
lines changed

9 files changed

+104
-11
lines changed

tests/distributed/test_context_parallel.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
3030
tp_size: int
3131
pp_size: int
3232
dcp_size: int
33+
dcp_kv_cache_interleave_size: int
3334
eager_mode: bool
3435
chunked_prefill: bool
3536

@@ -52,6 +53,7 @@ def detailed(
5253
tp_base: int = 4,
5354
pp_base: int = 1,
5455
dcp_base: int = 1,
56+
dcp_kv_cache_interleave_size: int = 1,
5557
multi_node_only: bool = False,
5658
runner: RunnerOption = "auto",
5759
load_format: str | None = None,
@@ -66,6 +68,7 @@ def detailed(
6668
tp_size=tp_base,
6769
pp_size=pp_multiplier * pp_base,
6870
dcp_size=int(dcp_multiplier * tp_base),
71+
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
6972
eager_mode=eager_mode_val,
7073
chunked_prefill=chunked_prefill_val,
7174
)
@@ -108,6 +111,7 @@ def _compare_cp_with_tp(
108111
tp_size,
109112
pp_size,
110113
dcp_size,
114+
dcp_kv_cache_interleave_size,
111115
eager_mode,
112116
chunked_prefill,
113117
) = parallel_setup
@@ -180,6 +184,8 @@ def _compare_cp_with_tp(
180184
str(pp_size),
181185
"--decode-context-parallel-size",
182186
str(dcp_size),
187+
"--dcp-kv-cache-interleave-size",
188+
str(dcp_kv_cache_interleave_size),
183189
"--distributed-executor-backend",
184190
distributed_backend,
185191
]
@@ -207,6 +213,7 @@ def _compare_cp_with_tp(
207213
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
208214
CPTestSettings.detailed(),
209215
CPTestSettings.detailed(tp_base=2),
216+
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
210217
],
211218
"bigcode/gpt_bigcode-santacoder": [
212219
CPTestSettings.detailed(),

vllm/config/parallel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,17 @@ class is dynamically inherited by the worker class. This is used to inject
223223
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
224224
needs to be divisible by dcp_size."""
225225

226+
dcp_kv_cache_interleave_size: int = 1
227+
"""Interleave size of kv_cache storage while using dcp or cp > 1,
228+
store interleave_size tokens on (d)cp i,
229+
then store next interleave_size tokens on (d)cp i+1.
230+
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
231+
Interleave_size=block_size: block-level align, first fill the block on first rank,
232+
token is stored on rank i+1 block j after rank i block j is full.
233+
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
234+
Block_size should be divisible by dcp_kv_cache_interleave_size.
235+
"""
236+
226237
_api_process_count: int = Field(default=1, gt=0)
227238
"""
228239
The number of API processes initialized.

vllm/config/vllm.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,23 @@ def __post_init__(self):
480480
)
481481
current_platform.check_and_update_config(self)
482482

483+
assert (
484+
self.parallel_config.dcp_kv_cache_interleave_size
485+
<= self.cache_config.block_size
486+
and self.cache_config.block_size
487+
% self.parallel_config.dcp_kv_cache_interleave_size
488+
== 0
489+
), (
490+
f"Block_size({self.cache_config.block_size}) should be "
491+
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
492+
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
493+
)
494+
495+
assert (
496+
self.parallel_config.dcp_kv_cache_interleave_size == 1
497+
or self.speculative_config is None
498+
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."
499+
483500
# Do this after all the updates to compilation_config.mode
484501
if (
485502
envs.VLLM_USE_V1

vllm/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ class EngineArgs:
362362
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
363363
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
364364
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
365+
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
365366
data_parallel_size: int = ParallelConfig.data_parallel_size
366367
data_parallel_rank: int | None = None
367368
data_parallel_start_rank: int | None = None
@@ -717,6 +718,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
717718
"-dcp",
718719
**parallel_kwargs["decode_context_parallel_size"],
719720
)
721+
parallel_group.add_argument(
722+
"--dcp-kv-cache-interleave-size",
723+
**parallel_kwargs["dcp_kv_cache_interleave_size"],
724+
)
720725
parallel_group.add_argument(
721726
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
722727
)
@@ -1482,6 +1487,7 @@ def create_engine_config(
14821487
worker_cls=self.worker_cls,
14831488
worker_extension_cls=self.worker_extension_cls,
14841489
decode_context_parallel_size=self.decode_context_parallel_size,
1490+
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
14851491
_api_process_count=self._api_process_count,
14861492
_api_process_rank=self._api_process_rank,
14871493
)

vllm/v1/attention/backends/mla/common.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -774,15 +774,6 @@ def build(
774774
)
775775
)
776776

777-
# Note(hc): update seq_lens of decode reqs under DCP.
778-
if self.dcp_world_size > 1:
779-
assert dcp_local_seq_lens is not None
780-
dcp_local_seq_lens[:num_decodes] = seq_lens[
781-
:num_decodes
782-
] // self.dcp_world_size + (
783-
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
784-
)
785-
786777
assert num_decodes + num_prefills == num_reqs
787778
assert num_decode_tokens + num_prefill_tokens == num_tokens
788779

vllm/v1/attention/backends/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,3 +992,35 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
992992
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore
993993

994994
return nums_dict, batch_ptr, token_chunk_offset_ptr
995+
996+
997+
def get_dcp_local_seq_lens(
998+
seq_lens: torch.Tensor,
999+
dcp_world_size: int = 1,
1000+
dcp_kv_cache_interleave_size: int = 1,
1001+
) -> torch.Tensor:
1002+
"""While using dcp, kv_cache size stored on each rank may be different,
1003+
use this function to calculate split decode seq_lens of each dcp rank.
1004+
Only consider dcp now, we can extend the case of cp based on this.
1005+
"""
1006+
num_requests = seq_lens.size(0)
1007+
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, dcp_world_size)
1008+
rank_offsets = (
1009+
torch.arange(dcp_world_size, dtype=torch.int32)
1010+
.unsqueeze(0)
1011+
.repeat(num_requests, 1)
1012+
)
1013+
base = (
1014+
seq_lens_tiled
1015+
// dcp_kv_cache_interleave_size
1016+
// dcp_world_size
1017+
* dcp_kv_cache_interleave_size
1018+
)
1019+
remainder = seq_lens_tiled - base * dcp_world_size
1020+
remainder = torch.clip(
1021+
remainder - rank_offsets * dcp_kv_cache_interleave_size,
1022+
0,
1023+
dcp_kv_cache_interleave_size,
1024+
)
1025+
dcp_local_seq_lens = base + remainder
1026+
return dcp_local_seq_lens

vllm/v1/worker/block_table.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(
2222
pin_memory: bool,
2323
device: torch.device,
2424
kernel_block_size: int,
25+
dcp_kv_cache_interleave_size: int,
2526
):
2627
"""
2728
Args:
@@ -86,6 +87,7 @@ def __init__(
8687
# DCP might not be initialized in testing
8788
self.dcp_world_size = 1
8889
self.dcp_rank = 0
90+
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
8991

9092
def append_row(
9193
self,
@@ -144,9 +146,19 @@ def compute_slot_mapping(
144146
# Use virtual_block_size for mask calculation, which marks local
145147
# tokens.
146148
virtual_block_offsets = positions % virtual_block_size
147-
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
149+
mask = (
150+
virtual_block_offsets
151+
// self.dcp_kv_cache_interleave_size
152+
% self.dcp_world_size
153+
== self.dcp_rank
154+
)
148155
# Calculate local block_offsets
149-
block_offsets = virtual_block_offsets // self.dcp_world_size
156+
block_offsets = (
157+
virtual_block_offsets
158+
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
159+
* self.dcp_kv_cache_interleave_size
160+
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
161+
)
150162
# Calculate slot_mapping
151163
slot_mapping = block_numbers * self.block_size + block_offsets
152164
# Write final slots, use -1 for not-local
@@ -234,6 +246,7 @@ def __init__(
234246
block_sizes: list[int],
235247
kernel_block_sizes: list[int],
236248
num_speculative_tokens: int = 0,
249+
dcp_kv_cache_interleave_size: int = 1,
237250
) -> None:
238251
# Note(hc): each dcp rank only store
239252
# (max_model_len//dcp_world_size) tokens in kvcache,
@@ -263,6 +276,7 @@ def __init__(
263276
pin_memory,
264277
device,
265278
kernel_block_size,
279+
dcp_kv_cache_interleave_size,
266280
)
267281
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
268282
]

vllm/v1/worker/gpu_input_batch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def __init__(
8383
is_spec_decode: bool = False,
8484
is_pooling_model: bool = False,
8585
num_speculative_tokens: int = 0,
86+
dcp_kv_cache_interleave_size: int = 1,
8687
):
8788
self.is_pooling_model = is_pooling_model
8889
self.is_spec_decode = is_spec_decode
@@ -135,6 +136,7 @@ def __init__(
135136
block_sizes=block_sizes,
136137
kernel_block_sizes=kernel_block_sizes,
137138
num_speculative_tokens=num_speculative_tokens,
139+
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
138140
)
139141

140142
# Sampling-related.

vllm/v1/worker/gpu_model_runner.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group
3636
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
3737
from vllm.distributed.parallel_state import (
38+
get_dcp_group,
3839
get_pp_group,
3940
get_tp_group,
4041
graph_capture,
@@ -92,6 +93,7 @@
9293
AttentionMetadataBuilder,
9394
CommonAttentionMetadata,
9495
create_fast_prefill_custom_backend,
96+
get_dcp_local_seq_lens,
9597
reorder_batch_to_split_decodes_and_prefills,
9698
split_attn_metadata,
9799
)
@@ -256,6 +258,7 @@ def __init__(
256258
self.is_multimodal_pruning_enabled = False
257259
self.max_model_len = model_config.max_model_len
258260
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
261+
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
259262
self.max_num_tokens = scheduler_config.max_num_batched_tokens
260263
self.max_num_reqs = scheduler_config.max_num_seqs
261264

@@ -372,6 +375,7 @@ def __init__(
372375
# uses output token ids so we set this conservatively.
373376
logitsprocs_need_output_token_ids=bool(custom_logitsprocs),
374377
is_pooling_model=self.is_pooling_model,
378+
dcp_kv_cache_interleave_size=self.parallel_config.dcp_kv_cache_interleave_size,
375379
)
376380

377381
self.use_async_scheduling = self.scheduler_config.async_scheduling
@@ -1274,6 +1278,15 @@ def _prepare_inputs(
12741278
logits_indices
12751279
)
12761280

1281+
# update seq_lens of decode reqs under DCP.
1282+
if self.dcp_world_size > 1:
1283+
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
1284+
self.seq_lens.cpu[:num_reqs],
1285+
self.dcp_world_size,
1286+
self.parallel_config.dcp_kv_cache_interleave_size,
1287+
)[:, self.dcp_rank]
1288+
self.dcp_local_seq_lens.copy_to_gpu(num_reqs)
1289+
12771290
attn_metadata: PerLayerAttnMetadata = {}
12781291
if ubatch_slices is not None:
12791292
attn_metadata = [dict() for _ in range(len(ubatch_slices))]

0 commit comments

Comments
 (0)