From 7e107b0348998440dcfbac58a4ef38d76216ba3a Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:36:47 +0100 Subject: [PATCH] Better SDPA unmasking implementation (#29318) * better unmask imple * comment * typo * bug report pytorch * cleanup * fix import * add back example * retrigger ci * come on --- src/transformers/modeling_attn_mask_utils.py | 69 ++++--------------- .../models/falcon/modeling_falcon.py | 16 ++--- .../models/gemma/modeling_gemma.py | 11 ++- .../gpt_bigcode/modeling_gpt_bigcode.py | 39 +++++------ .../models/llama/modeling_llama.py | 11 ++- 5 files changed, 54 insertions(+), 92 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 1a2c0db7bb140c..faae0d763f4e59 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -187,7 +187,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] @staticmethod def _unmask_unattended( - expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float] + expanded_mask: torch.FloatTensor, + min_dtype: float, ): # fmt: off """ @@ -200,13 +201,7 @@ def _unmask_unattended( The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias. - For example, if `attention_mask` is - ``` - [[0, 0, 1], - [1, 1, 1], - [0, 1, 1]] - ``` - and `expanded_mask` is (e.g. here left-padding case) + For example, if `expanded_mask` is (e.g. here left-padding case) ``` [[[[0, 0, 0], [0, 0, 0], @@ -232,47 +227,12 @@ def _unmask_unattended( ``` """ # fmt: on + if expanded_mask.dtype == torch.bool: + raise ValueError( + "AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor." + ) - # Get the index of the first non-zero value for every sample in the batch. - # In the above example, indices = [[2], [0], [1]]] - tmp = torch.arange(attention_mask.shape[1], 0, -1) - indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True) - - # Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the - # expanded mask will be completely unattended. - left_masked_rows = torch.where(indices > 0)[0] - - if left_masked_rows.shape[0] == 0: - return expanded_mask - indices = indices[left_masked_rows] - - max_len = torch.max(indices) - range_tensor = torch.arange(max_len).unsqueeze(0) - range_tensor = range_tensor.repeat(indices.size(0), 1) - - # Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above. - range_tensor[range_tensor >= indices] = 0 - - # TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case - if expanded_mask.dim() == 4: - num_masks = expanded_mask.shape[1] - if num_masks == 1: - # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len] - mask_slice = (left_masked_rows[:, None], 0, range_tensor) - else: - # Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len] - mask_slice = ( - left_masked_rows[:, None, None], - torch.arange(num_masks)[None, :, None], - range_tensor[:, None, :], - ) - else: - # Broadcast [left_masked_rows, 1], [left_masked_rows, max_len] - mask_slice = (left_masked_rows[:, None], range_tensor) - - expanded_mask[mask_slice] = unmasked_value - - return expanded_mask + return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True)) def _prepare_4d_causal_attention_mask( @@ -406,15 +366,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa( key_value_length=key_value_length, ) - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - # - # This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent - # controlflow that can not be captured properly. - # TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case. - if query_length > 1 and not is_tracing: + # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + if not is_tracing and expanded_4d_mask.device.type == "cuda": expanded_4d_mask = AttentionMaskConverter._unmask_unattended( - expanded_4d_mask, attention_mask, unmasked_value=0.0 + expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min ) return expanded_4d_mask diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 7ef857748ca813..2dde8d1cac67f6 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -438,9 +438,9 @@ def forward( else: present = None - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_layer.device.type == "cuda" and attention_mask is not None: + if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None: + # For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. query_layer = query_layer.contiguous() key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() @@ -456,6 +456,7 @@ def forward( # The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1. is_causal=self.is_causal and attention_mask is None and query_length > 1, ) + attention_scores = None else: attention_scores = query_layer @ key_layer.transpose(-1, -2) @@ -1112,18 +1113,17 @@ def forward( if attention_mask_2d is None: attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) else: + min_dtype = torch.finfo(alibi.dtype).min attention_mask = torch.masked_fill( alibi / math.sqrt(self.config.hidden_size // self.num_heads), attention_mask < -1, - torch.finfo(alibi.dtype).min, + min_dtype, ) # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if seq_length > 1: - attention_mask = AttentionMaskConverter._unmask_unattended( - attention_mask, attention_mask_2d, unmasked_value=0.0 - ) + if seq_length > 1 and attention_mask.device.type == "cuda": + attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype) else: # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. attention_mask = _prepare_4d_causal_attention_mask( diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 13265be8f3e1e9..e78ff54be865ea 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_causal_attention_mask, ) from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast @@ -978,7 +979,11 @@ def _update_causal_mask(self, attention_mask, input_tensor): padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) - if self.config._attn_implementation == "sdpa" and attention_mask is not None: + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). is_tracing = ( torch.jit.is_tracing() @@ -986,10 +991,10 @@ def _update_causal_mask(self, attention_mask, input_tensor): or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) if not is_tracing and torch.any(attention_mask != 1): - # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype) + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 0b8a1bbb485517..2ef46eaa9f7322 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -30,6 +30,7 @@ TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import is_torch_greater_or_equal_than_2_2 from ...utils import ( add_code_sample_docstrings, add_start_docstrings, @@ -534,21 +535,16 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): key = key.unsqueeze(1) value = value.unsqueeze(1) - # Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to memory-efficient backend + # Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend # and flash attention backend (No available kernel. Aborting execution.) from the shapes # query = [batch_size, num_heads, query_length, head_dim] # key = [batch_size, 1, past_length, head_dim] # value = [batch_size, 1, past_length, head_dim] # - # so we could do: - # - # key = key.expand(-1, self.num_heads, -1, -1) - # value = value.expand(-1, self.num_heads, -1, -1) - # - # However SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # so we always dispatch to the math path: https://github.com/pytorch/pytorch/issues/112577. - # Arguably we could still do expand + contiguous when `query.device.type == "cuda"` in order to dispatch on memory-efficient - # backend, but it feels very hacky. + # torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check. + if is_torch_greater_or_equal_than_2_2: + key = key.expand(-1, self.num_heads, -1, -1) + value = value.expand(-1, self.num_heads, -1, -1) else: query_length = query_shape[-1] @@ -1020,6 +1016,15 @@ def forward( self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1) if self._use_sdpa and head_mask is None and not output_attentions: + # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. + dtype = self.wte.weight.dtype + min_dtype = torch.finfo(dtype).min + self_attention_mask = torch.where( + self_attention_mask, + torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), + torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device), + ) + # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. if self.multi_query: @@ -1027,23 +1032,13 @@ def forward( # [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose. self_attention_mask = self_attention_mask.transpose(1, 2) - if query_length > 1 and attention_mask is not None: + if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda": # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 self_attention_mask = AttentionMaskConverter._unmask_unattended( - self_attention_mask, attention_mask, unmasked_value=True + self_attention_mask, min_dtype=min_dtype ) - # SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer. - dtype = self.wte.weight.dtype - self_attention_mask = torch.where( - self_attention_mask, - torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device), - torch.full( - [], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device - ), - ) - attention_mask = self_attention_mask # If a 2D or 3D attention mask is provided for the cross-attention diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 254310d2653977..4ea8a208a92315 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -1090,7 +1091,11 @@ def _update_causal_mask(self, attention_mask, input_tensor): padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) - if self.config._attn_implementation == "sdpa" and attention_mask is not None: + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + ): # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). is_tracing = ( torch.jit.is_tracing() @@ -1098,10 +1103,10 @@ def _update_causal_mask(self, attention_mask, input_tensor): or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) if not is_tracing and torch.any(attention_mask != 1): - # Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype) + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask