-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix flash attention bugs with Mistral and Falcon #27625
Conversation
@@ -838,7 +838,7 @@ def forward( | |||
attention_mask is not None | |||
and hasattr(self.config, "_flash_attn_2_enabled") | |||
and self.config._flash_attn_2_enabled | |||
and past_key_values is not None | |||
and use_cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the first autoregressive pass, past_key_values
is None. In the following passes, we may have masks as
[[1, 1, 0, 0, 1],
[1, 1, 1, 1, 1]]
and is_padding_right
is wrongfully evaluated to False.
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clean ! Thanks !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
model.save_pretrained(tmpdirname) | ||
|
||
dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) | ||
# NOTE: Mistral apparently does not support right padding + use_cache with FA2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it works but you'll get terrible results because the cache will cut the non padded values first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding!
I just have a small Q about the tests
Hi @amyeroberts thank you for the review! I added by mistake |
@fxmarty Thanks for clarifying. Tbh, I'm still a bit confused with the tests - it's not clear to me how this explicitly tests for the cache as |
@amyeroberts good catch indeed... I just checked, we are going here that sets transformers/src/transformers/generation/utils.py Line 1602 in f93c1e9
and uses
|
@fxmarty OK - thanks for explaining! As a follow up, could you add |
For sure @amyeroberts I will ping you there. Sorry I should have waited before merging.. |
@fxmarty No worries! It doesn't affect the functionality of this PR so it's fine to be done separately :) |
This PR fixes some important bugs in the Mistral and Falcon integration.
#26933 broke flash attention for Falcon due to the modification of the layout
The following tests were not passing:
and Falcon with FA2 is not really usable on
main
due to an error in the shape (currently[batch_size, num_head, seqlen, head_dim]
instead of the required[batch_size, seqlen, num_head, head_dim]
.