Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper, Bart, MBart] Add Flash Attention 2 #27203

Merged
merged 18 commits into from
Nov 1, 2023
Merged

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Nov 1, 2023

What does this PR do?

This PR adds Flash Attention for Whisper, Bart & MBart.

Whisper depends on Bart and MBart quite a bit for Flash Attention like 20+ other model architectures.
As this is the first PR that adds Flash Attention 2 to a encoder-decoder model, I wanted to make sure it's done for the two template models (Bart and MBart) as well so that Whisper (and all other encoder-decoder models that follow) don't loose their "# Copied from" statements.

Note that while this PR changes 27 files, only 4 files are really relevant to review because all other files are just consequences of the "# Copied from mechanism":
The following there files fully implement Flash Attention 2:

  • src/transformers/models/bart/modeling_bart.py
  • src/transformers/models/mbart/modeling_mbart.py
  • src/transformers/models/whisper/modeling_whisper.py
    The test files is restructured so that Flash Attention 2 tests can nicely run for different kinds of models (audio & nlp as well as decoder-only and encoder-decoder).
  • tests/test_modeling_common.py

I ran the following tests to make sure everything works as expected:

CUDA_VISIBLE_DEVICES="0" RUN_SLOW=1 pytest tests/models/whisper/test_modeling_whisper.py
CUDA_VISIBLE_DEVICES="0" RUN_SLOW=1 pytest tests/models/mbart/test_modeling_mbart.py
CUDA_VISIBLE_DEVICES="0" RUN_SLOW=1 pytest tests/models/bart/test_modeling_bart.py

as well as:

RUN_SLOW=1 pytest -m flash_attn_test tests

All tests pass that also pass on "main". The only failures are related to disk offloading which should be fixed in: #27204

There are some "error not raised" failures for flash attn and mistral, but they are also present in "main" and seem to be related to this PR: #27125 (cc @younesbelkada), I'd suggest to also fix those in another PR.

Other CI test failures are unrelated.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 1, 2023

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten patrickvonplaten changed the title [Whisper] Add Flash Attention 2 [Whisper, Bart, MBart] Add Flash Attention 2 Nov 1, 2023
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing piece of work! 🔥

Main comment is about the tests - I think some might be indexing on outputs when it should be using outputs_fa

src/transformers/models/bart/modeling_bart.py Show resolved Hide resolved
src/transformers/models/bart/modeling_bart.py Outdated Show resolved Hide resolved
src/transformers/models/bart/modeling_bart.py Outdated Show resolved Hide resolved
src/transformers/models/bart/modeling_bart.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved
tests/test_modeling_common.py Outdated Show resolved Hide resolved

output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
logits = output.hidden_states[-1]
self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This... isn't that close. I can see it's the tolerance used elsewhere but seems like quite a big difference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, Flash attention leads to very much different results though. I think 0.04 is good enough tbh

tests/test_modeling_common.py Outdated Show resolved Hide resolved
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to have access to the config here

patrickvonplaten and others added 4 commits November 1, 2023 18:50
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great. I appreciate the # Copied from statements which make the code simpler to review.

Very clean, it's ok for me to merge

class BartEncoderLayer(nn.Module):
def __init__(self, config: BartConfig):
super().__init__()
self.embed_dim = config.d_model
self.self_attn = BartAttention(
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should eventually move this to an enum to be clenaer (out of scope for this PR)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(brainstorming, still out of scope) It would be cleaner to eventually have the config return the appropriate attention name for all models:

self.self_attn = BART_ATTENTION_CLASSES[config.attention_type](
...
)

with

class PreTrainedConfig():
    ...

    @property
    def attention_type(self):
        return AttentionTypes.FA2 if getattr(self, "_flash_attn_2_enabled", False) else AttentionTypes.DEFAULT

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah attention_type as a property is a good idea I think! We should then probably also allow users to change it even after the model was loaded

Comment on lines 3092 to 3094
# make sure that all models have at least 40 position ids
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = 40
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why that minimum?

@patrickvonplaten
Copy link
Contributor Author

Ok ran some more tests and it should be good now. I'm getting some flaky behavior with the flash attention tests on my RTX 4090 (especially extreme for Whisper). We should maybe think about how we can make them more robust now that we've added some more models (cc @younesbelkada)

@patrickvonplaten patrickvonplaten merged commit af3de8d into main Nov 1, 2023
2 of 3 checks passed
@patrickvonplaten patrickvonplaten deleted the fa2_whisper branch November 1, 2023 20:03
ydshieh added a commit that referenced this pull request Nov 2, 2023
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* add whisper fa2

* correct

* change all

* correct

* correct

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix more

* fix more

* fix more

* fix more

* fix more

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants