Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flash Attention 2 support to Bark #27364

Merged
merged 12 commits into from
Nov 8, 2023
252 changes: 232 additions & 20 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@
BarkEosPrioritizerLogitsProcessor,
SuppressTokensLogitsProcessor,
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_flash_attn_2_available,
logging,
)
from ..auto import AutoModel
Expand All @@ -49,6 +51,11 @@
)


if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa


logger = logging.get_logger(__name__)


Expand All @@ -62,6 +69,19 @@
]


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


class BarkSelfAttention(nn.Module):
# adapted from GPTNeoSelfAttention and Bark code
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
Expand Down Expand Up @@ -187,6 +207,177 @@ def forward(
return outputs


class BarkSelfFlashAttention2(BarkSelfAttention):
"""
Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
return tensor

def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
# re-assemble all head outputs side by side
# (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
return tensor

def forward(
self,
hidden_states,
attention_mask=None,
past_key_values=None,
head_mask=None,
use_cache=False,
output_attentions=False,
):
batch_size, query_len, _ = hidden_states.size()

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if past_key_values is not None:
# (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features)
past_key = past_key_values[0].transpose(1, 2)
past_value = past_key_values[1].transpose(1, 2)
# and merge on seq_length
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)

if use_cache is True:
# (batch, head, seq_length, head_features)
present = (key.transpose(1, 2), value.transpose(1, 2))
else:
present = None

attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here self.dropout is a module not a float. The doc of the _flash_attention_forward does not match and is not restrictive enough

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might work but I'd rather we standardize!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
attn_weights = None
outputs += (attn_weights,)

return outputs

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)

return attn_output

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


BARK_ATTENTION_CLASSES = {
"default": BarkSelfAttention,
"flash_attention_2": BarkSelfFlashAttention2,
}


class BarkLayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""

Expand Down Expand Up @@ -229,7 +420,8 @@ def __init__(self, config, is_causal=False):
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
self.layernorm_2 = nn.LayerNorm(config.hidden_size)

self.attn = BarkSelfAttention(config, is_causal=is_causal)
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)

self.mlp = BarkMLP(config)

Expand Down Expand Up @@ -277,6 +469,7 @@ class BarkPreTrainedModel(PreTrainedModel):

config_class = BarkConfig
supports_gradient_checkpointing = False
_supports_flash_attn_2 = True

def _init_weights(self, module):
"""Initialize the weights."""
Expand Down Expand Up @@ -596,21 +789,13 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = attention_mask if 0 in attention_mask else None
else:
attention_mask = attention_mask.view(batch_size, -1)
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
# from_seq_length is 1 to easily broadcast
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down Expand Up @@ -1233,10 +1418,12 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
# from_seq_length is 1 to easily broadcast
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)

head_mask = self.get_head_mask(head_mask, self.config.num_layers)

Expand Down Expand Up @@ -1669,3 +1856,28 @@ def generate(
return audio, output_lengths

return audio

@classmethod
def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
):
"""
If you don't know about Flash Attention, check out the official repository of flash attention:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be worth explaining quickly why we override this method in the docstring!

https://github.com/Dao-AILab/flash-attention

For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models

The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.

If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
can initialize the correct attention module
"""
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map)

config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
return config
Loading
Loading