Skip to content

Commit

Permalink
Do not drop mask with SDPA for more cases (#30311)
Browse files Browse the repository at this point in the history
* overlooked

* style

* cleaner
  • Loading branch information
fxmarty authored and Ita Zaporozhets committed May 14, 2024
1 parent 0fcf8b3 commit 157bf0a
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,12 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
ignore_causal_mask = False

if attention_mask is None:
if sliding_window is None or key_value_length < sliding_window:
ignore_causal_mask = not is_tracing
if (
not is_tracing
and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window)
):
ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window:
# 4d mask is passed through
if len(attention_mask.shape) == 4:
Expand Down

0 comments on commit 157bf0a

Please sign in to comment.