Skip to content

Conversation

@zhangsicheng5
Copy link

@zhangsicheng5 zhangsicheng5 commented Oct 13, 2025

Purpose

1. cp_kv_cache_interleave_size support

In dcp scenario, kv_cache is split across dcp ranks, current implementation (#23734) split kv_cache with a token-level interleave style: token_idx i is stored on GPU whose dcp_rank == i % dcp_world_size.

For the convenience of pd disaggregate support, we add the cp_kv_cache_interleave_size argument to control the interleave size of kv_cache split size: store interleave_size tokens on dcp i, then store next interleave_size tokens on dcp i+1. The default value of cp_kv_cache_interleave_size is 1, which is same as original token-level interleave implementation. By setting cp_kv_cache_interleave_size to block_size, we can split kv_cache with a block-level interleave style, and makes it easy to support pd disaggregate with dcp > 1: D nodes only need to pull the corresponding kv_cache blocks, without need to rearange tokens in blocks.

Only dcp with cp_kv_cache_interleave_size is supported now, but the case of (p)cp is also considered and is easy to extend in the future.

2. Move dcp_local_seq_lens computation to utils

Move dcp_local_seq_lens computation to utils and pass it by metadata, so other attn backends can reuse it.

Test Plan

Model: DeepSeek-V2-Lite-Chat
Dataset: gsm8k

vllm serve DeepSeek-V2-Lite-Chat --gpu-memory-utilization 0.9 --tensor-parallel-size 2 --decode-context-parallel-size 2 --cp-kv-cache-interleave-size 64

Test Result

tp2 dcp2, original code

dataset version metric mode vllm-api-stream-chat
gsm8k 7cd45e accuracy gen 67.85

tp2 dcp2, interleave_size = 1

dataset version metric mode vllm-api-stream-chat
gsm8k 7cd45e accuracy gen 67.85

tp2 dcp2, interleave_size = 64

dataset version metric mode vllm-api-stream-chat
gsm8k 7cd45e accuracy gen 67.55

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the v1 label Oct 13, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for configurable interleave size for KV cache in Decode Context Parallelism (DCP), which is a nice enhancement for flexibility. The changes also include refactoring the dcp_local_seq_lens computation into a utility function. The implementation is mostly solid, but I've identified a couple of areas for improvement. One is a misleading error message in an assertion, and the other is an opportunity to refactor a new utility function for better readability and efficiency. Addressing these points will improve the code quality.

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

@zhangsicheng5 zhangsicheng5 force-pushed the dev branch 4 times, most recently from aa23faa to 397fd51 Compare October 13, 2025 13:26
@zhangsicheng5 zhangsicheng5 force-pushed the dev branch 2 times, most recently from 656e08c to 46ec829 Compare October 14, 2025 02:02
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! Requesting changes to prevent merging until the test results can be obtained 👍

tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
cp_kv_cache_interleave_size: int = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please call this dcp_kv_cache_interleave_size

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After prefill cp (#25852) is supported, this kv_cache_interleave_size will be used for both dcp and pcp, shall we keep this name for future usage?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure; by that logic we should update dcp_local_seq_lens to cp_local_seq_lens too but we can do that in the pcp PR

self,
req_indices: np.ndarray,
positions: np.ndarray,
cp_kv_cache_interleave_size: int = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is a constant pass it via the init

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review, we have already passed it via init

i += 1


def get_dcp_local_seq_lens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should find a better spot for this; this is too broad of a utils file for a feature specific utility

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we put this function in vllm/v1/attention/backends/utils.py, same place as CommonAttentionMetadata.dcp_local_seq_lens definition, this should be a more appropriate spot


# update seq_lens of decode reqs under DCP.
if self.dcp_world_size > 1:
self.dcp_local_seq_lens.gpu[:num_reqs] = get_dcp_local_seq_lens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might actually be better to compute get_dcp_local_seq_lens using host buffers and then do a non-blocking copy to self.dcp_local_seq_lens.gpu (see: CpuGpuBuffer.copy_to_gpu)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(then when async scheduling is enabled it will be overlapped)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified as suggested, thanks for review

Copy link
Contributor

@youzhedian youzhedian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

self.max_model_len = model_config.max_model_len
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
try:
self.dcp_rank = get_dcp_group().rank_in_group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delay to get_dcp_local_seq_lens calling is better?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases we might need to know how seq_len is split globally, instead of only local seq_len on current dcp_rank, for example in our current npu mla impl, we need the global seq_len split message to calculate a mask for following update_lse (if no kv_cache is stored on some (d)cp_ranks, then there's no need to do corresponding update_lse), so we think it's better to return the full seq_len split result from get_dcp_local_seq_lens, and each dcp_rank can select their corresponding part as needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can simplify this to:

self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group

that way we'll still get the benefit of the assert in get_dcp_group() an if a test sets self.dcp_world_size > 1 it should be initializing the dcp group anyways

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better way to get dcp_rank 👍 Modified as suggested

self,
req_indices: np.ndarray,
positions: np.ndarray,
cp_kv_cache_interleave_size: int = 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe no default val is better

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review, now we pass this arg via init, since it's a constant

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me! Thanks for the contribution, left a few nits.

return nums_dict, batch_ptr, token_chunk_offset_ptr


def get_dcp_local_seq_lens(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you simplify this and have it only compute for the current dcp_rank? and pass in dcp rank

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reason as this comment, it might be more flexible to return the full seq_len split result and each dcp_rank can select their own part as needed. Since the size of seq_len is only max_num_reqs, multiply it by dcp_size will increase little computation/storage overhead and we think it should be acceptable.

tp_base: int = 4,
pp_base: int = 1,
dcp_base: int = 1,
cp_kv_cache_interleave_size: int = 1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure; by that logic we should update dcp_local_seq_lens to cp_local_seq_lens too but we can do that in the pcp PR

self.max_model_len = model_config.max_model_len
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
try:
self.dcp_rank = get_dcp_group().rank_in_group
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we can simplify this to:

self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group

that way we'll still get the benefit of the assert in get_dcp_group() an if a test sets self.dcp_world_size > 1 it should be initializing the dcp group anyways

Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
@FENP
Copy link
Contributor

FENP commented Oct 17, 2025

Does dcp_kv_cache_interleave_size affect reorg_kvcache, especially the calculation of cp_target_rank?

def reorg_kvcache(
allgatered_kv_c_normed: torch.Tensor,
allgatered_k_pe: torch.Tensor,
cp_chunk_seq_lens_lst: list[int],
origin_context_lens: list[int],
cp_world_size: int,
sum_seq_len: int,
max_seq_len: int,
chunk_size: int,
chunk_idx: int,
toks: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
reorg kvcache after cp local gather to tp layout for attn kernel.
Args:
cp_chunk_seq_lens_lst: chunk context lengths under CP.
origin_context_lens: origin full context lengths under CP.
cp_world_size: CP size.
sum_seq_len: the sum of cp_chunk_seq_lens_lst.
max_seq_len: the max value of cp_chunk_seq_lens_lst.
chunk_size: equals to max_context_chunk from
chunked_context_metadata building.
chunk_idx: chunk idx of chunked_prefill.
toks: the number of tokens for local gather cache.
"""
kv_c_segments = []
k_pe_segments = []
src_token_idx = 0
max_seq_len_check = 0
for cp_chunk_seq_len, origin_context_len in zip(
cp_chunk_seq_lens_lst, origin_context_lens
):
chunk_context_len = chunk_size
if cp_chunk_seq_len != 0:
chunk_context_len = min(
chunk_context_len, origin_context_len - chunk_size * chunk_idx
)
cp_target_rank = (chunk_context_len - 1) % cp_world_size
cur_seq_len = 0
for rank in range(cp_world_size):
if rank > cp_target_rank and cp_chunk_seq_len:
real_cp_chunk_seq_len = cp_chunk_seq_len - 1
else:
real_cp_chunk_seq_len = cp_chunk_seq_len
if real_cp_chunk_seq_len:
kv_c_segment = allgatered_kv_c_normed[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ real_cp_chunk_seq_len
]
k_pe_segment = allgatered_k_pe[
rank * toks + src_token_idx : rank * toks
+ src_token_idx
+ real_cp_chunk_seq_len
]
kv_c_segments.append(kv_c_segment)
k_pe_segments.append(k_pe_segment)
cur_seq_len += real_cp_chunk_seq_len
max_seq_len_check = max(max_seq_len_check, cur_seq_len)
src_token_idx += cp_chunk_seq_len
reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
assert reorganized_k_pe.shape[0] == sum_seq_len
assert max_seq_len_check == max_seq_len
return reorganized_kv_c_normed, reorganized_k_pe

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch @FENP ; please address: #26696 (comment)

@pisceskkk
Copy link

pisceskkk commented Oct 21, 2025

Does dcp_kv_cache_interleave_size affect reorg_kvcache, especially the calculation of cp_target_rank?

Sorry for the late reply. We have fixed this bug and have retested the chunked prefill feature. We counted the number of tokens of gsm8k dataset (which contains over 700 prompts with more than 100 tokens) and used --max_num_batched_tokens=50. Therefore, I believe this accuracy test validates the correctness of the current changes. If there are still any other issues, please feel free to point them out. Looking forward to your further review.


Test Results
Model: DeepSeek-V2-Lite-Chat
Dataset: gsm8k
hyperparams: tp2, dcp2, batch_size=5, max_num_batched_tokens=50, max_out_len = 1024

interleave_size = 1

dataset version metric mode vllm-api-stream-chat
gsm8kdataset - accuracy gen 67.32

interleave_size = 8

dataset version metric mode vllm-api-stream-chat
gsm8kdataset - accuracy gen 68.23

CC List:
@LucasWilkinson @FENP @youzhedian

…leave size > 1

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
@pisceskkk
Copy link

I've completed the adaptation for the GQA flash_attn backend and conducted several precision tests. Although the precision tests don't show significant impact, I find this implementation approach rather unusual and suspect there might still be underlying issues. I'd appreciate if everyone could review the code(5d7184a), especially regarding this particular variable dcp_context_kv_lens_cpu.

More details, in the original DCP implementation, I believe this value represents the length of the KVCache on each dcp rank. However, when the total length of the KVCache is divisible by the DCP world size, it actually becomes one greater than the actual length on each rank. I attempted to correct this discrepancy but it resulted in completely incorrect prediction outcomes. Therefore, I retained a similar implementation, which seems somewhat abnormal.

Test results:
Qwen3/Qwen3-235B-A22B-FP8 TP8 DCP2
original code(commit 650b51f)

dataset version metric mode vllm-api-stream-chat
gsm8kdataset - accuracy gen 74.75

interleave size = 8, max_num_batched_tokens = 100

dataset version metric mode vllm-api-stream-chat
gsm8kdataset - accuracy gen 74.60

interleave size = 8, max_num_batched_tokens = 99

dataset version metric mode vllm-api-stream-chat
gsm8kdataset - accuracy gen 76.42

@LucasWilkinson
Copy link
Collaborator

I've completed the adaptation for the GQA flash_attn backend and conducted several precision tests. Although the precision tests don't show significant impact, I find this implementation approach rather unusual and suspect there might still be underlying issues. I'd appreciate if everyone could review the code(5d7184a), especially regarding this particular variable dcp_context_kv_lens_cpu.

Apologies for the delay! (FYI on vacation so response will be delayed). I agree this is very weird and should be figured out before landing. Why not use common_attn_metadata.dcp_local_seq_lens for flash_attn? might be related to the - query_kv_lens_cpu part of dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu

@pisceskkk
Copy link

pisceskkk commented Oct 25, 2025

Why not use common_attn_metadata.dcp_local_seq_lens for flash_attn? might be related to the - query_kv_lens_cpu part of dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu

@LucasWilkinson yes, common_attn_metadata.dcp_local_seq_lens is about the whole seq with both context and current chunks, but we only need the context chunks here.

I have thoroughly examined the subsequent usage of dcp_context_kv_lens but was unable to dive deeper into the operator-side analysis. I believe we need to cc some experts proficient in the logic of the flash_attn operator to help explain the meaning of seqused_k and identify the root cause of the issue. I would greatly appreciate assistance from specialists in this area.

@minosfuture
Copy link
Contributor

minosfuture commented Oct 25, 2025

Why not use common_attn_metadata.dcp_local_seq_lens for flash_attn? might be related to the - query_kv_lens_cpu part of dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu

@LucasWilkinson yes, common_attn_metadata.dcp_local_seq_lens is about the whole seq with both context and current chunks, but we only need the context chunks here.

I have thoroughly examined the subsequent usage of dcp_context_kv_lens but was unable to dive deeper into the operator-side analysis. I believe we need to cc some experts proficient in the logic of the flash_attn operator to help explain the meaning of seqused_k and identify the root cause of the issue. I would greatly appreciate assistance from specialists in this area.

seqused_k is the whole kv length (context len + query len). We need to fix the usage in _forward_with_dcp.
I double checked we correctly used kv len in code path of MLA and non-dcp attention.

edit:nvm, I noticed that dcp attention is implemented by splitting context and current chunk to avoid needing custom mask handling.

Discussed separately that read of context len==0 needs extra care.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants