-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flash Attention 2 support to Bark #27364
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very clean - thanks @ylacombe for adding this! Keen to see what kind of performance gain we get from this
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None | ||
): | ||
""" | ||
If you don't know about Flash Attention, check out the official repository of flash attention: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be worth explaining quickly why we override this method in the docstring!
|
||
dummy_attention_mask = inputs_dict.get("attention_mask", None) | ||
|
||
if dummy_attention_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the motivation behind overriding the attention mask here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Making sure that at least one of the input ids is masked !
|
||
logits = ( | ||
outputs.hidden_states[-1] | ||
if not model.config.is_encoder_decoder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We know bark is not encoder-decoder -> could we simplify the tests to reflect this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch!
else outputs_fa.decoder_hidden_states[-1] | ||
) | ||
|
||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty high tolerance! We've compared the audio outputs qualitatively with / without flash attention and they match?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick review, I've actually copied out and modified a test that is in the general suite, so I haven't change anything -> tolerance and attention mask overriding are the same than the original test!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had the same comment on tolerance for FA2 tests :D 0.04 was agreed as being acceptable
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
…ers into bark-flashattention-2
Thanks for the quick review, I've addressed your comments 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice - thanks for adding!
else outputs_fa.decoder_hidden_states[-1] | ||
) | ||
|
||
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had the same comment on tolerance for FA2 tests :D 0.04 was agreed as being acceptable
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Merging ! thanks for the quick reviews! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this! I usually also request to add a section in the readme, and update the flash attention list of models that are supported here and the readme like this change.
else: | ||
present = None | ||
|
||
attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here self.dropout is a module not a float. The doc of the _flash_attention_forward
does not match and is not restrictive enough
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might work but I'd rather we standardize!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
* change handmade attention mask to _prepare_4d_attention_mask * add flashattention2 support in Bark * add flashattention2 tests on BarkSemanticModel * make style * fix flashattention and tests + make style * fix memory leak and allow Bark to pass flash attention to sub-models * make style * Apply suggestions from code review Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> * remove unecessary code from tests + justify overriding * Update tests/models/bark/test_modeling_bark.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * make style --------- Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
What does this PR do?
Following a recent series of PRs and issues to improve Bark, this PR aims to add FA2 support to Bark. Bark self-attention class supports both causal and non-causal attention but otherwise changes are minimal.
I've also taken the opportunity to switch to
_prepare_4d_attention_mask
instead of manually creating the 4d attention mask.Benchmarks are currently running at the moment to measure speed/memory gains!
cc @sanchit-gandhi and @amyeroberts