-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FA2
] Add flash attention for opt
#26414
Merged
Merged
Changes from 6 commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
626276f
added flash attention for opt
susnato 7800457
added to list
susnato 689f599
fix use cache (#3)
younesbelkada 74e9687
style fix
susnato cf923e8
fix text
susnato db8cf07
test fix2
susnato af67e0f
reverted until 689f599
susnato edf1610
torch fx tests are working now!
susnato f34b680
small fix
susnato 90be210
Merge branch 'main' into flash_attn_opt
susnato 5cec2ad
added TODO docstring
susnato 2bde7cf
Merge branch 'main' into flash_attn_opt
susnato db18d65
Merge branch 'main' into flash_attn_opt
susnato adfbb69
changes
susnato 10ab9b3
Merge branch 'main' into flash_attn_opt
susnato 7d4c688
comments and .md file modification
susnato File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
from typing import List, Optional, Tuple, Union | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
from torch import nn | ||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||
|
@@ -32,12 +33,18 @@ | |
add_code_sample_docstrings, | ||
add_start_docstrings, | ||
add_start_docstrings_to_model_forward, | ||
is_flash_attn_available, | ||
logging, | ||
replace_return_docstrings, | ||
) | ||
from .configuration_opt import OPTConfig | ||
|
||
|
||
if is_flash_attn_available(): | ||
from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | ||
|
||
|
||
logger = logging.get_logger(__name__) | ||
|
||
_CHECKPOINT_FOR_DOC = "facebook/opt-350m" | ||
|
@@ -63,6 +70,19 @@ | |
] | ||
|
||
|
||
# Copied from transformers.models.llama.modeling_llama._get_unpad_data | ||
def _get_unpad_data(padding_mask): | ||
seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) | ||
indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) | ||
return ( | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
|
||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask | ||
def _make_causal_mask( | ||
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 | ||
|
@@ -160,6 +180,7 @@ def forward( | |
attention_mask: Optional[torch.Tensor] = None, | ||
layer_head_mask: Optional[torch.Tensor] = None, | ||
output_attentions: bool = False, | ||
padding_mask: Optional[torch.LongTensor] = None, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
"""Input shape: Batch x Time x Channel""" | ||
|
||
|
@@ -273,17 +294,213 @@ def forward( | |
return attn_output, attn_weights_reshaped, past_key_value | ||
|
||
|
||
class OptFlashAttention2(OPTAttention): | ||
""" | ||
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. | ||
The only required change would be on the forward pass where it needs to correctly call the public API of flash | ||
attention and deal with padding tokens in case the input contains any of them. | ||
""" | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
key_value_states: Optional[torch.Tensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
layer_head_mask: Optional[torch.Tensor] = None, | ||
output_attentions: bool = False, | ||
padding_mask: Optional[torch.LongTensor] = None, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
"""Input shape: Batch x Time x Channel""" | ||
|
||
# if key_value_states are provided this layer is used as a cross-attention layer | ||
# for the decoder | ||
is_cross_attention = key_value_states is not None | ||
|
||
bsz, _, _ = hidden_states.size() | ||
|
||
# get query proj | ||
query_states = self.q_proj(hidden_states) | ||
# get key, value proj | ||
if is_cross_attention and past_key_value is not None: | ||
# reuse k,v, cross_attentions | ||
key_states = past_key_value[0] | ||
value_states = past_key_value[1] | ||
elif is_cross_attention: | ||
# cross_attentions | ||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz) | ||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz) | ||
elif past_key_value is not None: | ||
# reuse k, v, self_attention | ||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | ||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | ||
key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
else: | ||
# self_attention | ||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) | ||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) | ||
|
||
if self.is_decoder: | ||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | ||
# Further calls to cross_attention layer can then reuse all cross-attention | ||
# key/value_states (first "if" case) | ||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of | ||
# all previous decoder key/value_states. Further calls to uni-directional self-attention | ||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) | ||
# if encoder bi-directional self-attention `past_key_value` is always `None` | ||
past_key_value = (key_states, value_states) | ||
|
||
query_length = query_states.shape[1] | ||
tgt_len = key_states.shape[-2] | ||
|
||
# Flash attention requires the input to have the shape | ||
# batch_size x seq_length x head_dim x hidden_dim | ||
query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) | ||
key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) | ||
value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) | ||
|
||
attn_dropout = self.dropout if self.training else 0.0 | ||
|
||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons | ||
# therefore the input hidden states gets silently casted in float32. Hence, we need | ||
# cast them back in float16 just to be sure everything works as expected. | ||
input_dtype = query_states.dtype | ||
if input_dtype == torch.float32: | ||
logger.warning_once( | ||
"The input hidden states seems to be silently casted in float32, this might be related to" | ||
" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | ||
" float16." | ||
) | ||
|
||
query_states = query_states.to(torch.float16) | ||
key_states = key_states.to(torch.float16) | ||
value_states = value_states.to(torch.float16) | ||
|
||
attn_output = self._flash_attention_forward( | ||
query_states, key_states, value_states, padding_mask, query_length, dropout=attn_dropout | ||
) | ||
|
||
attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) | ||
attn_output = self.out_proj(attn_weights_reshaped) | ||
|
||
if not output_attentions: | ||
attn_weights_reshaped = None | ||
|
||
return attn_output, attn_weights_reshaped, past_key_value | ||
|
||
def _flash_attention_forward( | ||
self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None | ||
): | ||
""" | ||
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token | ||
first unpad the input, then computes the attention scores and pad the final attention scores. | ||
|
||
Args: | ||
query_states (`torch.Tensor`): | ||
Input query states to be passed to Flash Attention API | ||
key_states (`torch.Tensor`): | ||
Input key states to be passed to Flash Attention API | ||
value_states (`torch.Tensor`): | ||
Input value states to be passed to Flash Attention API | ||
padding_mask (`torch.Tensor`): | ||
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the | ||
position of padding tokens and 1 for the position of non-padding tokens. | ||
dropout (`int`, *optional*): | ||
Attention dropout | ||
softmax_scale (`float`, *optional*): | ||
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) | ||
""" | ||
# we check if padding_mask contains all ones, in that case we don't use it. This is to make sure the torch.fx | ||
# tests pass for the relevant models | ||
if padding_mask.sum().item() != padding_mask.numel(): | ||
batch_size = query_states.shape[0] | ||
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( | ||
query_states, key_states, value_states, padding_mask, query_length | ||
) | ||
|
||
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | ||
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | ||
|
||
attn_output_unpad = flash_attn_varlen_func( | ||
query_states, | ||
key_states, | ||
value_states, | ||
cu_seqlens_q=cu_seqlens_q, | ||
cu_seqlens_k=cu_seqlens_k, | ||
max_seqlen_q=max_seqlen_in_batch_q, | ||
max_seqlen_k=max_seqlen_in_batch_k, | ||
dropout_p=dropout, | ||
softmax_scale=softmax_scale, | ||
causal=True, | ||
) | ||
|
||
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) | ||
else: | ||
attn_output = flash_attn_func( | ||
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True | ||
) | ||
|
||
return attn_output | ||
|
||
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input | ||
def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): | ||
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) | ||
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape | ||
|
||
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) | ||
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) | ||
if query_length == kv_seq_len: | ||
query_layer = index_first_axis( | ||
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k | ||
) | ||
cu_seqlens_q = cu_seqlens_k | ||
max_seqlen_in_batch_q = max_seqlen_in_batch_k | ||
indices_q = indices_k | ||
elif query_length == 1: | ||
max_seqlen_in_batch_q = 1 | ||
cu_seqlens_q = torch.arange( | ||
batch_size + 1, dtype=torch.int32, device=query_layer.device | ||
) # There is a memcpy here, that is very bad. | ||
indices_q = cu_seqlens_q[:-1] | ||
query_layer = query_layer.squeeze(1) | ||
else: | ||
# The -q_len: slice assumes left padding. | ||
padding_mask = padding_mask[:, -query_length:] | ||
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) | ||
|
||
return ( | ||
query_layer, | ||
key_layer, | ||
value_layer, | ||
indices_q, | ||
(cu_seqlens_q, cu_seqlens_k), | ||
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | ||
) | ||
|
||
|
||
class OPTDecoderLayer(nn.Module): | ||
def __init__(self, config: OPTConfig): | ||
super().__init__() | ||
self.embed_dim = config.hidden_size | ||
self.self_attn = OPTAttention( | ||
embed_dim=self.embed_dim, | ||
num_heads=config.num_attention_heads, | ||
dropout=config.attention_dropout, | ||
is_decoder=True, | ||
bias=config.enable_bias, | ||
) | ||
|
||
if not getattr(config, "_flash_attn_2_enabled", False): | ||
self.self_attn = OPTAttention( | ||
embed_dim=self.embed_dim, | ||
num_heads=config.num_attention_heads, | ||
dropout=config.attention_dropout, | ||
is_decoder=True, | ||
bias=config.enable_bias, | ||
) | ||
else: | ||
self.self_attn = OptFlashAttention2( | ||
embed_dim=self.embed_dim, | ||
num_heads=config.num_attention_heads, | ||
dropout=config.attention_dropout, | ||
is_decoder=True, | ||
bias=config.enable_bias, | ||
) | ||
|
||
self.do_layer_norm_before = config.do_layer_norm_before | ||
self.dropout = config.dropout | ||
self.activation_fn = ACT2FN[config.activation_function] | ||
|
@@ -303,6 +520,7 @@ def forward( | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: Optional[bool] = False, | ||
use_cache: Optional[bool] = False, | ||
padding_mask: Optional[torch.LongTensor] = None, | ||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | ||
""" | ||
Args: | ||
|
@@ -333,6 +551,7 @@ def forward( | |
attention_mask=attention_mask, | ||
layer_head_mask=layer_head_mask, | ||
output_attentions=output_attentions, | ||
padding_mask=padding_mask, | ||
) | ||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) | ||
hidden_states = residual + hidden_states | ||
|
@@ -399,6 +618,7 @@ class OPTPreTrainedModel(PreTrainedModel): | |
base_model_prefix = "model" | ||
supports_gradient_checkpointing = True | ||
_no_split_modules = ["OPTDecoderLayer"] | ||
_supports_flash_attn_2 = True | ||
|
||
def _init_weights(self, module): | ||
std = self.config.init_std | ||
|
@@ -647,6 +867,9 @@ def forward( | |
f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " | ||
f"{mask_seq_length} (sum of the lengths of current and past inputs)" | ||
) | ||
|
||
padding_mask = attention_mask | ||
|
||
causal_attention_mask = self._prepare_decoder_attention_mask( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case we need to pass |
||
attention_mask, input_shape, inputs_embeds, past_key_values_length | ||
) | ||
|
@@ -695,7 +918,7 @@ def forward( | |
def create_custom_forward(module): | ||
def custom_forward(*inputs): | ||
# None for past_key_value | ||
return module(*inputs, output_attentions, None) | ||
return module(*inputs, output_attentions, None, padding_mask=padding_mask) | ||
|
||
return custom_forward | ||
|
||
|
@@ -714,6 +937,7 @@ def custom_forward(*inputs): | |
past_key_value=past_key_value, | ||
output_attentions=output_attentions, | ||
use_cache=use_cache, | ||
padding_mask=padding_mask, | ||
) | ||
|
||
hidden_states = layer_outputs[0] | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that if we check if
attention_mask
is ones here, thetorch.fx
tests pass.