Skip to content

Commit 0ce24f5

Browse files
Fix Causality Handling in Flash Attention to Support Bidirectional Attention (#39707)
Fix the is_causal logic to enable bidirectional attention Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
1 parent 83dbebc commit 0ce24f5

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/transformers/integrations/flash_attention.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,18 @@ def flash_attention_forward(
5858
else:
5959
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
6060

61-
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
62-
kwargs.pop("is_causal", None)
61+
# Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
62+
is_causal = kwargs.pop("is_causal", None)
63+
if is_causal is None:
64+
is_causal = module.is_causal
6365

6466
attn_output = _flash_attention_forward(
6567
query,
6668
key,
6769
value,
6870
attention_mask,
6971
query_length=seq_len,
70-
is_causal=module.is_causal,
72+
is_causal=is_causal,
7173
dropout=dropout,
7274
softmax_scale=scaling,
7375
sliding_window=sliding_window,

0 commit comments

Comments
 (0)