Skip to content

Commit

Permalink
Fix #32390
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Aug 3, 2024
1 parent c1aa0ed commit e22d913
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ...activations import ACT2FN
from ...cache_utils import Cache, HybridCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
Expand Down Expand Up @@ -910,6 +911,17 @@ def _update_causal_mask(
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask


Expand Down

0 comments on commit e22d913

Please sign in to comment.