Skip to content
Open
7 changes: 7 additions & 0 deletions tests/distributed/test_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ParallelSetup(NamedTuple):
tp_size: int
pp_size: int
dcp_size: int
dcp_kv_cache_interleave_size: int
eager_mode: bool
chunked_prefill: bool

Expand All @@ -52,6 +53,7 @@ def detailed(
tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
dcp_kv_cache_interleave_size: int = 1,
multi_node_only: bool = False,
runner: RunnerOption = "auto",
load_format: str | None = None,
Expand All @@ -66,6 +68,7 @@ def detailed(
tp_size=tp_base,
pp_size=pp_multiplier * pp_base,
dcp_size=int(dcp_multiplier * tp_base),
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
eager_mode=eager_mode_val,
chunked_prefill=chunked_prefill_val,
)
Expand Down Expand Up @@ -108,6 +111,7 @@ def _compare_cp_with_tp(
tp_size,
pp_size,
dcp_size,
dcp_kv_cache_interleave_size,
eager_mode,
chunked_prefill,
) = parallel_setup
Expand Down Expand Up @@ -180,6 +184,8 @@ def _compare_cp_with_tp(
str(pp_size),
"--decode-context-parallel-size",
str(dcp_size),
"--dcp-kv-cache-interleave-size",
str(dcp_kv_cache_interleave_size),
"--distributed-executor-backend",
distributed_backend,
]
Expand Down Expand Up @@ -207,6 +213,7 @@ def _compare_cp_with_tp(
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
CPTestSettings.detailed(),
CPTestSettings.detailed(tp_base=2),
CPTestSettings.detailed(tp_base=2, dcp_kv_cache_interleave_size=64),
],
"bigcode/gpt_bigcode-santacoder": [
CPTestSettings.detailed(),
Expand Down
11 changes: 11 additions & 0 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,17 @@ class is dynamically inherited by the worker class. This is used to inject
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""

dcp_kv_cache_interleave_size: int = 1
"""Interleave size of kv_cache storage while using dcp or cp > 1,
store interleave_size tokens on (d)cp i,
then store next interleave_size tokens on (d)cp i+1.
Interleave_size=1: token-level align, token i is stored on rank i % (d)cp_size.
Interleave_size=block_size: block-level align, first fill the block on first rank,
token is stored on rank i+1 block j after rank i block j is full.
Block_size should be greater than or equal to dcp_kv_cache_interleave_size.
Block_size should be divisible by dcp_kv_cache_interleave_size.
"""

_api_process_count: int = Field(default=1, gt=0)
"""
The number of API processes initialized.
Expand Down
17 changes: 17 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,23 @@ def __post_init__(self):
)
current_platform.check_and_update_config(self)

assert (
self.parallel_config.dcp_kv_cache_interleave_size
<= self.cache_config.block_size
and self.cache_config.block_size
% self.parallel_config.dcp_kv_cache_interleave_size
== 0
), (
f"Block_size({self.cache_config.block_size}) should be "
"greater than or equal to and divisible by dcp_kv_cache_interleave_size "
f"({self.parallel_config.dcp_kv_cache_interleave_size})."
)

assert (
self.parallel_config.dcp_kv_cache_interleave_size == 1
or self.speculative_config is None
), "MTP with dcp_kv_cache_interleave_size > 1 is not supported now."

# Do this after all the updates to compilation_config.mode
if (
envs.VLLM_USE_V1
Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ class EngineArgs:
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
data_parallel_size: int = ParallelConfig.data_parallel_size
data_parallel_rank: int | None = None
data_parallel_start_rank: int | None = None
Expand Down Expand Up @@ -717,6 +718,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"-dcp",
**parallel_kwargs["decode_context_parallel_size"],
)
parallel_group.add_argument(
"--dcp-kv-cache-interleave-size",
**parallel_kwargs["dcp_kv_cache_interleave_size"],
)
parallel_group.add_argument(
"--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]
)
Expand Down Expand Up @@ -1482,6 +1487,7 @@ def create_engine_config(
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
decode_context_parallel_size=self.decode_context_parallel_size,
dcp_kv_cache_interleave_size=self.dcp_kv_cache_interleave_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
)
Expand Down
16 changes: 13 additions & 3 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
Expand Down Expand Up @@ -232,6 +233,10 @@ def __init__(
self.dcp_world_size = 1
self.dcp_rank = 0

self.dcp_kv_cache_interleave_size = (
self.parallel_config.dcp_kv_cache_interleave_size
)

self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
Expand Down Expand Up @@ -350,8 +355,12 @@ def schedule(
- common_attn_metadata.query_start_loc_cpu[:-1]
)
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size

dcp_context_kv_lens_cpu = get_dcp_local_seq_lens(
dcp_context_kv_lens_cpu,
self.dcp_world_size,
self.dcp_rank,
self.dcp_kv_cache_interleave_size,
)
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
Expand Down Expand Up @@ -437,7 +446,8 @@ def use_cascade_attention(self, *args, **kwargs) -> bool:


class FlashAttentionImpl(AttentionImpl):
can_return_lse_for_decode: bool = True
# TODO(qcs): enable DCP when `flash_attn_varlen_func` supports ctxlen(seqused_k)=0
can_return_lse_for_decode: bool = False

def __init__(
self,
Expand Down
47 changes: 27 additions & 20 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
get_dcp_local_seq_lens,
get_per_layer_parameters,
infer_global_hyperparameters,
split_decodes_and_prefills,
Expand Down Expand Up @@ -554,6 +555,7 @@ def __init__(
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = parallel_config.dcp_kv_cache_interleave_size

# Don't try to access the runner on AMD
if self.aot_schedule:
Expand Down Expand Up @@ -774,15 +776,6 @@ def build(
)
)

# Note(hc): update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
assert dcp_local_seq_lens is not None
dcp_local_seq_lens[:num_decodes] = seq_lens[
:num_decodes
] // self.dcp_world_size + (
self.dcp_rank <= (seq_lens[:num_decodes] - 1) % self.dcp_world_size
)

assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens

Expand All @@ -792,9 +785,13 @@ def build(

context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
# Note(hc): The context lengths in the perspective of dcp rank0.
cp_context_lens_cpu = torch.ceil(
context_lens_cpu.float() / self.dcp_world_size
).int()
cp_context_lens_cpu = (
torch.ceil(
context_lens_cpu.float()
/ (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
).int()
* self.dcp_kv_cache_interleave_size
)
origin_context_lens = context_lens_cpu.tolist()
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
Expand Down Expand Up @@ -981,6 +978,7 @@ def reorg_kvcache(
chunk_size: int,
chunk_idx: int,
toks: int,
interleave_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg kvcache after cp local gather to tp layout for attn kernel.
Expand All @@ -995,27 +993,32 @@ def reorg_kvcache(
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
interleave_size: Interleave size of kv_cache storage.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for cp_chunk_seq_len, origin_context_len in zip(
cp_chunk_seq_lens_lst, origin_context_lens
local_context_lens_allrank = get_dcp_local_seq_lens(
torch.Tensor(origin_context_lens),
cp_world_size,
None,
interleave_size,
)
# print(origin_context_lens, local_context_lens_allrank)
for cp_chunk_seq_len, origin_context_len, local_context_lens in zip(
cp_chunk_seq_lens_lst, origin_context_lens, local_context_lens_allrank
):
chunk_context_len = chunk_size
if cp_chunk_seq_len != 0:
chunk_context_len = min(
chunk_context_len, origin_context_len - chunk_size * chunk_idx
)
cp_target_rank = (chunk_context_len - 1) % cp_world_size

cur_seq_len = 0
for rank in range(cp_world_size):
if rank > cp_target_rank and cp_chunk_seq_len:
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
else:
real_cp_chunk_seq_len = cp_chunk_seq_len
if real_cp_chunk_seq_len:
real_cp_chunk_seq_len = local_context_lens[rank]
if real_cp_chunk_seq_len != 0:
kv_c_segment = allgatered_kv_c_normed[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
Expand Down Expand Up @@ -1263,6 +1266,9 @@ def __init__(self, *args, **kwargs) -> None:
get_current_vllm_config()
)
)
self.dcp_kv_cache_interleave_size: int = (
get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size
)

def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
Expand Down Expand Up @@ -1641,6 +1647,7 @@ def _context_parallel_compute_prefill_context(
chunk_size=prefill_metadata.chunked_context.chunk_size,
chunk_idx=i,
toks=toks,
interleave_size=self.dcp_kv_cache_interleave_size,
)

kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
Expand Down
38 changes: 38 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,3 +992,41 @@ def compute_causal_conv1d_metadata(query_start_loc_p: torch.Tensor):
nums_dict[BLOCK_M]["token_chunk_offset_ptr"] = token_chunk_offset_ptr # type: ignore

return nums_dict, batch_ptr, token_chunk_offset_ptr


def get_dcp_local_seq_lens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you simplify this and have it only compute for the current dcp_rank? and pass in dcp rank

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason as this comment, it might be more flexible to return the full seq_len split result and each dcp_rank can select their own part as needed. Since the size of seq_len is only max_num_reqs, multiply it by dcp_size will increase little computation/storage overhead and we think it should be acceptable.

seq_lens: torch.Tensor,
dcp_world_size: int = 1,
dcp_rank: int | None = None,
dcp_kv_cache_interleave_size: int = 1,
) -> torch.Tensor:
"""While using dcp, kv_cache size stored on each rank may be different,
use this function to calculate split decode seq_lens of each dcp rank.
Only consider dcp now, we can extend the case of cp based on this.
"""
num_requests = seq_lens.size(0)
if dcp_rank is None:
rank_offsets = (
torch.arange(dcp_world_size, dtype=torch.int32)
.unsqueeze(0)
.repeat(num_requests, 1)
)
else:
rank_offsets = torch.Tensor([[dcp_rank]]).to(dtype=torch.int32)
seq_lens_tiled = (
seq_lens.to(torch.int32).unsqueeze(-1).repeat(1, rank_offsets.shape[1])
)
base = (
seq_lens_tiled
// dcp_kv_cache_interleave_size
// dcp_world_size
* dcp_kv_cache_interleave_size
)
remainder = seq_lens_tiled - base * dcp_world_size
remainder = torch.clip(
remainder - rank_offsets * dcp_kv_cache_interleave_size,
0,
dcp_kv_cache_interleave_size,
)
dcp_local_seq_lens = base + remainder
return dcp_local_seq_lens.squeeze(1)
18 changes: 16 additions & 2 deletions vllm/v1/worker/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
dcp_kv_cache_interleave_size: int,
):
"""
Args:
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size

def append_row(
self,
Expand Down Expand Up @@ -144,9 +146,19 @@ def compute_slot_mapping(
# Use virtual_block_size for mask calculation, which marks local
# tokens.
virtual_block_offsets = positions % virtual_block_size
mask = virtual_block_offsets % self.dcp_world_size == self.dcp_rank
mask = (
virtual_block_offsets
// self.dcp_kv_cache_interleave_size
% self.dcp_world_size
== self.dcp_rank
)
# Calculate local block_offsets
block_offsets = virtual_block_offsets // self.dcp_world_size
block_offsets = (
virtual_block_offsets
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
* self.dcp_kv_cache_interleave_size
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
# Write final slots, use -1 for not-local
Expand Down Expand Up @@ -234,6 +246,7 @@ def __init__(
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
Expand Down Expand Up @@ -263,6 +276,7 @@ def __init__(
pin_memory,
device,
kernel_block_size,
dcp_kv_cache_interleave_size,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
]
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(
block_sizes=block_sizes,
kernel_block_sizes=kernel_block_sizes,
num_speculative_tokens=num_speculative_tokens,
dcp_kv_cache_interleave_size=dcp_kv_cache_interleave_size,
)

# Sampling-related.
Expand Down
Loading