Skip to content

Commit

Permalink
Better SDPA unmasking implementation (#29318)
Browse files Browse the repository at this point in the history
* better unmask imple

* comment

* typo

* bug report pytorch

* cleanup

* fix import

* add back example

* retrigger ci

* come on
  • Loading branch information
fxmarty authored and Ita Zaporozhets committed May 14, 2024
1 parent 991e03c commit 7e107b0
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 92 deletions.
69 changes: 13 additions & 56 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]

@staticmethod
def _unmask_unattended(
expanded_mask: torch.Tensor, attention_mask: torch.Tensor, unmasked_value: Union[bool, float]
expanded_mask: torch.FloatTensor,
min_dtype: float,
):
# fmt: off
"""
Expand All @@ -200,13 +201,7 @@ def _unmask_unattended(
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
For example, if `attention_mask` is
```
[[0, 0, 1],
[1, 1, 1],
[0, 1, 1]]
```
and `expanded_mask` is (e.g. here left-padding case)
For example, if `expanded_mask` is (e.g. here left-padding case)
```
[[[[0, 0, 0],
[0, 0, 0],
Expand All @@ -232,47 +227,12 @@ def _unmask_unattended(
```
"""
# fmt: on
if expanded_mask.dtype == torch.bool:
raise ValueError(
"AttentionMaskConverter._unmask_unattended expects a float `expanded_mask`, got a BoolTensor."
)

# Get the index of the first non-zero value for every sample in the batch.
# In the above example, indices = [[2], [0], [1]]]
tmp = torch.arange(attention_mask.shape[1], 0, -1)
indices = torch.argmax(attention_mask.cpu() * tmp, 1, keepdim=True)

# Find the batch indexes that have unattended tokens on the leftmost side (e.g. [0, 0, 1, 1, 1]), for which the first rows of the
# expanded mask will be completely unattended.
left_masked_rows = torch.where(indices > 0)[0]

if left_masked_rows.shape[0] == 0:
return expanded_mask
indices = indices[left_masked_rows]

max_len = torch.max(indices)
range_tensor = torch.arange(max_len).unsqueeze(0)
range_tensor = range_tensor.repeat(indices.size(0), 1)

# Avoid unmasking tokens at relevant target positions (on the row axis), by rather unmasking possibly several times the first row that should always be unmasked as we filtered out the batch above.
range_tensor[range_tensor >= indices] = 0

# TODO: we may drop support for 3D attention mask as the refactor from Patrick maybe dropped this case
if expanded_mask.dim() == 4:
num_masks = expanded_mask.shape[1]
if num_masks == 1:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], 0, range_tensor)
else:
# Broadcast [left_masked_rows, 1, 1], [1, num_masks, 1], [left_masked_rows, 1, max_len]
mask_slice = (
left_masked_rows[:, None, None],
torch.arange(num_masks)[None, :, None],
range_tensor[:, None, :],
)
else:
# Broadcast [left_masked_rows, 1], [left_masked_rows, max_len]
mask_slice = (left_masked_rows[:, None], range_tensor)

expanded_mask[mask_slice] = unmasked_value

return expanded_mask
return expanded_mask.mul(~torch.all(expanded_mask == min_dtype, dim=-1, keepdim=True))


def _prepare_4d_causal_attention_mask(
Expand Down Expand Up @@ -406,15 +366,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
key_value_length=key_value_length,
)

# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
#
# This fix is not applied in case we are tracing with torch.jit.trace or symbolic_trace, as _unmask_unattended has a data-dependent
# controlflow that can not be captured properly.
# TODO: _unmask_unattended does not work either with torch.compile when using fullgraph=True. We should find a way to detect this case.
if query_length > 1 and not is_tracing:
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
if not is_tracing and expanded_4d_mask.device.type == "cuda":
expanded_4d_mask = AttentionMaskConverter._unmask_unattended(
expanded_4d_mask, attention_mask, unmasked_value=0.0
expanded_4d_mask, min_dtype=torch.finfo(inputs_embeds.dtype).min
)

return expanded_4d_mask
Expand Down
16 changes: 8 additions & 8 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,9 @@ def forward(
else:
present = None

# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_layer.device.type == "cuda" and attention_mask is not None:
if self._use_sdpa and query_layer.device.type == "cuda" and attention_mask is not None:
# For torch<=2.1.2, SDPA with memory-efficient backend is bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
Expand All @@ -456,6 +456,7 @@ def forward(
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case query_length == 1.
is_causal=self.is_causal and attention_mask is None and query_length > 1,
)

attention_scores = None
else:
attention_scores = query_layer @ key_layer.transpose(-1, -2)
Expand Down Expand Up @@ -1112,18 +1113,17 @@ def forward(
if attention_mask_2d is None:
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
else:
min_dtype = torch.finfo(alibi.dtype).min
attention_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
attention_mask < -1,
torch.finfo(alibi.dtype).min,
min_dtype,
)

# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
if seq_length > 1:
attention_mask = AttentionMaskConverter._unmask_unattended(
attention_mask, attention_mask_2d, unmasked_value=0.0
)
if seq_length > 1 and attention_mask.device.type == "cuda":
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
else:
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
attention_mask = _prepare_4d_causal_attention_mask(
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_causal_attention_mask,
)
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
Expand Down Expand Up @@ -978,18 +979,22 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)

if self.config._attn_implementation == "sdpa" and attention_mask is not None:
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask

Expand Down
39 changes: 17 additions & 22 deletions src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_2
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -534,21 +535,16 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
key = key.unsqueeze(1)
value = value.unsqueeze(1)

# Although these expand are not numerically useful, PyTorch 2.1 can not dispatch to memory-efficient backend
# Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim]
#
# so we could do:
#
# key = key.expand(-1, self.num_heads, -1, -1)
# value = value.expand(-1, self.num_heads, -1, -1)
#
# However SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# so we always dispatch to the math path: https://github.com/pytorch/pytorch/issues/112577.
# Arguably we could still do expand + contiguous when `query.device.type == "cuda"` in order to dispatch on memory-efficient
# backend, but it feels very hacky.
# torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
if is_torch_greater_or_equal_than_2_2:
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)
else:
query_length = query_shape[-1]

Expand Down Expand Up @@ -1020,30 +1016,29 @@ def forward(
self_attention_mask = self_attention_mask.unsqueeze(2 if self.multi_query else 1)

if self._use_sdpa and head_mask is None and not output_attentions:
# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
dtype = self.wte.weight.dtype
min_dtype = torch.finfo(dtype).min
self_attention_mask = torch.where(
self_attention_mask,
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
torch.full([], min_dtype, dtype=dtype, device=self_attention_mask.device),
)

# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
if self.multi_query:
# gpt_bigcode using MQA has the bad taste to use a causal mask with shape
# [batch_size, target_length, 1, source_length], not compatible with SDPA, hence this transpose.
self_attention_mask = self_attention_mask.transpose(1, 2)

if query_length > 1 and attention_mask is not None:
if query_length > 1 and attention_mask is not None and attention_mask.device.type == "cuda":
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
self_attention_mask = AttentionMaskConverter._unmask_unattended(
self_attention_mask, attention_mask, unmasked_value=True
self_attention_mask, min_dtype=min_dtype
)

# SDPA with a custom mask is much faster in fp16/fp32 dtype rather than bool. Cast here to floating point instead of at every layer.
dtype = self.wte.weight.dtype
self_attention_mask = torch.where(
self_attention_mask,
torch.full([], 0.0, dtype=dtype, device=self_attention_mask.device),
torch.full(
[], torch.finfo(self.wte.weight.dtype).min, dtype=dtype, device=self_attention_mask.device
),
)

attention_mask = self_attention_mask

# If a 2D or 3D attention mask is provided for the cross-attention
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -1090,18 +1091,22 @@ def _update_causal_mask(self, attention_mask, input_tensor):
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)

if self.config._attn_implementation == "sdpa" and attention_mask is not None:
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in masked rows from the causal_mask, for example the relevant first rows when
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = causal_mask.mul(~torch.all(causal_mask == min_dtype, dim=-1, keepdim=True)).to(dtype)
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask

Expand Down

0 comments on commit 7e107b0

Please sign in to comment.