-
Notifications
You must be signed in to change notification settings - Fork 3.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Attention Runtime Error for CLIP model #17729
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4 tasks
wangyems
previously approved these changes
Sep 28, 2023
wangyems
approved these changes
Sep 28, 2023
aciddelgado
approved these changes
Sep 28, 2023
snnn
pushed a commit
that referenced
this pull request
Sep 29, 2023
### Description The condition check is not correct ``` if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT } else { // BERT } ``` Change it to ``` if (is_unidirectional_) { // GPT } else { // BERT } ``` Another walkaround is to enable fused causal attention by adding an environment variable `ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1` before running stable diffusion. ### Motivation and Context Without the fix, optimized CLIP model of stable diffusion will encounter error in running Attention node: 2023-09-24 16:15:31.206037898 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Attention node. Name:'Attention_0' Status Message: /onnxruntime_src/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu:207 bool onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::mhaImpl::is_flash_attention(int) const interface->mHasCausalMask == false was false. Note that the bug has been there for a long time. It is just surfaced since we recently added a fusion for CLIP, which will trigger the error. We will add a comprehensive test for causal attention later to avoid such corner cases.
Cherry-picked to the 1.16.1 patch release |
kleiti
pushed a commit
to kleiti/onnxruntime
that referenced
this pull request
Mar 22, 2024
### Description The condition check is not correct ``` if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT } else { // BERT } ``` Change it to ``` if (is_unidirectional_) { // GPT } else { // BERT } ``` Another walkaround is to enable fused causal attention by adding an environment variable `ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1` before running stable diffusion. ### Motivation and Context Without the fix, optimized CLIP model of stable diffusion will encounter error in running Attention node: 2023-09-24 16:15:31.206037898 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Attention node. Name:'Attention_0' Status Message: /onnxruntime_src/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu:207 bool onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::mhaImpl::is_flash_attention(int) const interface->mHasCausalMask == false was false. Note that the bug has been there for a long time. It is just surfaced since we recently added a fusion for CLIP, which will trigger the error. We will add a comprehensive test for causal attention later to avoid such corner cases.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
The condition check is not correct
Change it to
There are two walkarounds using older ORT binary <= 1.16:
(1) Enable fused causal attention by adding an environment variable
ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1
before running stable diffusion. But the fused causal attention uses fp16 in accumulation so such walkaround might bring precision loss.(2) Disable attention fusion in CLIP. However, it will bring performance loss.
Motivation and Context
Without the fix, optimized CLIP model of stable diffusion will encounter error in running Attention node:
2023-09-24 16:15:31.206037898 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Attention node. Name:'Attention_0' Status Message: /onnxruntime_src/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu:207 bool onnxruntime::contrib::cuda::FusedMHARunnerFP16v2::mhaImpl::is_flash_attention(int) const interface->mHasCausalMask == false was false.
Note that the bug has been there for a long time. It is just surfaced since we recently added a fusion for CLIP, which will trigger the error.
We will add a comprehensive test for causal attention later to avoid such corner cases.