Skip to content

Conversation

@tdoublep
Copy link
Member

@tdoublep tdoublep commented May 14, 2025

We have observed a pretty severe (~40%) performance regression for the triton_unified_attention kernel when moving from triton 3.2 to triton 3.3.

After a lot of investigation, I was able to figure out that it came from this commit to Triton. This change reworked the way that Triton determines what kernel arguments are constant. It seems that before this PR, Triton was (correctly) detecting that stride_k_cache_3 and stride_v_cache_3 are constant and leveraging this to obtain a faster kernel. For whatever reason, the new logic doesn't do this and we need to explicitly tell the compiler to interpret these strides as constant.

This minor change recovers the performance.

There is one other small change: the TritonBackend no longer works on main because of this assert. Since the TritonBackend re-uses the FlashAttentionMetadata and FlashAttentionMetadataBuilder directly (this is intentional to reduce duplicate code), this assert doesn't really make sense unless TritonAttentionImpl inherits from FlashAttentionImpl. Happy to consider other ways to solve that, but would like to avoid create entirely new TritonAttentionMetadata etc that are just duplicates.

cc @SageMoore @bringlein

Main branch:

++++++++++++++++++ Repetition 0 ++++++++++++++++++
============ Serving Benchmark Result ============
Successful requests:                     985       
Benchmark duration (s):                  26.18     
Total input tokens:                      209782    
Total generated tokens:                  125289    
Request throughput (req/s):              37.62     
Output token throughput (tok/s):         4785.69   
Total Token throughput (tok/s):          12798.78  
---------------Time to First Token----------------
Mean TTFT (ms):                          4889.40   
Median TTFT (ms):                        4844.86   
P99 TTFT (ms):                           8946.27   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          110.97    
Median TPOT (ms):                        53.41     
P99 TPOT (ms):                           390.86    
---------------Inter-token Latency----------------
Mean ITL (ms):                           46.36     
Median ITL (ms):                         28.78     
P99 ITL (ms):                            324.39    
==================================================

=++++++++++++++++++ Repetition 1 ++++++++++++++++++
============ Serving Benchmark Result ============
Successful requests:                     985       
Benchmark duration (s):                  24.18     
Total input tokens:                      209782    
Total generated tokens:                  125289    
Request throughput (req/s):              40.74     
Output token throughput (tok/s):         5181.56   
Total Token throughput (tok/s):          13857.49  
---------------Time to First Token----------------
Mean TTFT (ms):                          4455.20   
Median TTFT (ms):                        4315.81   
P99 TTFT (ms):                           8447.30   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          103.56    
Median TPOT (ms):                        49.74     
P99 TPOT (ms):                           312.85    
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.30     
Median ITL (ms):                         26.91     
P99 ITL (ms):                            322.73    
==================================================

++++++++++++++++++ Repetition 2 ++++++++++++++++++
============ Serving Benchmark Result ============
Successful requests:                     985       
Benchmark duration (s):                  24.19     
Total input tokens:                      209782    
Total generated tokens:                  125305    
Request throughput (req/s):              40.73     
Output token throughput (tok/s):         5181.00   
Total Token throughput (tok/s):          13854.88  
---------------Time to First Token----------------
Mean TTFT (ms):                          4453.85   
Median TTFT (ms):                        4283.36   
P99 TTFT (ms):                           8430.35   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          103.24    
Median TPOT (ms):                        49.79     
P99 TPOT (ms):                           311.08    
---------------Inter-token Latency----------------
Mean ITL (ms):                           43.17     
Median ITL (ms):                         26.91     
P99 ITL (ms):                            325.18    
==================================================

With the changes from this PR:

++++++++++++++++++ Repetition 0 ++++++++++++++++++
============ Serving Benchmark Result ============
Successful requests:                     985       
Benchmark duration (s):                  18.36     
Total input tokens:                      209782    
Total generated tokens:                  125289    
Request throughput (req/s):              53.64     
Output token throughput (tok/s):         6823.19   
Total Token throughput (tok/s):          18247.83  
---------------Time to First Token----------------
Mean TTFT (ms):                          3535.61   
Median TTFT (ms):                        3447.89   
P99 TTFT (ms):                           6488.40   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          78.28     
Median TPOT (ms):                        38.20     
P99 TPOT (ms):                           231.30    
---------------Inter-token Latency----------------
Mean ITL (ms):                           32.68     
Median ITL (ms):                         20.25     
P99 ITL (ms):                            235.62    
==================================================

++++++++++++++++++ Repetition 1 ++++++++++++++++++
============ Serving Benchmark Result ============
Successful requests:                     985       
Benchmark duration (s):                  17.13     
Total input tokens:                      209782    
Total generated tokens:                  125289    
Request throughput (req/s):              57.50     
Output token throughput (tok/s):         7313.20   
Total Token throughput (tok/s):          19558.30  
---------------Time to First Token----------------
Mean TTFT (ms):                          3358.98   
Median TTFT (ms):                        3277.33   
P99 TTFT (ms):                           6301.26   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          76.65     
Median TPOT (ms):                        36.04     
P99 TPOT (ms):                           230.60    
---------------Inter-token Latency----------------
Mean ITL (ms):                           30.80     
Median ITL (ms):                         18.56     
P99 ITL (ms):                            232.08    
==================================================


++++++++++++++++++ Repetition 2 ++++++++++++++++++
============ Serving Benchmark Result ============
Successful requests:                     985       
Benchmark duration (s):                  17.37     
Total input tokens:                      209782    
Total generated tokens:                  125289    
Request throughput (req/s):              56.72     
Output token throughput (tok/s):         7214.31   
Total Token throughput (tok/s):          19293.85  
---------------Time to First Token----------------
Mean TTFT (ms):                          3571.61   
Median TTFT (ms):                        3417.45   
P99 TTFT (ms):                           6460.17   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          76.82     
Median TPOT (ms):                        36.22     
P99 TPOT (ms):                           231.26    
---------------Inter-token Latency----------------
Mean ITL (ms):                           30.74     
Median ITL (ms):                         18.33     
P99 ITL (ms):                            234.79    
==================================================

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label May 14, 2025
@tdoublep tdoublep changed the title Fix performance regression for unified attention [Kernel] [V1] Fix performance regression for triton unified attention May 14, 2025
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep tdoublep force-pushed the tpa-fix-regression branch from 196faa8 to c90dc87 Compare May 14, 2025 18:11
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) May 14, 2025 18:30
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label May 14, 2025
Comment on lines 56 to +63
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
stride_k_cache_3: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int
stride_v_cache_0: tl.int64, # int
stride_v_cache_1: tl.int64, # int
stride_v_cache_2: tl.int64, # int
stride_v_cache_3: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make all of the strides constexpr to be safe?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think its best to not to avoid over recompilation, but for the last stride it makes sense since this is almost always 1 (and when it is 1 we want the compiler to optimize around this, i.e. use wider loads)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoiding recompilation is one reason, but I also think for really long sequences there is a risk that the strides can overflow unless they are explicitly marked as tl.int64. This can't happen for the stride_k_cache_3 and stride_v_cache_3 though, so I think we are safe to do this.

@mgoin
Copy link
Member

mgoin commented May 14, 2025

Out of curiosity, are the benchmarks you reported CUDA or ROCm?

@LucasWilkinson
Copy link
Collaborator

There is one other small change: the TritonBackend no longer works on main because of this assert. Since the TritonBackend re-uses the FlashAttentionMetadata and FlashAttentionMetadataBuilder directly (this is intentional to reduce duplicate code), this assert doesn't really make sense unless TritonAttentionImpl inherits from FlashAttentionImpl. Happy to consider other ways to solve that, but would like to avoid create entirely new TritonAttentionMetadata etc that are just duplicates.

I discussed this with @bringlein we should avoid subclassing FlashAttention, the flash attention API is not as stable as people think and does change from time to time. We should be able update FlashAttention without needing to worry about a bunch of downstream attention backends (otherwise we've kinda defeated the point of having the attention backend abstraction). E.g. right now by using FlashAttentionMetadataBuilder directly when building the Triton metadata you are also calling the new FA ahead of time scheduler kernel only to ignore the results.

A temporary workaround could be to do:

class TritonAttentionMetadata(FlashAttentionMetadata):
    def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
        super().__init__(runner, kv_cache_spec, block_table)
        self.aot_schedule = False

this should fix the assert without having to subclass from FlashAttentionImpl

we should work on a new PR to have a operate Triton metadata and in the process move things like make_local_attention_virtual_batches into common area (like utils.py or local_attn_common.py), and maybe move _get_sliding_window_configs to a common area to and abstract so we can pass in the the impl class

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice find!



class TritonAttentionImpl(AttentionImpl):
class TritonAttentionImpl(FlashAttentionImpl):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do:

class TritonAttentionMetadata(FlashAttentionMetadata):
    def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
                 block_table: BlockTable):
        super().__init__(runner, kv_cache_spec, block_table)
        self.aot_schedule = False

instead of this so we also avoid calls to the FA AOT scheduler unless you see an issue with this approach

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ on this as a temp solution

Copy link
Member Author

@tdoublep tdoublep May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, trying this now.

I will implement a clean version of TritonAttentionMetadata in a follow-on PR, I think there might be some benefits to the kernel if we do this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson nice catch. Is it good to go now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I similar workaround is done in the PR #16606

@LucasWilkinson LucasWilkinson disabled auto-merge May 14, 2025 18:48
tdoublep and others added 2 commits May 14, 2025 16:55
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
@tdoublep tdoublep force-pushed the tpa-fix-regression branch from d5054c1 to 64e7f13 Compare May 14, 2025 21:22
@tdoublep
Copy link
Member Author

Out of curiosity, are the benchmarks you reported CUDA or ROCm?

@mgoin Benchmarks above are on H100, I believe @bringlein has also confirmed this behaviour on MI300x too.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for updating!

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) May 15, 2025 13:15
@tdoublep
Copy link
Member Author

V1 test failures look unrelated (something to do with Eagle)

@DarkLight1337 DarkLight1337 added this to the v0.9.0 milestone May 15, 2025
@vllm-bot vllm-bot merged commit 01c2233 into vllm-project:main May 15, 2025
59 of 61 checks passed
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
…vllm-project#18161)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
@ekagra-ranjan
Copy link
Contributor

@tdoublep - Nice job!
which model was this benchmarked on in the PR description? Does it mean all model which use Flash attention will see 40% due to this PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.