-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[DCP] Support dcp kv_cache interleave size > 1 #26696
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for configurable interleave size for KV cache in Decode Context Parallelism (DCP), which is a nice enhancement for flexibility. The changes also include refactoring the dcp_local_seq_lens computation into a utility function. The implementation is mostly solid, but I've identified a couple of areas for improvement. One is a misleading error message in an assertion, and the other is an opportunity to refactor a new utility function for better readability and efficiency. Addressing these points will improve the code quality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
aa23faa to
397fd51
Compare
656e08c to
46ec829
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution! Requesting changes to prevent merging until the test results can be obtained 👍
| tp_base: int = 4, | ||
| pp_base: int = 1, | ||
| dcp_base: int = 1, | ||
| cp_kv_cache_interleave_size: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please call this dcp_kv_cache_interleave_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After prefill cp (#25852) is supported, this kv_cache_interleave_size will be used for both dcp and pcp, shall we keep this name for future usage?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure; by that logic we should update dcp_local_seq_lens to cp_local_seq_lens too but we can do that in the pcp PR
vllm/v1/worker/block_table.py
Outdated
| self, | ||
| req_indices: np.ndarray, | ||
| positions: np.ndarray, | ||
| cp_kv_cache_interleave_size: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this is a constant pass it via the init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for review, we have already passed it via init
vllm/utils/__init__.py
Outdated
| i += 1 | ||
|
|
||
|
|
||
| def get_dcp_local_seq_lens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should find a better spot for this; this is too broad of a utils file for a feature specific utility
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we put this function in vllm/v1/attention/backends/utils.py, same place as CommonAttentionMetadata.dcp_local_seq_lens definition, this should be a more appropriate spot
vllm/v1/worker/gpu_model_runner.py
Outdated
|
|
||
| # update seq_lens of decode reqs under DCP. | ||
| if self.dcp_world_size > 1: | ||
| self.dcp_local_seq_lens.gpu[:num_reqs] = get_dcp_local_seq_lens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it might actually be better to compute get_dcp_local_seq_lens using host buffers and then do a non-blocking copy to self.dcp_local_seq_lens.gpu (see: CpuGpuBuffer.copy_to_gpu)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(then when async scheduling is enabled it will be overlapped)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified as suggested, thanks for review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
vllm/v1/worker/gpu_model_runner.py
Outdated
| self.max_model_len = model_config.max_model_len | ||
| self.dcp_world_size = self.parallel_config.decode_context_parallel_size | ||
| try: | ||
| self.dcp_rank = get_dcp_group().rank_in_group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delay to get_dcp_local_seq_lens calling is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some cases we might need to know how seq_len is split globally, instead of only local seq_len on current dcp_rank, for example in our current npu mla impl, we need the global seq_len split message to calculate a mask for following update_lse (if no kv_cache is stored on some (d)cp_ranks, then there's no need to do corresponding update_lse), so we think it's better to return the full seq_len split result from get_dcp_local_seq_lens, and each dcp_rank can select their corresponding part as needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we can simplify this to:
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
that way we'll still get the benefit of the assert in get_dcp_group() an if a test sets self.dcp_world_size > 1 it should be initializing the dcp group anyways
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better way to get dcp_rank 👍 Modified as suggested
vllm/v1/worker/block_table.py
Outdated
| self, | ||
| req_indices: np.ndarray, | ||
| positions: np.ndarray, | ||
| cp_kv_cache_interleave_size: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe no default val is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for review, now we pass this arg via init, since it's a constant
fb69ad7 to
33413c6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me! Thanks for the contribution, left a few nits.
| return nums_dict, batch_ptr, token_chunk_offset_ptr | ||
|
|
||
|
|
||
| def get_dcp_local_seq_lens( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you simplify this and have it only compute for the current dcp_rank? and pass in dcp rank
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same reason as this comment, it might be more flexible to return the full seq_len split result and each dcp_rank can select their own part as needed. Since the size of seq_len is only max_num_reqs, multiply it by dcp_size will increase little computation/storage overhead and we think it should be acceptable.
| tp_base: int = 4, | ||
| pp_base: int = 1, | ||
| dcp_base: int = 1, | ||
| cp_kv_cache_interleave_size: int = 1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure; by that logic we should update dcp_local_seq_lens to cp_local_seq_lens too but we can do that in the pcp PR
vllm/v1/worker/gpu_model_runner.py
Outdated
| self.max_model_len = model_config.max_model_len | ||
| self.dcp_world_size = self.parallel_config.decode_context_parallel_size | ||
| try: | ||
| self.dcp_rank = get_dcp_group().rank_in_group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think we can simplify this to:
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
that way we'll still get the benefit of the assert in get_dcp_group() an if a test sets self.dcp_world_size > 1 it should be initializing the dcp group anyways
Signed-off-by: zhangsicheng5 <zhangsicheng5@huawei.com>
|
Does vllm/vllm/v1/attention/backends/mla/common.py Lines 976 to 1042 in e20eba7
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch @FENP ; please address: #26696 (comment)
Sorry for the late reply. We have fixed this bug and have retested the chunked prefill feature. We counted the number of tokens of gsm8k dataset (which contains over 700 prompts with more than 100 tokens) and used Test Results interleave_size = 1
interleave_size = 8
CC List: |
…leave size > 1 Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
|
I've completed the adaptation for the GQA flash_attn backend and conducted several precision tests. Although the precision tests don't show significant impact, I find this implementation approach rather unusual and suspect there might still be underlying issues. I'd appreciate if everyone could review the code(5d7184a), especially regarding this particular variable More details, in the original DCP implementation, I believe this value represents the length of the KVCache on each dcp rank. However, when the total length of the KVCache is divisible by the DCP world size, it actually becomes one greater than the actual length on each rank. I attempted to correct this discrepancy but it resulted in completely incorrect prediction outcomes. Therefore, I retained a similar implementation, which seems somewhat abnormal. Test results:
interleave size = 8, max_num_batched_tokens = 100
interleave size = 8, max_num_batched_tokens = 99
|
Apologies for the delay! (FYI on vacation so response will be delayed). I agree this is very weird and should be figured out before landing. Why not use |
@LucasWilkinson yes, I have thoroughly examined the subsequent usage of |
seqused_k is the whole kv length (context len + query len). We need to fix the usage in _forward_with_dcp. edit:nvm, I noticed that dcp attention is implemented by splitting context and current chunk to avoid needing custom mask handling. Discussed separately that read of context len==0 needs extra care. |
Purpose
1. cp_kv_cache_interleave_size support
In dcp scenario, kv_cache is split across dcp ranks, current implementation (#23734) split kv_cache with a token-level interleave style: token_idx i is stored on GPU whose dcp_rank == i % dcp_world_size.
For the convenience of pd disaggregate support, we add the cp_kv_cache_interleave_size argument to control the interleave size of kv_cache split size: store interleave_size tokens on dcp i, then store next interleave_size tokens on dcp i+1. The default value of cp_kv_cache_interleave_size is 1, which is same as original token-level interleave implementation. By setting cp_kv_cache_interleave_size to block_size, we can split kv_cache with a block-level interleave style, and makes it easy to support pd disaggregate with dcp > 1: D nodes only need to pull the corresponding kv_cache blocks, without need to rearange tokens in blocks.
Only dcp with cp_kv_cache_interleave_size is supported now, but the case of (p)cp is also considered and is easy to extend in the future.
2. Move dcp_local_seq_lens computation to utils
Move dcp_local_seq_lens computation to utils and pass it by metadata, so other attn backends can reuse it.
Test Plan
Model: DeepSeek-V2-Lite-Chat
Dataset: gsm8k
Test Result
tp2 dcp2, original code
tp2 dcp2, interleave_size = 1
tp2 dcp2, interleave_size = 64
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.