Skip to content

Commit

Permalink
don't zero out the attention_mask when using sliding window with flas…
Browse files Browse the repository at this point in the history
…h attention (#31670)

* don't zero out the attention_mask when using sliding window with flash attention

* chore: lint
  • Loading branch information
winglian authored and ArthurZucker committed Jun 28, 2024
1 parent e3cb841 commit 7edc993
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def forward(
class Gemma2DecoderLayer(nn.Module):
def __init__(self, config: Gemma2Config, layer_idx: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size

self.self_attn = GEMMA2_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
Expand All @@ -625,7 +626,9 @@ def forward(
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
if (
self.config._attn_implementation != "flash_attention_2" and self.is_sliding and attention_mask is not None
): # efficient SDPA and no padding
attention_mask = attention_mask * torch.tril(
torch.ones_like(attention_mask), diagonal=-self.sliding_window
)
Expand Down

0 comments on commit 7edc993

Please sign in to comment.