Skip to content

Conversation

@jvlunteren
Copy link
Contributor

@jvlunteren jvlunteren commented Jul 18, 2025

This PR introduces modifications and extensions to the Triton unified attention kernel (#16828 and #19152) that enable support for hybrid models such Granite 4.0 which combines Mamba-2 and Transformer layers.

Key changes:

  • adopt FlashInfer-style KV cache layout (num_blocks, 2, block_size, num_kv_heads, head_size)

  • implement reorder_batch() function

  • support for larger, non-power-of-two block sizes

In addition, prefill support has been removed from the included split-KV attention kernel. This kernel is now dedicated solely to decode operations, as prefill typically involves sufficient parallelism to efficiently utilize GPU compute resources without requiring the split-KV optimization.
The removal of prefill support from the included split-KV attention kernel has been reverted to simplify the review process, as this change is orthogonal to the main focus of the PR. Prefill support will be addressed in a future PR.

The FlashInfer-style KV cache layout and the reorder_batch() function have also been removed, as these are not needed anymore.

Performance

This PR extends the functionality of the Triton unified attention kernel to support hybrid models. We need to ensure that these changes do not degrade performance for conventional models. This was validated by running benchmark_serving.py using the meta-llama/Llama-3.1-8B-Instruct model on an NVIDIA H100 GPU.

Current Triton unified attention kernel:

$ python benchmarks/benchmark_serving.py \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json
============ Serving Benchmark Result ============
Successful requests:                     984
Benchmark duration (s):                  21.93
Total input tokens:                      210752
Total generated tokens:                  195369
Request throughput (req/s):              44.87
Output token throughput (tok/s):         8909.60
Total Token throughput (tok/s):          18520.72
---------------Time to First Token----------------
Mean TTFT (ms):                          3822.98
Median TTFT (ms):                        3783.49
P99 TTFT (ms):                           6847.26
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          86.86
Median TPOT (ms):                        49.16
P99 TPOT (ms):                           232.71
---------------Inter-token Latency----------------
Mean ITL (ms):                           39.06
Median ITL (ms):                         26.22
P99 ITL (ms):                            236.50
==================================================

Updated Triton unified attention kernel (this PR):

$ python benchmarks/benchmark_serving.py \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --dataset-name sharegpt \
    --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json
============ Serving Benchmark Result ============
Successful requests:                     984
Benchmark duration (s):                  21.28
Total input tokens:                      211024
Total generated tokens:                  194422
Request throughput (req/s):              46.25
Output token throughput (tok/s):         9137.96
Total Token throughput (tok/s):          19056.22
---------------Time to First Token----------------
Mean TTFT (ms):                          3595.88
Median TTFT (ms):                        3526.50
P99 TTFT (ms):                           6669.08
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          85.63
Median TPOT (ms):                        48.81
P99 TPOT (ms):                           227.73
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.40
Median ITL (ms):                         25.91
P99 ITL (ms):                            233.44
==================================================

The above results confirm that performance remains stable, with a slight improvement observed.

Correctness

To verify correctness for conventional models, we compare lm_eval results between this PR and FlashAttention using the meta-llama/Llama-3.1-8B-Instruct model.

FlashAttention:

VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.792|±  |0.0182|
|     |       |strict-match    |     5|exact_match|↑  |0.770|±  |0.0188|

Updated Triton unified attention kernel (this PR):

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.782|±  |0.0185|

The correctness for hybrid models is validated by comparing this PR with FlashInfer using the ibm-granite/granite-4.0-tiny-base-preview model.

FlashInfer:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER lm_eval --model vllm --model_args pretrained=ibm-granite/granite-4.0-tiny-base-preview,enable_prefix_caching=False,enforce_eager=True,tensor_parallel_size=2 --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.604|±  |0.0219|
|     |       |strict-match    |     5|exact_match|↑  |0.576|±  |0.0221|

Updated Triton unified attention kernel (this PR):

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args pretrained=ibm-granite/granite-4.0-tiny-base-preview,enable_prefix_caching=False,enforce_eager=True,tensor_parallel_size=2 --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.610|±  |0.0218|
|     |       |strict-match    |     5|exact_match|↑  |0.592|±  |0.0220|

How were the changes realized?

Most of the changes targeted by this PR were implemented in a relatively straightforward manner. Support for larger, non-power-of-two block sizes was realized using a tiling strategy.

The current Triton unified kernel computes partial attention scores at the block level and merges them using the online-softmax approach. In this PR, the partial attention score calculation is instead performed at the tile level. These tiles have a power-of-two size that is independent of the block size and can be chosen to optimize GPU resource utilization.

For conventional models, the tile size is typically chosen to be equal to or larger than the block size. In contrast, for hybrid models with very large block sizes, a significantly smaller tile size is selected.

@tdoublep @bringlein

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Jul 18, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR introduces modifications to the Triton unified attention kernel to enable support for hybrid models. A critical issue was identified in vllm/v1/attention/backends/triton_attn.py regarding the KV cache layout that needs to be addressed before merging.

key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)
else:
key_cache, value_cache = kv_cache.unbind(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The line key_cache, value_cache = kv_cache.unbind(0) is incorrect given the new KV cache layout (num_blocks, 2, block_size, num_kv_heads, head_size). This will lead to runtime errors. It should be key_cache, value_cache = kv_cache.unbind(1) to correctly unbind the key and value caches.

key_cache, value_cache = kv_cache.unbind(1)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

key_cache, value_cache = kv_cache.unbind(1) is part of this PR. The gemini-code-assist seems to have missed that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are absolutely correct, jvlunteren! My apologies for the oversight. The change to key_cache, value_cache = kv_cache.unbind(1) is indeed present in this pull request (as shown in DIFF 4 of vllm/v1/attention/backends/triton_attn.py).

This change is crucial for correctly handling the new KV cache layout (num_blocks, 2, block_size, num_kv_heads, head_size), and I'm glad to see it's already implemented. Thank you for pointing out my mistake!

@mergify
Copy link

mergify bot commented Jul 21, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jvlunteren.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 21, 2025
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@mergify mergify bot removed the needs-rebase label Jul 21, 2025
jvlunteren and others added 2 commits July 21, 2025 15:40
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@jvlunteren
Copy link
Contributor Author

The removal of prefill support from the included split-KV attention kernel has been reverted to simplify the review process.

@mergify
Copy link

mergify bot commented Aug 2, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @jvlunteren.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 2, 2025
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@mergify mergify bot removed the needs-rebase label Aug 2, 2025
@jvlunteren
Copy link
Contributor Author

The FlashInfer-style KV cache layout and the reorder_batch() function have been removed, as these are not needed anymore.

# calculate the number of tiles that need to be processed to
# cover the longest sequence prefix (due to causal masking, tiles beyond
# this prefix can be skipped)
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How come the 3D kernel didn't need to compute max_seq_prefix_len before this change? It looks like the 2D kernel did need to.

Copy link
Contributor Author

@jvlunteren jvlunteren Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value max_seq_prefix_len relates to the number of tokens preceding the last query token in a Q block (for prefill, there can be multiple query tokens in a Q block), and is determined also considering potential padding. Because the 3D kernel is only used for decodes and therefore each Q block will only contain one query token, consequently the number of tokens preceding that single query token can be determined directly from the sequence length and max_seq_prefix_len does need to be calculated.

When I added prefill support back to 3d (split-kv) kernel, to simplify the review process (c4f7d67), I also added using max_seq_prefix_len back to the 3d kernel for the same purpose, by making the 2d and 3d kernels very similar.

In order to synchronize the branch of the fork involved in this PR with the vLLM main branch, it was easier to temporarily remove my updates, do the synchronization with the main vLLM branch, and then apply the updates again. Because the 3d kernel in the vLLM main branch does not use max_seq_prefix_len, it looks like the update only happened with the last change, but as indicated above, it was already included with commit c4f7d67. In a follow-up PR, I intend to simplify the 3d kernel by removing all functionality needed to support prefills, including the use of max_seq_prefix_len.

Comment on lines 725 to 726
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we set these to 32 by default? If I understand correctly, if we want to keep the default behaviour the same as on main we should set them to 16? I'm not saying we need to do that, just would be good to justify the change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm especially surprised that the decode would benefit from using a value here bigger than the block size

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tile sizes for prefill and decode were set to 32 to prevent issues with tl.dot which imposes certain restrictions on the shapes of the tensors that are multiplied, which also depend on the data type (e.g., fp8).

I ran several experiments with a few models, and it appeared that slightly larger tile sizes worked better for prefill, while a tile size of 16 or 32 worked fine for decode. However, this may be different for other models.

Ideally, these tile sizes should be tuned dependent on GPU and model characteristics. I will update the assignment of the default values.

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@jvlunteren
Copy link
Contributor Author

jvlunteren commented Sep 17, 2025

A check previously based on the minimum block size required by the tl.dot operation (also taking into account the data type) for tensor multiplication in the attention kernel has been replaced with a check for the tile size, as the latter now determines the shapes of the tensors involved in the computation. The corresponding test test_triton_uniform_attention.py has also been updated.

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
@jvlunteren
Copy link
Contributor Author

The default tile sizes for prefill and decode are now assigned to always satisfy the shape constraints imposed by tl.dot. The check has been removed.

Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This can enable large block size support for hybrid models, but also makes it significantly easier to tune the tile size in the future.

@tdoublep tdoublep enabled auto-merge (squash) September 18, 2025 07:48
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 18, 2025
@tdoublep tdoublep merged commit 01a583f into vllm-project:main Sep 18, 2025
42 checks passed
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
…on Kernel (vllm-project#21197)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…on Kernel (vllm-project#21197)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…on Kernel (vllm-project#21197)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: charlifu <charlifu@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…on Kernel (vllm-project#21197)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
…on Kernel (vllm-project#21197)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…on Kernel (vllm-project#21197)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…on Kernel (vllm-project#21197)

Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants