Skip to content

Conversation

@bringlein
Copy link
Contributor

@bringlein bringlein commented Jul 9, 2025

This PR introduces tune-able block sizes to the unified attention kernel that enhances prefill attention performance. For now, we use simple heuristics to determine the right block sizes, but we intend to tune them more for targeted platforms in the very near future.

Performance

benchmark_latency.py

On H100, with this PR:

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 python benchmarks/benchmark_latency.py \
    --model meta-llama/Llama-3.1-8B-Instruct \
    --input-len 10000 --output-len 4 \
    --batch-size 1
...
Avg latency: 0.3723039971664548 seconds
10% percentile latency: 0.3671667315065861 seconds
25% percentile latency: 0.368670291849412 seconds
50% percentile latency: 0.37277518305927515 seconds
75% percentile latency: 0.3755023997509852 seconds
90% percentile latency: 0.3784640687983483 seconds
99% percentile latency: 0.3798159270081669 seconds

on H100, current upstream using the triton_attn backend:

Avg latency: 0.6647519522656997 seconds
10% percentile latency: 0.6590904780197888 seconds
25% percentile latency: 0.6594004178186879 seconds
50% percentile latency: 0.660066019045189 seconds
75% percentile latency: 0.6703241039067507 seconds
90% percentile latency: 0.6784501124173403 seconds
99% percentile latency: 0.6801793645089492 seconds

So, this PR decreases the latency of prefill by 78%.

benchmark_serving.py

On H100, with this PR:

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 vllm serve meta-llama/Llama-3.1-8B-Instruct --disable-log-requests --no-enable-prefix-caching

python benchmarks/benchmark_serving.py --model /models/hf/meta-llama/Llama-3.1-8B-Instruct/main/ --dataset-name sharegpt --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json
...
============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  20.32     
Total input tokens:                      215196    
Total generated tokens:                  197261    
Request throughput (req/s):              49.22     
Output token throughput (tok/s):         9708.59   
Total Token throughput (tok/s):          20299.89  
---------------Time to First Token----------------
Mean TTFT (ms):                          3460.93   
Median TTFT (ms):                        3277.94   
P99 TTFT (ms):                           6428.40   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          81.76     
Median TPOT (ms):                        45.50     
P99 TPOT (ms):                           214.88    
---------------Inter-token Latency----------------
Mean ITL (ms):                           36.51     
Median ITL (ms):                         25.51     
P99 ITL (ms):                            217.66    
==================================================

before, on H100 with current upstream using the triton_attn backend:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  21.20     
Total input tokens:                      215196    
Total generated tokens:                  198133    
Request throughput (req/s):              47.16     
Output token throughput (tok/s):         9344.13   
Total Token throughput (tok/s):          19492.97  
---------------Time to First Token----------------
Mean TTFT (ms):                          3613.11   
Median TTFT (ms):                        3433.86   
P99 TTFT (ms):                           6816.68   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          86.69     
Median TPOT (ms):                        48.43     
P99 TPOT (ms):                           228.71    
---------------Inter-token Latency----------------
Mean ITL (ms):                           38.26     
Median ITL (ms):                         25.77     
P99 ITL (ms):                            233.85    
==================================================

So, also here, this PR improves throughput, TTFT, and ITL about 4%.

Correctness

With this PR on an H100:

VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.798|±  |0.0180|
|     |       |strict-match    |     5|exact_match|↑  |0.780|±  |0.0185|

More Context

The optimization allows the unified attention kernel to adapt the distribution of the work across compute units / streaming multiprocessors depending on the length of the requests in a batch.
While, depending on the use-case, the performance increases are only modest (4%), this PR enables more fine-grained tuning of the unified attention kernel in the future. E.g. already with the vary basic heuristic, the performance increase is 78% for prefill-heavy latency use cases.

CC @jvlunteren @tdoublep @SageMoore

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @bringlein, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces dynamic, tune-able block sizes to the unified 2D Triton attention kernel, allowing it to adapt its computational strategy based on input sequence lengths. This enhancement leads to substantial performance improvements, particularly for prefill operations, and establishes a foundation for future fine-grained kernel optimizations.

Highlights

  • Tune-able Block Sizes: The unified 2D Triton attention kernel (kernel_unified_attention_2d) now supports tune-able BLOCK_M (query block size) and BLOCK_N (key/value block size) parameters, allowing for more flexible work distribution.
  • Dynamic Heuristics for Block Sizes: The unified_attention function now dynamically sets BLOCK_M based on max_seqlen_q and BLOCK_N based on max_seqlen_k using simple heuristics. This enables the kernel to adapt its block sizes to different sequence lengths for improved performance.
  • Performance Enhancements: This optimization significantly improves prefill attention performance, showing up to a 78% latency reduction on H100 for specific benchmarks. It also provides modest gains (around 4%) for serving throughput and time-to-first-token (TTFT).
  • Kernel Logic Refactoring: The internal logic of kernel_unified_attention_2d has been refactored to accommodate the new block sizes. This includes deriving BLOCK_Q internally, replacing the find_seq_idx helper with an inline binary search, and adjusting tensor dimensions and loop iterations for attention scores and accumulators to align with BLOCK_Q and BLOCK_N.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces tunable block sizes to the unified attention kernel, which is a great optimization that yields impressive performance improvements. The implementation is solid. My feedback focuses on enhancing code clarity and maintainability by simplifying some conditions and correcting comments. Overall, this is a valuable contribution.

bringlein and others added 2 commits July 9, 2025 11:09
Co-authored-by: Tom Parnell <tpa@zurich.ibm.com>
Co-authored-by: Jan Van Lunteren <jvl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@bringlein bringlein force-pushed the ngl_tuneable_2d_kernel_pr branch from 4923046 to e99a139 Compare July 9, 2025 15:11
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@bringlein bringlein force-pushed the ngl_tuneable_2d_kernel_pr branch from 70146a4 to 3556c2a Compare July 9, 2025 15:15
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@github-actions
Copy link

github-actions bot commented Jul 9, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@bringlein
Copy link
Contributor Author

While preparing the PR, I merged with an old version of the kernel (see above). I fixed it now, sorry for that. Neither performance nor correctness were affected:

Updated latency (including #18100)

Avg latency: 0.36923687839880587 seconds
10% percentile latency: 0.3676671689376235 seconds
25% percentile latency: 0.3679791371105239 seconds
50% percentile latency: 0.36859156284481287 seconds
75% percentile latency: 0.36941709427628666 seconds
90% percentile latency: 0.3715606901794672 seconds
99% percentile latency: 0.3773039229772985 seconds

Updated performance (including #18100)

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  20.13     
Total input tokens:                      215196    
Total generated tokens:                  197127    
Request throughput (req/s):              49.67     
Output token throughput (tok/s):         9791.90   
Total Token throughput (tok/s):          20481.34  
---------------Time to First Token----------------
Mean TTFT (ms):                          3524.92   
Median TTFT (ms):                        3430.91   
P99 TTFT (ms):                           6502.68   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          80.98     
Median TPOT (ms):                        45.43     
P99 TPOT (ms):                           213.69    
---------------Inter-token Latency----------------
Mean ITL (ms):                           36.11     
Median ITL (ms):                         24.31     
P99 ITL (ms):                            216.08    
==================================================

Updated correctness (including #18100)

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.794|±  |0.0181|
|     |       |strict-match    |     5|exact_match|↑  |0.776|±  |0.0187|
pytest tests/kernels/attention/test_triton_unified_attention.py: 864 passed, 288 skipped in 364.42s

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@bringlein
Copy link
Contributor Author

After feedback from @SageMoore, I tried to balance the preliminary heuristics used for BLOCK_N/BLOCK_M to improve the performance for short-prefills on MI300. As stated above, this is only the first step to enable more platform-specific tuning in the near future. And despite the new heuristics, all performance and correctness numbers reported above are unchanged.

@SageMoore
Copy link
Contributor

Hi, @bringlein

We discussed this a bit offline but it looks like theres a slight regression in performance at higher qps rates. Here's some results that were collected on main and your PR with the following script.

MODEL=meta-llama/Llama-3.1-8B-Instruct
REQUEST_RATES=(20)
TOTAL_SECONDS=120
for REQUEST_RATE in "${REQUEST_RATES[@]}";
do
    NUM_PROMPTS=$(($TOTAL_SECONDS * $REQUEST_RATE))
    echo ""
    echo "===== RUNNING $MODEL FOR $NUM_PROMPTS PROMPTS WITH $REQUEST_RATE QPS ====="
    echo ""
    python3 benchmarks/benchmark_serving.py \
        --model $MODEL \
        --dataset-name random \
        --ignore-eos \
        --num-prompts $NUM_PROMPTS \
        --request-rate $REQUEST_RATE \
done

Results from main

============ Serving Benchmark Result ============ 
Successful requests: 2400 
Benchmark duration (s): 130.36 
Total input tokens: 2452447 
Total generated tokens: 307200 
Request throughput (req/s): 18.41 
Output token throughput (tok/s): 2356.54 
Total Token throughput (tok/s): 21169.32 
---------------Time to First Token---------------- 
Mean TTFT (ms): 298.33 
Median TTFT (ms): 276.30 
P99 TTFT (ms): 726.43 
-----Time per Output Token (excl. 1st token)------ 
Mean TPOT (ms): 121.86 
Median TPOT (ms): 131.66 
P99 TPOT (ms): 161.00 
---------------Inter-token Latency---------------- 
Mean ITL (ms): 121.86 
Median ITL (ms): 112.99 
P99 ITL (ms): 317.65 
==================================================

Results from this PR

============ Serving Benchmark Result ============ 
Successful requests: 2400 
Benchmark duration (s): 130.24 
Total input tokens: 2452447 
Total generated tokens: 307200 
Request throughput (req/s): 18.43 
Output token throughput (tok/s): 2358.79 
Total Token throughput (tok/s): 21189.53
---------------Time to First Token---------------- 
Mean TTFT (ms): 346.05 
Median TTFT (ms): 327.21 
P99 TTFT (ms): 801.77 
-----Time per Output Token (excl. 1st token)------ 
Mean TPOT (ms): 149.42 
Median TPOT (ms): 176.69 
P99 TPOT (ms): 199.16
---------------Inter-token Latency----------------
Mean ITL (ms): 149.42 
Median ITL (ms): 138.09 
P99 ITL (ms): 329.11 
==================================================

It's not clear to me if the performance tradeoffs are worth it in this case. Obviously the long prefill improvements are great so this may just be a tradeoff that we want to make CC: @gshtras @robertgshaw2-redhat

bringlein added a commit to foundation-model-stack/vllm-triton-backend that referenced this pull request Jul 16, 2025
This PR adds the tunable 2d attention kernel (similar to vllm-project/vllm#20690) with tuning using the micro-benchmarks already carried out for H100 and MI300.

---------

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
@mergify
Copy link

mergify bot commented Jul 17, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @bringlein.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jul 17, 2025
@github-actions
Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Oct 16, 2025
@github-actions
Copy link

This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you!

@github-actions github-actions bot closed this Nov 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase stale Over 90 days of inactivity v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants