-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flash Attention 2 support to Bark #27364
Merged
Merged
Changes from 7 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
0eadfab
change handmade attention mask to _prepare_4d_attention_mask
ylacombe 0fcff47
add flashattention2 support in Bark
ylacombe 7ca710c
add flashattention2 tests on BarkSemanticModel
ylacombe 10d81ba
make style
ylacombe 32fb57d
fix flashattention and tests + make style
ylacombe c2ff5f4
fix memory leak and allow Bark to pass flash attention to sub-models
ylacombe ef106a4
make style
ylacombe 425d41d
Apply suggestions from code review
ylacombe 049c2e9
remove unecessary code from tests + justify overriding
ylacombe c6a34cf
Merge branch 'bark-flashattention-2' of github.com:ylacombe/transform…
ylacombe 653fa13
Update tests/models/bark/test_modeling_bark.py
ylacombe 5f76f13
make style
ylacombe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,12 +26,14 @@ | |
BarkEosPrioritizerLogitsProcessor, | ||
SuppressTokensLogitsProcessor, | ||
) | ||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask | ||
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput | ||
from ...modeling_utils import PreTrainedModel, get_parameter_device | ||
from ...utils import ( | ||
add_start_docstrings, | ||
add_start_docstrings_to_model_forward, | ||
is_accelerate_available, | ||
is_flash_attn_2_available, | ||
logging, | ||
) | ||
from ..auto import AutoModel | ||
|
@@ -49,6 +51,11 @@ | |
) | ||
|
||
|
||
if is_flash_attn_2_available(): | ||
from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
|
||
|
@@ -62,6 +69,19 @@ | |
] | ||
|
||
|
||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data | ||
def _get_unpad_data(attention_mask): | ||
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | ||
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) | ||
return ( | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
|
||
class BarkSelfAttention(nn.Module): | ||
# adapted from GPTNeoSelfAttention and Bark code | ||
# BarkSelfAttention can have two attention type, i.e full attention or causal attention | ||
|
@@ -187,6 +207,177 @@ def forward( | |
return outputs | ||
|
||
|
||
class BarkSelfFlashAttention2(BarkSelfAttention): | ||
""" | ||
Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays | ||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of | ||
flash attention and deal with padding tokens in case the input contains any of them. | ||
""" | ||
|
||
def _split_heads(self, tensor, num_heads, attn_head_size): | ||
""" | ||
Splits hidden_size dim into attn_head_size and num_heads | ||
""" | ||
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) | ||
tensor = tensor.view(new_shape) | ||
# Flash attention requires the input to have the shape | ||
# batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features) | ||
return tensor | ||
|
||
def _merge_heads(self, tensor, num_heads, attn_head_size): | ||
""" | ||
Merges attn_head_size dim and num_attn_heads dim into hidden_size | ||
""" | ||
# re-assemble all head outputs side by side | ||
# (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size) | ||
tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,)) | ||
return tensor | ||
|
||
def forward( | ||
self, | ||
hidden_states, | ||
attention_mask=None, | ||
past_key_values=None, | ||
head_mask=None, | ||
use_cache=False, | ||
output_attentions=False, | ||
): | ||
batch_size, query_len, _ = hidden_states.size() | ||
|
||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim | ||
query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2) | ||
|
||
query = self._split_heads(query, self.num_heads, self.head_dim) | ||
key = self._split_heads(key, self.num_heads, self.head_dim) | ||
value = self._split_heads(value, self.num_heads, self.head_dim) | ||
|
||
if past_key_values is not None: | ||
# (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features) | ||
past_key = past_key_values[0].transpose(1, 2) | ||
past_value = past_key_values[1].transpose(1, 2) | ||
# and merge on seq_length | ||
key = torch.cat((past_key, key), dim=1) | ||
value = torch.cat((past_value, value), dim=1) | ||
|
||
if use_cache is True: | ||
# (batch, head, seq_length, head_features) | ||
present = (key.transpose(1, 2), value.transpose(1, 2)) | ||
else: | ||
present = None | ||
|
||
attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout) | ||
|
||
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) | ||
attn_output = self.out_proj(attn_output) | ||
attn_output = self.resid_dropout(attn_output) | ||
|
||
outputs = (attn_output, present) | ||
if output_attentions: | ||
attn_weights = None | ||
outputs += (attn_weights,) | ||
|
||
return outputs | ||
|
||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward | ||
def _flash_attention_forward( | ||
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None | ||
): | ||
""" | ||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | ||
first unpad the input, then computes the attention scores and pad the final attention scores. | ||
|
||
Args: | ||
query_states (`torch.Tensor`): | ||
Input query states to be passed to Flash Attention API | ||
key_states (`torch.Tensor`): | ||
Input key states to be passed to Flash Attention API | ||
value_states (`torch.Tensor`): | ||
Input value states to be passed to Flash Attention API | ||
attention_mask (`torch.Tensor`): | ||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the | ||
position of padding tokens and 1 for the position of non-padding tokens. | ||
dropout (`int`, *optional*): | ||
Attention dropout | ||
softmax_scale (`float`, *optional*): | ||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) | ||
""" | ||
# Contains at least one padding token in the sequence | ||
if attention_mask is not None: | ||
batch_size = query_states.shape[0] | ||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | ||
query_states, key_states, value_states, attention_mask, query_length | ||
) | ||
|
||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | ||
|
||
attn_output_unpad = flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
dropout_p=dropout, | ||
softmax_scale=softmax_scale, | ||
causal=self.is_causal, | ||
) | ||
|
||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | ||
else: | ||
attn_output = flash_attn_func( | ||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal | ||
) | ||
|
||
return attn_output | ||
|
||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input | ||
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): | ||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) | ||
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | ||
|
||
key_layer = index_first_axis( | ||
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | ||
) | ||
value_layer = index_first_axis( | ||
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k | ||
) | ||
if query_length == kv_seq_len: | ||
query_layer = index_first_axis( | ||
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k | ||
) | ||
cu_seqlens_q = cu_seqlens_k | ||
max_seqlen_in_batch_q = max_seqlen_in_batch_k | ||
indices_q = indices_k | ||
elif query_length == 1: | ||
max_seqlen_in_batch_q = 1 | ||
cu_seqlens_q = torch.arange( | ||
batch_size + 1, dtype=torch.int32, device=query_layer.device | ||
) # There is a memcpy here, that is very bad. | ||
indices_q = cu_seqlens_q[:-1] | ||
query_layer = query_layer.squeeze(1) | ||
else: | ||
# The -q_len: slice assumes left padding. | ||
attention_mask = attention_mask[:, -query_length:] | ||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) | ||
|
||
return ( | ||
query_layer, | ||
key_layer, | ||
value_layer, | ||
indices_q, | ||
(cu_seqlens_q, cu_seqlens_k), | ||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | ||
) | ||
|
||
|
||
BARK_ATTENTION_CLASSES = { | ||
"default": BarkSelfAttention, | ||
"flash_attention_2": BarkSelfFlashAttention2, | ||
} | ||
|
||
|
||
class BarkLayerNorm(nn.Module): | ||
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False.""" | ||
|
||
|
@@ -229,7 +420,8 @@ def __init__(self, config, is_causal=False): | |
self.layernorm_1 = nn.LayerNorm(config.hidden_size) | ||
self.layernorm_2 = nn.LayerNorm(config.hidden_size) | ||
|
||
self.attn = BarkSelfAttention(config, is_causal=is_causal) | ||
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default" | ||
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal) | ||
|
||
self.mlp = BarkMLP(config) | ||
|
||
|
@@ -277,6 +469,7 @@ class BarkPreTrainedModel(PreTrainedModel): | |
|
||
config_class = BarkConfig | ||
supports_gradient_checkpointing = False | ||
_supports_flash_attn_2 = True | ||
|
||
def _init_weights(self, module): | ||
"""Initialize the weights.""" | ||
|
@@ -596,21 +789,13 @@ def forward( | |
if attention_mask is not None: | ||
if batch_size <= 0: | ||
raise ValueError("batch_size has to be defined and > 0") | ||
attention_mask = attention_mask.view(batch_size, -1) | ||
# We create a 3D attention mask from a 2D tensor mask. | ||
# Sizes are [batch_size, 1, 1, to_seq_length] | ||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | ||
# this attention mask is more simple than the triangular masking of causal attention | ||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | ||
attention_mask = attention_mask[:, None, None, :] | ||
|
||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | ||
# masked positions, this operation will create a tensor which is 0.0 for | ||
# positions we want to attend and the dtype's smallest value for masked positions. | ||
# Since we are adding it to the raw scores before the softmax, this is | ||
# effectively the same as removing these entirely. | ||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility | ||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min | ||
if getattr(self.config, "_flash_attn_2_enabled", False): | ||
attention_mask = attention_mask if 0 in attention_mask else None | ||
else: | ||
attention_mask = attention_mask.view(batch_size, -1) | ||
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] | ||
# from_seq_length is 1 to easily broadcast | ||
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1) | ||
|
||
# Prepare head mask if needed | ||
# 1.0 in head_mask indicate we keep the head | ||
|
@@ -1233,10 +1418,12 @@ def forward( | |
if attention_mask is not None: | ||
if batch_size <= 0: | ||
raise ValueError("batch_size has to be defined and > 0") | ||
attention_mask = attention_mask.view(batch_size, -1) | ||
attention_mask = attention_mask[:, None, None, :] | ||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility | ||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min | ||
if getattr(self.config, "_flash_attn_2_enabled", False): | ||
attention_mask = attention_mask if 0 in attention_mask else None | ||
else: | ||
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] | ||
# from_seq_length is 1 to easily broadcast | ||
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1) | ||
|
||
head_mask = self.get_head_mask(head_mask, self.config.num_layers) | ||
|
||
|
@@ -1669,3 +1856,28 @@ def generate( | |
return audio, output_lengths | ||
|
||
return audio | ||
|
||
@classmethod | ||
def _check_and_enable_flash_attn_2( | ||
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None | ||
): | ||
""" | ||
If you don't know about Flash Attention, check out the official repository of flash attention: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be worth explaining quickly why we override this method in the docstring! |
||
https://github.com/Dao-AILab/flash-attention | ||
|
||
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this | ||
specific section of the documentation to learn more about it: | ||
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models | ||
|
||
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in | ||
half precision and not ran on CPU. | ||
|
||
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model | ||
can initialize the correct attention module | ||
""" | ||
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map) | ||
|
||
config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) | ||
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) | ||
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False) | ||
return config |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here self.dropout is a module not a float. The doc of the
_flash_attention_forward
does not match and is not restrictive enoughThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might work but I'd rather we standardize!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually a float here: https://github.com/ylacombe/transformers/blob/3258ff93304078b9e27d752e6c19d3813f664855/src/transformers/models/bark/modeling_bark.py#L93-L94 !