Skip to content

Commit

Permalink
Add Flash Attention 2 support to Bark (#27364)
Browse files Browse the repository at this point in the history
* change handmade attention mask to _prepare_4d_attention_mask

* add flashattention2 support in Bark

* add flashattention2 tests on BarkSemanticModel

* make style

* fix flashattention and tests + make style

* fix memory leak and allow Bark to pass flash attention to sub-models

* make style

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove unecessary code from tests + justify overriding

* Update tests/models/bark/test_modeling_bark.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make style

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 8, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent ef71673 commit a5bee89
Showing 2 changed files with 355 additions and 20 deletions.
256 changes: 236 additions & 20 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
@@ -26,12 +26,14 @@
BarkEosPrioritizerLogitsProcessor,
SuppressTokensLogitsProcessor,
)
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import CausalLMOutputWithPast, MaskedLMOutput
from ...modeling_utils import PreTrainedModel, get_parameter_device
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_flash_attn_2_available,
logging,
)
from ..auto import AutoModel
@@ -49,6 +51,11 @@
)


if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa


logger = logging.get_logger(__name__)


@@ -62,6 +69,19 @@
]


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


class BarkSelfAttention(nn.Module):
# adapted from GPTNeoSelfAttention and Bark code
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
@@ -187,6 +207,177 @@ def forward(
return outputs


class BarkSelfFlashAttention2(BarkSelfAttention):
"""
Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
return tensor

def _merge_heads(self, tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
# re-assemble all head outputs side by side
# (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
return tensor

def forward(
self,
hidden_states,
attention_mask=None,
past_key_values=None,
head_mask=None,
use_cache=False,
output_attentions=False,
):
batch_size, query_len, _ = hidden_states.size()

# calculate query, key, values for all heads in batch and move head forward to be the batch dim
query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

if past_key_values is not None:
# (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features)
past_key = past_key_values[0].transpose(1, 2)
past_value = past_key_values[1].transpose(1, 2)
# and merge on seq_length
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)

if use_cache is True:
# (batch, head, seq_length, head_features)
present = (key.transpose(1, 2), value.transpose(1, 2))
else:
present = None

attn_output = self._flash_attention_forward(query, key, value, attention_mask, query_len, dropout=self.dropout)

attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)
if output_attentions:
attn_weights = None
outputs += (attn_weights,)

return outputs

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=self.is_causal,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal
)

return attn_output

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


BARK_ATTENTION_CLASSES = {
"default": BarkSelfAttention,
"flash_attention_2": BarkSelfFlashAttention2,
}


class BarkLayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False."""

@@ -229,7 +420,8 @@ def __init__(self, config, is_causal=False):
self.layernorm_1 = nn.LayerNorm(config.hidden_size)
self.layernorm_2 = nn.LayerNorm(config.hidden_size)

self.attn = BarkSelfAttention(config, is_causal=is_causal)
attn_type = "flash_attention_2" if getattr(config, "_flash_attn_2_enabled", False) else "default"
self.attn = BARK_ATTENTION_CLASSES[attn_type](config, is_causal=is_causal)

self.mlp = BarkMLP(config)

@@ -277,6 +469,7 @@ class BarkPreTrainedModel(PreTrainedModel):

config_class = BarkConfig
supports_gradient_checkpointing = False
_supports_flash_attn_2 = True

def _init_weights(self, module):
"""Initialize the weights."""
@@ -596,21 +789,13 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = attention_mask if 0 in attention_mask else None
else:
attention_mask = attention_mask.view(batch_size, -1)
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
# from_seq_length is 1 to easily broadcast
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
@@ -1233,10 +1418,12 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
if getattr(self.config, "_flash_attn_2_enabled", False):
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
# from_seq_length is 1 to easily broadcast
attention_mask = _prepare_4d_attention_mask(attention_mask, input_embeds.dtype, tgt_len=1)

head_mask = self.get_head_mask(head_mask, self.config.num_layers)

@@ -1669,3 +1856,32 @@ def generate(
return audio, output_lengths

return audio

@classmethod
def _check_and_enable_flash_attn_2(
cls, config, torch_dtype: Optional[torch.dtype] = None, device_map: Optional[Union[str, Dict[str, int]]] = None
):
"""
`_check_and_enable_flash_attn_2` originally don't expand flash attention enabling to the model
sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention
if necessary.
If you don't know about Flash Attention, check out the official repository of flash attention:
https://github.com/Dao-AILab/flash-attention
For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
specific section of the documentation to learn more about it:
https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
half precision and not ran on CPU.
If all checks pass, the method will create an attribute in the config `_flash_attn_2_enabled` so that the model
can initialize the correct attention module
"""
config = super()._check_and_enable_flash_attn_2(config, torch_dtype, device_map)

config.semantic_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
config.coarse_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
config.fine_acoustics_config._flash_attn_2_enabled = getattr(config, "_flash_attn_2_enabled", False)
return config
Loading

0 comments on commit a5bee89

Please sign in to comment.