- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
[Kernel] Optimize Prefill Attention in Unified Triton Attention Kernel #20308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Optimize Prefill Attention in Unified Triton Attention Kernel #20308
Conversation
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @jvlunteren, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request focuses on optimizing the prefill attention phase within the Triton Unified Attention Kernel. The core improvement involves intelligently skipping unnecessary computations for fully masked tiles, leading to enhanced efficiency, especially for processing long input sequences. This results in notable performance gains across various prompt lengths while maintaining model correctness.
Highlights
- Performance Optimization: Implemented an optimization to the Triton Unified Attention Kernel specifically targeting the prefill attention phase.
- Causal Mask Leveraging: The optimization works by leveraging the causal mask to identify and skip processing of fully masked tiles, thereby reducing redundant computations during prefill.
- Significant Speedup for Long Prompts: Achieves significant performance improvements for very long prompts, demonstrating up to a 1.75x speedup for 16,000-token inputs on NVIDIA H100 GPUs.
- General Throughput and Latency Gains: Even for typical serving workloads with shorter prompts, the PR shows measurable gains, including approximately 3% improvement in request throughput and over 5% reductions in Time to First Token (TTFT), Time per Output Token (TPOT), and Inter-token Latency (ITL).
- Correctness Verified: Correctness benchmarks using lm_evalconfirm that the optimization maintains or slightly improves model accuracy metrics, ensuring no regression in output quality.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description | 
|---|---|---|
| Code Review | /gemini review | Performs a code review for the current pull request in its current state. | 
| Pull Request Summary | /gemini summary | Provides a summary of the current pull request in its current state. | 
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. | 
| Help | /gemini help | Displays a list of available commands. | 
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
- 
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩ 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This PR optimizes the unified triton attention kernel by reducing the number of tiles processed during the prefill phase, leading to significant performance improvements, especially for long prompts. The change is well-justified by the performance benchmarks. A suggestion has been made to improve the readability of the core calculation for better maintainability.
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
| 👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run  Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add  🚀 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good @jvlunteren. Thanks for the contribution!
| num_blocks = cdiv_fn( | ||
| tl.minimum( | ||
| context_len + q_block_local_idx * BLOCK_Q + | ||
| (BLOCK_M - 1) // num_queries_per_kv + 1, seq_len), BLOCK_SIZE) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth adding a comment to explain the optimization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice find
Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
vllm-project#20308) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
vllm-project#20308) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
vllm-project#20308) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
vllm-project#20308) Signed-off-by: Jan van Lunteren <jvl@zurich.ibm.com>
This PR introduces an optimization to the unified triton attention kernel (#16828 and #19152) that enhances prefill attention performance. The key improvement involves reducing the number of tiles processed during the prefill phase by leveraging the causal mask to skip unnecessary computations. This results in more efficient execution, particularly for long prompts.
Performance
The following results were obtained for
meta-llama/Llama-3.1-8B-Instructon an NVIDIA H100 GPU, by running$ VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1 VLLM_USE_V1=1 python benchmarks/benchmark_latency.py \ --model meta-llama/Llama-3.1-8B-Instruct \ --input-len <input-length> --output-len 4 \ --batch-size <batch-size>for
Results for a batch size 1 are shown in the following graph. The input (prompt) length (in tokens) was varied in these experiments across the following values: 500, 1000, 1500, 2000, 4000, 8000, and 16000. The number of warmup iterations and measurement iterations were left at the default values of 10 and 30 respectively.
As illustrated in the graph above, this PR improves the performance of the Triton Unified Attention Kernel by approximately 1.75 times for a batch size of 1 and an input length of 16000 tokens.
Additional results were collected using
benchmark_serving.py, which only includes sequence lengths under 2000 tokens:Current triton unified attention kernel:
$ python benchmarks/benchmark_serving.py \ --model meta-llama/Llama-3.1-8B-Instruct \ --dataset-name sharegpt \ --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ============ Serving Benchmark Result ============ Successful requests: 984 Benchmark duration (s): 22.18 Total input tokens: 210771 Total generated tokens: 195009 Request throughput (req/s): 44.37 Output token throughput (tok/s): 8793.44 Total Token throughput (tok/s): 18297.62 ---------------Time to First Token---------------- Mean TTFT (ms): 3874.12 Median TTFT (ms): 3715.54 P99 TTFT (ms): 7060.57 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 88.60 Median TPOT (ms): 51.50 P99 TPOT (ms): 233.82 ---------------Inter-token Latency---------------- Mean ITL (ms): 40.26 Median ITL (ms): 25.51 P99 ITL (ms): 239.07 ==================================================Updated triton unified attention kernel (this PR):
$ python benchmarks/benchmark_serving.py \ --model meta-llama/Llama-3.1-8B-Instruct \ --dataset-name sharegpt \ --dataset-path ShareGPT_V3_unfiltered_cleaned_split.json ============ Serving Benchmark Result ============ Successful requests: 984 Benchmark duration (s): 21.44 Total input tokens: 210460 Total generated tokens: 195875 Request throughput (req/s): 45.90 Output token throughput (tok/s): 9137.19 Total Token throughput (tok/s): 18954.74 ---------------Time to First Token---------------- Mean TTFT (ms): 3588.36 Median TTFT (ms): 3478.75 P99 TTFT (ms): 6540.15 -----Time per Output Token (excl. 1st token)------ Mean TPOT (ms): 83.17 Median TPOT (ms): 47.72 P99 TPOT (ms): 220.70 ---------------Inter-token Latency---------------- Mean ITL (ms): 38.12 Median ITL (ms): 25.28 P99 ITL (ms): 223.90Despite the relatively short prompt lengths used in this benchmark, the results still demonstrate a ~3% improvement in throughput, along with over 5% reductions in latency metrics (TTFT, TPOT, and ITL).
Correctness
V1
FlashAttention:Updated triton unified attention kernel (this PR):
How is this performance improvement achieved?
The triton unified attention kernel employs a loop that iteratively processes multiple tiles, computing attention locally for each tile and accumulating the results across tiles to form the final output. During prefill processing, a causal mask is applied to each tile to ensure that attention is computed only over past and current tokens. In the current implementation, up to half of the tiles may be fully masked out during processing, resulting in redundant computation and reduced efficiency. This PR addresses the issue by skipping such tiles, ensuring that only those containing unmasked tokens are processed.
cc @tdoublep