Skip to content

Conversation

@yaochengji
Copy link
Collaborator

@yaochengji yaochengji commented Aug 6, 2025

Essential Elements of an Effective PR Description Checklist

  • [ x] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • [ x] The test plan, such as providing test command.
  • [x ] The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

kv cache update kernel doesn't need to be padded slices to multiple of num_slices_per_block

Test Plan

pytest -s -v tests/v1/tpu/test_kv_cache_update_kernel.py

Test Result

passed

(Optional) Documentation Update

…ces_per_block

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
@gemini-code-assist
Copy link
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@github-actions
Copy link

github-actions bot commented Aug 6, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added v1 tpu Related to Google TPUs labels Aug 6, 2025
@yaochengji yaochengji requested review from mgoin and vanbasten23 August 7, 2025 00:43
@vanbasten23
Copy link
Collaborator

I remember before the fix, it outputs something like

+-----------+
| Anomalies |
+-----------+
---------
Anomalies
---------
+-----------+
| Anomalies |
+-----------+
---------
Anomalies
---------

Do you know where the output is from?


def _kv_cache_update_kernel(
# Prefetch
slices_ref, # [3, padded_num_slices], list of (kv_cache_start,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we don't pad the num_slices, is slices_ref.shape[1] sufficient so that you don't need num_slices_ref?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need pad to avoid recompilation, but don't need to pad to multiple of num_slices_per_block

def _get_padded_num_kv_cache_update_slices(
num_tokens: int, max_num_reqs: int, page_size: int,
num_slices_per_kv_cache_update_block: int) -> int:
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove "padded" from the function name, variable names, and comments since we don't need to pad the num_slices anymore?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I replied in the previous comment, we still need to pad it.

num_tokens: int, max_num_reqs: int, page_size: int,
num_slices_per_kv_cache_update_block: int) -> int:
def _get_padded_num_kv_cache_update_slices(num_tokens: int, max_num_reqs: int,
page_size: int) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: one more ask, could you add #19928 (comment) as comment here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, it's added.

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
@yaochengji
Copy link
Collaborator Author

I remember before the fix, it outputs something like

+-----------+
| Anomalies |
+-----------+
---------
Anomalies
---------
+-----------+
| Anomalies |
+-----------+
---------
Anomalies
---------

Do you know where the output is from?

@vanbasten23 it's from the XLA execution. Changing the logic in the kernel can prevent such out-of-index error during execution.

@simon-mo simon-mo enabled auto-merge (squash) August 8, 2025 23:05
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 8, 2025
@vllm-bot vllm-bot merged commit 2a84fb4 into vllm-project:main Aug 10, 2025
45 of 51 checks passed
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…iple of num_slices_per_block (vllm-project#22394)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
Signed-off-by: Paul Pak <paulpak58@gmail.com>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…iple of num_slices_per_block (vllm-project#22394)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
Signed-off-by: Diego-Castan <diego.castan@ibm.com>
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
…iple of num_slices_per_block (vllm-project#22394)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
…iple of num_slices_per_block (vllm-project#22394)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
…iple of num_slices_per_block (vllm-project#22394)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
…iple of num_slices_per_block (vllm-project#22394)

Signed-off-by: Chengji Yao <chengjiyao@gmail.com>
Co-authored-by: Chengji Yao <chengjiyao@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants