Skip to content

Commit 6fadfb4

Browse files
Merge pull request #7 from zhangsicheng5/long_seq_dev
support kv_cache interleave_size
2 parents bbc4ea4 + f3a8cce commit 6fadfb4

File tree

4 files changed

+61
-0
lines changed

4 files changed

+61
-0
lines changed

vllm/config/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,18 @@ def __post_init__(self):
618618
"to True to enable.")
619619
current_platform.check_and_update_config(self)
620620

621+
assert (
622+
self.parallel_config.cp_kv_cache_interleave_size
623+
<= self.cache_config.block_size
624+
and self.cache_config.block_size
625+
% self.parallel_config.cp_kv_cache_interleave_size
626+
== 0
627+
), (
628+
f"Block_size({self.cache_config.block_size}) should be "
629+
"greater than or equal to and divisible by cp_kv_cache_interleave_size "
630+
f"({self.parallel_config.cp_kv_cache_interleave_size})."
631+
)
632+
621633
# final check of cudagraph mode after platform-specific update
622634
if envs.VLLM_USE_V1 and current_platform.is_cuda_alike():
623635
if self.compilation_config.cudagraph_mode == CUDAGraphMode.FULL \

vllm/config/parallel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,17 @@ class is dynamically inherited by the worker class. This is used to inject
195195
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
196196
needs to be divisible by dcp_size."""
197197

198+
cp_kv_cache_interleave_size: int = 1
199+
"""Interleave size of kv_cache storage while using dcp or cp > 1,
200+
store interleave_size tokens on (d)cp i,
201+
then store next interleave_size tokens on (d)cp i+1.
202+
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
203+
Interleave_size=block_size: block-level align, first fill the block on first rank,
204+
token is stored on rank i+1 block j after rank i block j is full.
205+
Block_size should be greater than or equal to cp_kv_cache_interleave_size.
206+
Block_size should be divisible by cp_kv_cache_interleave_size.
207+
"""
208+
198209
_api_process_count: int = 1
199210
"""
200211
The number of API processes initialized.

vllm/engine/arg_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ class EngineArgs:
318318
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
319319
decode_context_parallel_size: int = \
320320
ParallelConfig.decode_context_parallel_size
321+
cp_kv_cache_interleave_size: int = ParallelConfig.cp_kv_cache_interleave_size
321322
context_parallel_size: int = ParallelConfig.context_parallel_size
322323
data_parallel_size: int = ParallelConfig.data_parallel_size
323324
data_parallel_rank: Optional[int] = None
@@ -654,6 +655,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
654655
parallel_group.add_argument(
655656
"--decode-context-parallel-size", "-dcp",
656657
**parallel_kwargs["decode_context_parallel_size"])
658+
parallel_group.add_argument(
659+
"--cp-kv-cache-interleave-size",
660+
**parallel_kwargs["cp_kv_cache_interleave_size"])
657661
parallel_group.add_argument(
658662
"--context-parallel-size", "-cp",
659663
**parallel_kwargs["context_parallel_size"])
@@ -1338,6 +1342,7 @@ def create_engine_config(
13381342
worker_cls=self.worker_cls,
13391343
worker_extension_cls=self.worker_extension_cls,
13401344
decode_context_parallel_size=self.decode_context_parallel_size,
1345+
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
13411346
_api_process_count=self._api_process_count,
13421347
_api_process_rank=self._api_process_rank,
13431348
)

vllm/v1/attention/backends/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,3 +851,36 @@ def __init__(self, metadata, common_attn_metadata):
851851
builder_cls=FastPrefillAttentionBuilder)
852852

853853
return attn_backend
854+
855+
856+
def get_cp_local_seq_lens(
857+
seq_lens: torch.Tensor,
858+
cp_world_size: int = 1,
859+
dcp_world_size: int = 1,
860+
cp_kv_cache_interleave_size: int = 1,
861+
) -> torch.Tensor:
862+
"""While using cp or dcp, kv_cache size stored on each rank may be different,
863+
use this function to calculate split decode seq_lens of each (d)cp rank.
864+
"""
865+
num_requests = seq_lens.size(0)
866+
total_world_size = cp_world_size * dcp_world_size
867+
seq_lens_tiled = seq_lens.unsqueeze(-1).repeat(1, total_world_size)
868+
rank_offsets = (
869+
torch.arange(total_world_size, dtype=torch.int32)
870+
.unsqueeze(0)
871+
.repeat(num_requests, 1)
872+
)
873+
base = (
874+
seq_lens_tiled
875+
// cp_kv_cache_interleave_size
876+
// total_world_size
877+
* cp_kv_cache_interleave_size
878+
)
879+
remainder = seq_lens_tiled - base * total_world_size
880+
remainder = torch.clip(
881+
remainder - rank_offsets * cp_kv_cache_interleave_size,
882+
0,
883+
cp_kv_cache_interleave_size,
884+
)
885+
dcp_local_seq_lens = (base + remainder).reshape([-1, cp_world_size, dcp_world_size])
886+
return dcp_local_seq_lens

0 commit comments

Comments
 (0)