-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[Kernel] [V1] Fix performance regression for triton unified attention #18161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
196faa8 to
c90dc87
Compare
| stride_k_cache_0: tl.int64, # int | ||
| stride_k_cache_1: tl.int64, # int | ||
| stride_k_cache_2: tl.int64, # int | ||
| stride_k_cache_3: tl.int64, # int | ||
| stride_k_cache_3: tl.constexpr, # int | ||
| stride_v_cache_0: tl.int64, # int | ||
| stride_v_cache_1: tl.int64, # int | ||
| stride_v_cache_2: tl.int64, # int | ||
| stride_v_cache_3: tl.int64, # int | ||
| stride_v_cache_3: tl.constexpr, # int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not make all of the strides constexpr to be safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would think its best to not to avoid over recompilation, but for the last stride it makes sense since this is almost always 1 (and when it is 1 we want the compiler to optimize around this, i.e. use wider loads)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoiding recompilation is one reason, but I also think for really long sequences there is a risk that the strides can overflow unless they are explicitly marked as tl.int64. This can't happen for the stride_k_cache_3 and stride_v_cache_3 though, so I think we are safe to do this.
|
Out of curiosity, are the benchmarks you reported CUDA or ROCm? |
I discussed this with @bringlein we should avoid subclassing FlashAttention, the flash attention API is not as stable as people think and does change from time to time. We should be able update FlashAttention without needing to worry about a bunch of downstream attention backends (otherwise we've kinda defeated the point of having the attention backend abstraction). E.g. right now by using A temporary workaround could be to do: this should fix the assert without having to subclass from we should work on a new PR to have a operate Triton metadata and in the process move things like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice find!
|
|
||
|
|
||
| class TritonAttentionImpl(AttentionImpl): | ||
| class TritonAttentionImpl(FlashAttentionImpl): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should do:
class TritonAttentionMetadata(FlashAttentionMetadata):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
block_table: BlockTable):
super().__init__(runner, kv_cache_spec, block_table)
self.aot_schedule = False
instead of this so we also avoid calls to the FA AOT scheduler unless you see an issue with this approach
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
++ on this as a temp solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, trying this now.
I will implement a clean version of TritonAttentionMetadata in a follow-on PR, I think there might be some benefits to the kernel if we do this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson nice catch. Is it good to go now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I similar workaround is done in the PR #16606
Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
d5054c1 to
64e7f13
Compare
@mgoin Benchmarks above are on H100, I believe @bringlein has also confirmed this behaviour on MI300x too. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for updating!
|
V1 test failures look unrelated (something to do with Eagle) |
…vllm-project#18161) Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
|
@tdoublep - Nice job! |
We have observed a pretty severe (~40%) performance regression for the
triton_unified_attentionkernel when moving from triton 3.2 to triton 3.3.After a lot of investigation, I was able to figure out that it came from this commit to Triton. This change reworked the way that Triton determines what kernel arguments are constant. It seems that before this PR, Triton was (correctly) detecting that
stride_k_cache_3andstride_v_cache_3are constant and leveraging this to obtain a faster kernel. For whatever reason, the new logic doesn't do this and we need to explicitly tell the compiler to interpret these strides as constant.This minor change recovers the performance.
There is one other small change: the TritonBackend no longer works on main because of this assert. Since the TritonBackend re-uses the
FlashAttentionMetadataandFlashAttentionMetadataBuilderdirectly (this is intentional to reduce duplicate code), this assert doesn't really make sense unlessTritonAttentionImplinherits fromFlashAttentionImpl. Happy to consider other ways to solve that, but would like to avoid create entirely newTritonAttentionMetadataetc that are just duplicates.cc @SageMoore @bringlein
Main branch:
With the changes from this PR: