-
Notifications
You must be signed in to change notification settings - Fork 320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FlashAttention actually does not support attention mask #116
Comments
Hey @HJoonKwon! Damn, very good find, thank you! I guess this does matter in compiled forward, where we are padding inputs to static dimensions. We'd need to run the benchmarks, but maybe avoiding the call to |
@Phil26AT Great! Thank you again for your great work. I got inspired a lot. |
On the topic of FlashAttention, you link to FlashAttention and not FlashAttention2 here FlashAttention: https://arxiv.org/abs/2205.14135 |
Thanks for your great work!
I'm just curious whether your code here is using flash or not when mask is not
None
. My guess is it's using memory efficient attention instead since PyTorch flash attention kernel does not support attention mask. In addition, if memory efficient was used,half()
would not have been needed when mask is notNone
.Thank you!
++ I did some experiments. Even if sdp_flash is enabled, it is not executed when mask is not
None
. If we force PyTorch to use flash, it spits out an error like below.while memory efficient kernel does not
The text was updated successfully, but these errors were encountered: