Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FA2] Add flash attention for opt #26414

Merged
merged 16 commits into from
Nov 23, 2023
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ We natively support Flash Attention 2 for the following models:

- Llama
- Falcon
- OPT

You can request to add Flash Attention 2 support for more models by opening an issue on GitHub, and even open a Pull Request to integrate the changes. The supported models can be used for inference and training, including training with padding tokens - *which is currently not supported for `BetterTransformer` API below.*

Expand Down
240 changes: 232 additions & 8 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
Expand All @@ -32,12 +33,18 @@
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_available,
logging,
replace_return_docstrings,
)
from .configuration_opt import OPTConfig


if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
Expand All @@ -63,6 +70,19 @@
]


# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(padding_mask):
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
Expand Down Expand Up @@ -160,6 +180,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

Expand Down Expand Up @@ -273,17 +294,213 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value


class OptFlashAttention2(OPTAttention):
"""
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
The only required change would be on the forward pass where it needs to correctly call the public API of flash
attention and deal with padding tokens in case the input contains any of them.
"""

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None

bsz, _, _ = hidden_states.size()

# get query proj
query_states = self.q_proj(hidden_states)
# get key, value proj
if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
elif past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)

query_length = query_states.shape[1]
tgt_len = key_states.shape[-2]

# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)

attn_dropout = self.dropout if self.training else 0.0

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once(
"The input hidden states seems to be silently casted in float32, this might be related to"
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
" float16."
)

query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)

attn_output = self._flash_attention_forward(
query_states, key_states, value_states, padding_mask, query_length, dropout=attn_dropout
)

attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
attn_output = self.out_proj(attn_weights_reshaped)

if not output_attentions:
attn_weights_reshaped = None

return attn_output, attn_weights_reshaped, past_key_value

def _flash_attention_forward(
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.

Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
padding_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
"""
# we check if padding_mask contains all ones, in that case we don't use it. This is to make sure the torch.fx
# tests pass for the relevant models
if padding_mask.sum().item() != padding_mask.numel():
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
Copy link
Contributor Author

@susnato susnato Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that if we check if attention_mask is ones here, the torch.fx tests pass.

query_states, key_states, value_states, padding_mask, query_length
)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True
)

return attn_output

# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask)
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape

key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
padding_mask = padding_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)


class OPTDecoderLayer(nn.Module):
def __init__(self, config: OPTConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=config.enable_bias,
)

if not getattr(config, "_flash_attn_2_enabled", False):
self.self_attn = OPTAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=config.enable_bias,
)
else:
self.self_attn = OptFlashAttention2(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
bias=config.enable_bias,
)

self.do_layer_norm_before = config.do_layer_norm_before
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
Expand All @@ -303,6 +520,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -333,6 +551,7 @@ def forward(
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
padding_mask=padding_mask,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -399,6 +618,7 @@ class OPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_supports_flash_attn_2 = True

def _init_weights(self, module):
std = self.config.init_std
Expand Down Expand Up @@ -647,6 +867,9 @@ def forward(
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
f"{mask_seq_length} (sum of the lengths of current and past inputs)"
)

padding_mask = attention_mask

causal_attention_mask = self._prepare_decoder_attention_mask(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case we need to pass padding_mask = attention_mask and we check for the ones array in the _flash_attention_forward. If we use any if-else condition here it gives an error for the torch.fx tests.

attention_mask, input_shape, inputs_embeds, past_key_values_length
)
Expand Down Expand Up @@ -695,7 +918,7 @@ def forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return module(*inputs, output_attentions, None, padding_mask=padding_mask)

return custom_forward

Expand All @@ -714,6 +937,7 @@ def custom_forward(*inputs):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)

hidden_states = layer_outputs[0]
Expand Down