-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Kernel] Enable Hybrid Model Support in Triton Unified Attention Kernel #21197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR introduces modifications to the Triton unified attention kernel to enable support for hybrid models. A critical issue was identified in vllm/v1/attention/backends/triton_attn.py regarding the KV cache layout that needs to be addressed before merging.
| key_cache, value_cache = PagedAttention.split_kv_cache( | ||
| kv_cache, self.num_kv_heads, self.head_size) | ||
| else: | ||
| key_cache, value_cache = kv_cache.unbind(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The line key_cache, value_cache = kv_cache.unbind(0) is incorrect given the new KV cache layout (num_blocks, 2, block_size, num_kv_heads, head_size). This will lead to runtime errors. It should be key_cache, value_cache = kv_cache.unbind(1) to correctly unbind the key and value caches.
key_cache, value_cache = kv_cache.unbind(1)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
key_cache, value_cache = kv_cache.unbind(1) is part of this PR. The gemini-code-assist seems to have missed that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are absolutely correct, jvlunteren! My apologies for the oversight. The change to key_cache, value_cache = kv_cache.unbind(1) is indeed present in this pull request (as shown in DIFF 4 of vllm/v1/attention/backends/triton_attn.py).
This change is crucial for correctly handling the new KV cache layout (num_blocks, 2, block_size, num_kv_heads, head_size), and I'm glad to see it's already implemented. Thank you for pointing out my mistake!
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
|
The removal of prefill support from the included split-KV attention kernel has been reverted to simplify the review process. |
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
|
The FlashInfer-style KV cache layout and the |
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
| # calculate the number of tiles that need to be processed to | ||
| # cover the longest sequence prefix (due to causal masking, tiles beyond | ||
| # this prefix can be skipped) | ||
| num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How come the 3D kernel didn't need to compute max_seq_prefix_len before this change? It looks like the 2D kernel did need to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value max_seq_prefix_len relates to the number of tokens preceding the last query token in a Q block (for prefill, there can be multiple query tokens in a Q block), and is determined also considering potential padding. Because the 3D kernel is only used for decodes and therefore each Q block will only contain one query token, consequently the number of tokens preceding that single query token can be determined directly from the sequence length and max_seq_prefix_len does need to be calculated.
When I added prefill support back to 3d (split-kv) kernel, to simplify the review process (c4f7d67), I also added using max_seq_prefix_len back to the 3d kernel for the same purpose, by making the 2d and 3d kernels very similar.
In order to synchronize the branch of the fork involved in this PR with the vLLM main branch, it was easier to temporarily remove my updates, do the synchronization with the main vLLM branch, and then apply the updates again. Because the 3d kernel in the vLLM main branch does not use max_seq_prefix_len, it looks like the update only happened with the last change, but as indicated above, it was already included with commit c4f7d67. In a follow-up PR, I intend to simplify the 3d kernel by removing all functionality needed to support prefills, including the use of max_seq_prefix_len.
| TILE_SIZE_PREFILL = 32 | ||
| TILE_SIZE_DECODE = 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we set these to 32 by default? If I understand correctly, if we want to keep the default behaviour the same as on main we should set them to 16? I'm not saying we need to do that, just would be good to justify the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm especially surprised that the decode would benefit from using a value here bigger than the block size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tile sizes for prefill and decode were set to 32 to prevent issues with tl.dot which imposes certain restrictions on the shapes of the tensors that are multiplied, which also depend on the data type (e.g., fp8).
I ran several experiments with a few models, and it appeared that slightly larger tile sizes worked better for prefill, while a tile size of 16 or 32 worked fine for decode. However, this may be different for other models.
Ideally, these tile sizes should be tuned dependent on GPU and model characteristics. I will update the assignment of the default values.
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
|
A check previously based on the minimum block size required by the tl.dot operation (also taking into account the data type) for tensor multiplication in the attention kernel has been replaced with a check for the tile size, as the latter now determines the shapes of the tensors involved in the computation. The corresponding test |
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
|
The default tile sizes for prefill and decode are now assigned to always satisfy the shape constraints imposed by |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This can enable large block size support for hybrid models, but also makes it significantly easier to tune the tile size in the future.
…on Kernel (vllm-project#21197) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
…on Kernel (vllm-project#21197) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
…on Kernel (vllm-project#21197) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: charlifu <charlifu@amd.com>
…on Kernel (vllm-project#21197) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…on Kernel (vllm-project#21197) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
…on Kernel (vllm-project#21197) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
…on Kernel (vllm-project#21197) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
This PR introduces modifications and extensions to the Triton unified attention kernel (#16828 and #19152) that enable support for hybrid models such Granite 4.0 which combines Mamba-2 and Transformer layers.
Key changes:
adopt FlashInfer-style KV cache layout(num_blocks, 2, block_size, num_kv_heads, head_size)implementreorder_batch()functionsupport for larger, non-power-of-two block sizes
In addition, prefill support has been removed from the included split-KV attention kernel. This kernel is now dedicated solely to decode operations, as prefill typically involves sufficient parallelism to efficiently utilize GPU compute resources without requiring the split-KV optimization.The removal of prefill support from the included split-KV attention kernel has been reverted to simplify the review process, as this change is orthogonal to the main focus of the PR. Prefill support will be addressed in a future PR.
The FlashInfer-style KV cache layout and the
reorder_batch()function have also been removed, as these are not needed anymore.Performance
This PR extends the functionality of the Triton unified attention kernel to support hybrid models. We need to ensure that these changes do not degrade performance for conventional models. This was validated by running
benchmark_serving.pyusing themeta-llama/Llama-3.1-8B-Instructmodel on an NVIDIA H100 GPU.Current Triton unified attention kernel:
$ python benchmarks/benchmark_serving.py \ --model meta-llama/Llama-3.1-8B-Instruct \ --dataset-name sharegpt \ --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ============ Serving Benchmark Result ============ Successful requests: 984 Benchmark duration (s): 21.93 Total input tokens: 210752 Total generated tokens: 195369 Request throughput (req/s): 44.87 Output token throughput (tok/s): 8909.60 Total Token throughput (tok/s): 18520.72 ---------------Time to First Token---------------- Mean TTFT (ms): 3822.98 Median TTFT (ms): 3783.49 P99 TTFT (ms): 6847.26 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 86.86 Median TPOT (ms): 49.16 P99 TPOT (ms): 232.71 ---------------Inter-token Latency---------------- Mean ITL (ms): 39.06 Median ITL (ms): 26.22 P99 ITL (ms): 236.50 ==================================================Updated Triton unified attention kernel (this PR):
$ python benchmarks/benchmark_serving.py \ --model meta-llama/Llama-3.1-8B-Instruct \ --dataset-name sharegpt \ --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ============ Serving Benchmark Result ============ Successful requests: 984 Benchmark duration (s): 21.28 Total input tokens: 211024 Total generated tokens: 194422 Request throughput (req/s): 46.25 Output token throughput (tok/s): 9137.96 Total Token throughput (tok/s): 19056.22 ---------------Time to First Token---------------- Mean TTFT (ms): 3595.88 Median TTFT (ms): 3526.50 P99 TTFT (ms): 6669.08 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 85.63 Median TPOT (ms): 48.81 P99 TPOT (ms): 227.73 ---------------Inter-token Latency---------------- Mean ITL (ms): 38.40 Median ITL (ms): 25.91 P99 ITL (ms): 233.44 ==================================================The above results confirm that performance remains stable, with a slight improvement observed.
Correctness
To verify correctness for conventional models, we compare
lm_evalresults between this PR andFlashAttentionusing themeta-llama/Llama-3.1-8B-Instructmodel.FlashAttention:Updated Triton unified attention kernel (this PR):
The correctness for hybrid models is validated by comparing this PR with
FlashInferusing theibm-granite/granite-4.0-tiny-base-previewmodel.FlashInfer:Updated Triton unified attention kernel (this PR):
How were the changes realized?
Most of the changes targeted by this PR were implemented in a relatively straightforward manner. Support for larger, non-power-of-two block sizes was realized using a tiling strategy.
The current Triton unified kernel computes partial attention scores at the block level and merges them using the online-softmax approach. In this PR, the partial attention score calculation is instead performed at the tile level. These tiles have a power-of-two size that is independent of the block size and can be chosen to optimize GPU resource utilization.
For conventional models, the tile size is typically chosen to be equal to or larger than the block size. In contrast, for hybrid models with very large block sizes, a significantly smaller tile size is selected.
@tdoublep @bringlein