-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Whisper, Bart, MBart] Add Flash Attention 2 #27203
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing piece of work! 🔥
Main comment is about the tests - I think some might be indexing on outputs
when it should be using outputs_fa
tests/test_modeling_common.py
Outdated
|
||
output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) | ||
logits = output.hidden_states[-1] | ||
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This... isn't that close. I can see it's the tolerance used elsewhere but seems like quite a big difference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, Flash attention leads to very much different results though. I think 0.04 is good enough tbh
input_dtype = query_states.dtype | ||
if input_dtype == torch.float32: | ||
# Handle the case where the model is quantized | ||
if hasattr(self.config, "_pre_quantization_dtype"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to have access to the config here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great. I appreciate the # Copied from
statements which make the code simpler to review.
Very clean, it's ok for me to merge
class BartEncoderLayer(nn.Module): | ||
def __init__(self, config: BartConfig): | ||
super().__init__() | ||
self.embed_dim = config.d_model | ||
self.self_attn = BartAttention( | ||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should eventually move this to an enum to be clenaer (out of scope for this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(brainstorming, still out of scope) It would be cleaner to eventually have the config return the appropriate attention name for all models:
self.self_attn = BART_ATTENTION_CLASSES[config.attention_type](
...
)
with
class PreTrainedConfig():
...
@property
def attention_type(self):
return AttentionTypes.FA2 if getattr(self, "_flash_attn_2_enabled", False) else AttentionTypes.DEFAULT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah attention_type
as a property is a good idea I think! We should then probably also allow users to change it even after the model was loaded
tests/test_modeling_common.py
Outdated
# make sure that all models have at least 40 position ids | ||
if hasattr(config, "max_position_embeddings"): | ||
config.max_position_embeddings = 40 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why that minimum?
Ok ran some more tests and it should be good now. I'm getting some flaky behavior with the flash attention tests on my RTX 4090 (especially extreme for Whisper). We should maybe think about how we can make them more robust now that we've added some more models (cc @younesbelkada) |
* add whisper fa2 * correct * change all * correct * correct * fix more * fix more * fix more * fix more * fix more * fix more * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix more * fix more * fix more * fix more * fix more --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
What does this PR do?
This PR adds Flash Attention for Whisper, Bart & MBart.
Whisper depends on Bart and MBart quite a bit for Flash Attention like 20+ other model architectures.
As this is the first PR that adds Flash Attention 2 to a encoder-decoder model, I wanted to make sure it's done for the two template models (Bart and MBart) as well so that Whisper (and all other encoder-decoder models that follow) don't loose their "# Copied from" statements.
Note that while this PR changes 27 files, only 4 files are really relevant to review because all other files are just consequences of the "# Copied from mechanism":
The following there files fully implement Flash Attention 2:
The test files is restructured so that Flash Attention 2 tests can nicely run for different kinds of models (audio & nlp as well as decoder-only and encoder-decoder).
I ran the following tests to make sure everything works as expected:
as well as:
All tests pass that also pass on "main". The only failures are related to disk offloading which should be fixed in: #27204
There are some "error not raised" failures for flash attn and mistral, but they are also present in "main" and seem to be related to this PR: #27125 (cc @younesbelkada), I'd suggest to also fix those in another PR.
Other CI test failures are unrelated.