Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention 2 support to Bark #27364

Merged
merged 12 commits into from
Nov 8, 2023

Conversation

ylacombe
Copy link
Contributor

@ylacombe ylacombe commented Nov 8, 2023

What does this PR do?

Following a recent series of PRs and issues to improve Bark, this PR aims to add FA2 support to Bark. Bark self-attention class supports both causal and non-causal attention but otherwise changes are minimal.

I've also taken the opportunity to switch to _prepare_4d_attention_mask instead of manually creating the 4d attention mask.

Benchmarks are currently running at the moment to measure speed/memory gains!

cc @sanchit-gandhi and @amyeroberts

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 8, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@sanchit-gandhi sanchit-gandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean - thanks @ylacombe for adding this! Keen to see what kind of performance gain we get from this

cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
):
"""
If you don't know about Flash Attention, check out the official repository of flash attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be worth explaining quickly why we override this method in the docstring!

tests/models/bark/test_modeling_bark.py Outdated Show resolved Hide resolved

dummy_attention_mask = inputs_dict.get("attention_mask", None)

if dummy_attention_mask is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation behind overriding the attention mask here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making sure that at least one of the input ids is masked !


logits = (
outputs.hidden_states[-1]
if not model.config.is_encoder_decoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We know bark is not encoder-decoder -> could we simplify the tests to reflect this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch!

else outputs_fa.decoder_hidden_states[-1]
)

assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty high tolerance! We've compared the audio outputs qualitatively with / without flash attention and they match?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the quick review, I've actually copied out and modified a test that is in the general suite, so I haven't change anything -> tolerance and attention mask overriding are the same than the original test!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same comment on tolerance for FA2 tests :D 0.04 was agreed as being acceptable

tests/models/bark/test_modeling_bark.py Outdated Show resolved Hide resolved
ylacombe and others added 3 commits November 8, 2023 13:56
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
@ylacombe
Copy link
Contributor Author

ylacombe commented Nov 8, 2023

Thanks for the quick review, I've addressed your comments 🤗

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice - thanks for adding!

else outputs_fa.decoder_hidden_states[-1]
)

assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same comment on tolerance for FA2 tests :D 0.04 was agreed as being acceptable

tests/models/bark/test_modeling_bark.py Outdated Show resolved Hide resolved
ylacombe and others added 2 commits November 8, 2023 16:05
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
@ylacombe
Copy link
Contributor Author

ylacombe commented Nov 8, 2023

Merging ! thanks for the quick reviews!

@ylacombe ylacombe merged commit a5bee89 into huggingface:main Nov 8, 2023
18 checks passed
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this! I usually also request to add a section in the readme, and update the flash attention list of models that are supported here and the readme like this change.

else:
present = None

attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here self.dropout is a module not a float. The doc of the _flash_attention_forward does not match and is not restrictive enough

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might work but I'd rather we standardize!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ylacombe ylacombe mentioned this pull request Nov 9, 2023
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* change handmade attention mask to _prepare_4d_attention_mask

* add flashattention2 support in Bark

* add flashattention2 tests on BarkSemanticModel

* make style

* fix flashattention and tests + make style

* fix memory leak and allow Bark to pass flash attention to sub-models

* make style

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unecessary code from tests + justify overriding

* Update tests/models/bark/test_modeling_bark.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants