From 3bc50d81e6c70d63e59d635106bac6a561b47681 Mon Sep 17 00:00:00 2001 From: Susnato Dhar Date: Thu, 23 Nov 2023 15:46:51 +0530 Subject: [PATCH] [`FA2`] Add flash attention for opt (#26414) * added flash attention for opt * added to list * fix use cache (#3) * style fix * fix text * test fix2 * reverted until 689f599 * torch fx tests are working now! * small fix * added TODO docstring * changes * comments and .md file modification --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --- docs/source/en/model_doc/opt.md | 49 ++++ src/transformers/models/opt/modeling_opt.py | 304 ++++++++++++++++++-- 2 files changed, 323 insertions(+), 30 deletions(-) diff --git a/docs/source/en/model_doc/opt.md b/docs/source/en/model_doc/opt.md index 68da201f99bc72..3da7b22fab747d 100644 --- a/docs/source/en/model_doc/opt.md +++ b/docs/source/en/model_doc/opt.md @@ -62,6 +62,55 @@ The resource should ideally demonstrate something new instead of duplicating an - A blog post on [How 🤗 Accelerate runs very large models thanks to PyTorch](https://huggingface.co/blog/accelerate-large-models) with OPT. + +## Combining OPT and Flash Attention 2 + +First, make sure to install the latest version of Flash Attention 2 to include the sliding window attention feature. + +```bash +pip install -U flash-attn --no-build-isolation +``` + +Make also sure that you have a hardware that is compatible with Flash-Attention 2. Read more about it in the official documentation of flash-attn repository. Make also sure to load your model in half-precision (e.g. `torch.float16``) + +To load and run a model using Flash Attention 2, refer to the snippet below: + +```python +>>> import torch +>>> from transformers import OPTForCausalLM, GPT2Tokenizer +>>> device = "cuda" # the device to load the model onto + +>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16, use_flash_attention_2=True) +>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m") + +>>> prompt = ("A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the " + "Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived " + "there?") + +>>> model_inputs = tokenizer([prompt], return_tensors="pt").to(device) +>>> model.to(device) + +>>> generated_ids = model.generate(**model_inputs, max_new_tokens=30, do_sample=False) +>>> tokenizer.batch_decode(generated_ids)[0] +'A chat between a curious human and the Statue of Liberty.\n\nHuman: What is your name?\nStatue: I am the Statue of Liberty.\nHuman: Where do you live?\nStatue: New York City.\nHuman: How long have you lived there?\nStatue: I have lived here for about a year.\nHuman: What is your favorite place to eat?\nStatue: I love' +``` + +### Expected speedups + +Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `facebook/opt-2.7b` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths. + +
+ +
+ +Below is an expected speedup diagram that compares pure inference time between the native implementation in transformers using `facebook/opt-350m` checkpoint and the Flash Attention 2 version of the model using two different sequence lengths. + +
+ +
+ + + ## OPTConfig [[autodoc]] OPTConfig diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 5b56ff9e618ddf..2192f327bc49f9 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -16,6 +16,7 @@ from typing import List, Optional, Tuple, Union import torch +import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -33,12 +34,18 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_flash_attn_2_available, logging, replace_return_docstrings, ) from .configuration_opt import OPTConfig +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "facebook/opt-350m" @@ -64,6 +71,19 @@ ] +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + class OPTLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -93,30 +113,49 @@ class OPTAttention(nn.Module): def __init__( self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, + config: OPTConfig, is_decoder: bool = False, - bias: bool = True, + **kwargs, ): super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads + self.config = config + + def _handle_deprecated_argument(config_arg_name, config, fn_arg_name, kwargs): + """ + If a the deprecated argument `fn_arg_name` is passed, raise a deprecation + warning and return that value, otherwise take the equivalent config.config_arg_name + """ + val = None + if fn_arg_name in kwargs: + logging.warning( + "Passing in {} to {self.__class__.__name__} is deprecated and won't be supported from v4.38." + " Please set it in the config instead" + ) + val = kwargs.pop(fn_arg_name) + else: + val = getattr(config, config_arg_name) + return val - if (self.head_dim * num_heads) != self.embed_dim: + self.embed_dim = _handle_deprecated_argument("hidden_size", config, "embed_dim", kwargs) + self.num_heads = _handle_deprecated_argument("num_attention_heads", config, "num_heads", kwargs) + self.dropout = _handle_deprecated_argument("attention_dropout", config, "dropout", kwargs) + self.enable_bias = _handle_deprecated_argument("enable_bias", config, "bias", kwargs) + + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" - f" and `num_heads`: {num_heads})." + f" and `num_heads`: {self.num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -242,17 +281,210 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +class OptFlashAttention2(OPTAttention): + """ + OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of flash + attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, _, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + class OPTDecoderLayer(nn.Module): def __init__(self, config: OPTConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = OPTAttention( - embed_dim=self.embed_dim, - num_heads=config.num_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - bias=config.enable_bias, - ) + + if not getattr(config, "_flash_attn_2_enabled", False): + self.self_attn = OPTAttention(config=config, is_decoder=True) + else: + self.self_attn = OptFlashAttention2(config=config, is_decoder=True) + self.do_layer_norm_before = config.do_layer_norm_before self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -368,6 +600,7 @@ class OPTPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["OPTDecoderLayer"] + _supports_flash_attn_2 = True def _init_weights(self, module): std = self.config.init_std @@ -581,16 +814,27 @@ def forward( mask_seq_length = past_key_values_length + seq_length # embed positions - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - elif attention_mask.shape[1] != mask_seq_length: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{mask_seq_length} (sum of the lengths of current and past inputs)" + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if attention_mask is None + else attention_mask ) - causal_attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + else: + # 4d mask is passed through the layers + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + elif attention_mask.shape[1] != mask_seq_length: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{mask_seq_length} (sum of the lengths of current and past inputs)" + ) + causal_attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + pos_embeds = self.embed_positions(attention_mask, past_key_values_length) if self.project_in is not None: