diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index a3adf99d4cb8cc..dcba7b79f9194a 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -19,7 +19,7 @@ # limitations under the License. """ PyTorch LLaMA model.""" import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict import torch import torch.nn.functional as F @@ -323,6 +323,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, + flash_kwargs: None = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -478,6 +479,7 @@ def forward( output_attentions: bool = False, use_cache: bool = False, padding_mask: Optional[torch.LongTensor] = None, + flash_kwargs: Optional[Dict] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # LlamaFlashAttention attention does not support output_attentions output_attentions = False @@ -519,9 +521,12 @@ def forward( # when training. dropout_rate = 0.0 # if not self.training else self.attn_dropout - # contains at least one padding token - if padding_mask is not None: - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + # contains at least one masked token + if flash_kwargs["masking"]: + indices_k = flash_kwargs["indices_k"] + cu_seqlens_k = flash_kwargs["cu_seqlens_k"] + max_seqlen_in_batch_k = flash_kwargs["max_seqlen_in_batch_k"] + key_states = index_first_axis(rearrange(key_states, "b s ... -> (b s) ..."), indices_k) value_states = index_first_axis(rearrange(value_states, "b s ... -> (b s) ..."), indices_k) @@ -533,11 +538,9 @@ def forward( indices_q = indices_k elif q_len == 1: max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - bsz + 1, dtype=torch.int32, device=query_states.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_states = query_states.squeeze(1) + cu_seqlens_q = flash_kwargs["cu_seqlens_q"] + indices_q = flash_kwargs["indices_q"] + query_states = query_states.squeeze(1) # [batch_size, 1, num_heads, head_dim] -> [batch_size, num_heads, head_dim] else: # The -q_len: slice assumes left padding. padding_mask = padding_mask[:, -q_len:] @@ -591,6 +594,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, padding_mask: Optional[torch.LongTensor] = None, + flash_kwargs: Optional[Dict] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -619,6 +623,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + flash_kwargs=flash_kwargs, ) hidden_states = residual + hidden_states @@ -770,6 +775,7 @@ def __init__(self, config: LlamaConfig): self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False + self._flash = getattr(config, "_flash_attn_2_enabled", False) # Initialize weights and apply final processing self.post_init() @@ -864,9 +870,26 @@ def forward( else: padding_mask = None - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + + flash_kwargs = None + if not self._flash: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + else: + flash_kwargs = {} + flash_kwargs["masking"] = padding_mask is not None + + if padding_mask is not None: + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + flash_kwargs["indices_k"] = indices_k + flash_kwargs["cu_seqlens_k"] = cu_seqlens_k + flash_kwargs["max_seqlen_in_batch_k"] = max_seqlen_in_batch_k + if seq_length == 1: + flash_kwargs["cu_seqlens_q"] = torch.arange( + batch_size + 1, dtype=torch.int32, device=input_ids.device + ) # There is a memcpy here, that is very bad. At least happening only once. + flash_kwargs["indices_q"] = flash_kwargs["cu_seqlens_q"][:-1] hidden_states = inputs_embeds @@ -909,6 +932,7 @@ def custom_forward(*inputs): output_attentions=output_attentions, use_cache=use_cache, padding_mask=padding_mask, + flash_kwargs=flash_kwargs, ) hidden_states = layer_outputs[0]