- 
          
- 
        Couldn't load subscription status. 
- Fork 10.9k
[TPU] kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block #22394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TPU] kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block #22394
Conversation
…ces_per_block Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
| Warning Gemini encountered an error creating the review. You can try again by commenting  | 
| 👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run  Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add  🚀 | 
| I remember before the fix, it outputs something like Do you know where the output is from? | 
|  | ||
| def _kv_cache_update_kernel( | ||
| # Prefetch | ||
| slices_ref, # [3, padded_num_slices], list of (kv_cache_start, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we don't pad the num_slices, is slices_ref.shape[1] sufficient so that you don't need num_slices_ref?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We still need pad to avoid recompilation, but don't need to pad to multiple of num_slices_per_block
| def _get_padded_num_kv_cache_update_slices( | ||
| num_tokens: int, max_num_reqs: int, page_size: int, | ||
| num_slices_per_kv_cache_update_block: int) -> int: | ||
| def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove "padded" from the function name, variable names, and comments since we don't need to pad the num_slices anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I replied in the previous comment, we still need to pad it.
| num_tokens: int, max_num_reqs: int, page_size: int, | ||
| num_slices_per_kv_cache_update_block: int) -> int: | ||
| def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int, | ||
| page_size: int) -> int: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: one more ask, could you add #19928 (comment) as comment here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, it's added.
Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
| 
 @vanbasten23 it's from the XLA execution. Changing the logic in the kernel can prevent such out-of-index error during execution. | 
…iple of num_slices_per_block (vllm-project#22394) Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
…iple of num_slices_per_block (vllm-project#22394) Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
…iple of num_slices_per_block (vllm-project#22394) Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
…iple of num_slices_per_block (vllm-project#22394) Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
…iple of num_slices_per_block (vllm-project#22394) Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
…iple of num_slices_per_block (vllm-project#22394) Signed-off-by: Chengji Yao <chengjiyao@gmail.com> Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block
Test Plan
Test Result
passed
(Optional) Documentation Update