-
-
Couldn't load subscription status.
- Fork 10.9k
[Perf] Optimize reshape_and_cache CUDA Kernel #26021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request optimizes the reshape_and_cache CUDA kernel by altering the parallelization strategy, which results in significant performance gains for large token counts as demonstrated by the provided benchmarks. The new approach vectorizes key_cache updates. My review identified a critical issue where the complex index calculations for key_cache and value_cache could lead to integer overflow with large tensor dimensions, potentially causing memory corruption. I have provided a code suggestion to resolve this by pre-calculating strides using int64_t, which also enhances code readability.
csrc/cache_kernels.cu
Outdated
| cache_t* __restrict__ key_dst = key_cache + block_idx * num_heads * h_block_count * block_size * x | ||
| + head_idx * h_block_count * block_size * x | ||
| + h_block * block_size * x | ||
| + block_offset * x; | ||
| const int64_t tgt_value_start = block_idx * num_heads * h_block_count * x * block_size | ||
| + head_idx * h_block_count * x * block_size | ||
| + h_block * x * block_size | ||
| + block_offset; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The index calculations for key_dst and tgt_value_start are susceptible to integer overflow. The intermediate products like num_heads * h_block_count * block_size * x are computed using int types. With large tensor dimensions, this could overflow, leading to incorrect memory addressing and potential data corruption.
To prevent this and improve readability, I suggest pre-calculating the strides for each dimension using int64_t before computing the final offsets.
const int64_t key_h_block_stride = (int64_t)block_size * x;
const int64_t key_head_stride = (int64_t)h_block_count * key_h_block_stride;
const int64_t key_block_stride = (int64_t)num_heads * key_head_stride;
cache_t* __restrict__ key_dst = key_cache + block_idx * key_block_stride
+ head_idx * key_head_stride
+ h_block * key_h_block_stride
+ block_offset * x;
const int64_t val_h_block_stride = (int64_t)x * block_size;
const int64_t val_head_stride = (int64_t)h_block_count * val_h_block_stride;
const int64_t val_block_stride = (int64_t)num_heads * val_head_stride;
const int64_t tgt_value_start = block_idx * val_block_stride
+ head_idx * val_head_stride
+ h_block * val_h_block_stride
+ block_offset;
Signed-off-by: Liu-congo <1502632128@qq.com>
Purpose
FIX #25705
Method
then i viewed them in shape as:
that is to say that if we start the kernel in config as grid(num_tokens), block(num_heads * head_size // x)
we process each token's kv_cache in a block
the block size is num_heads * head_size // x
so by the threadIdx we could read a contiguous x elements in key/value
then apply vectorize op in key_cache and use naive loop method for value_cache
that's how it worked
I doubt whether there are efficient scatter kernel in cuda that can speed up the value_cache's op, cause value_cache need's to
fill in the data in a stride of block_size for every element, maybe there still potential for speeding up?
Test Plan
for time cost test, i add a benchmark_reshape_and_cache.py into benchmark/kernels, which almost copy the benchmark/kernels/benchmark_reshape_and_cache_flash.py
for accuary test, it has been implemented in tests/kernels/attention/test_cache.py::test_reshape_and_cache
the kernel implemented can passed it fully
Test Result
python benchmark/kernels/benckmark_reshape_and_cache.py --num-heads 64 --head-size 64
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.