File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed
src/transformers/integrations Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -58,16 +58,18 @@ def flash_attention_forward(
5858 else :
5959 target_dtype = next (layer for layer in module .modules () if isinstance (layer , torch .nn .Linear )).weight .dtype
6060
61- # FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
62- kwargs .pop ("is_causal" , None )
61+ # Instead of relying on the value set in the module directly, we use the is_causal passed in kwargs if it is presented
62+ is_causal = kwargs .pop ("is_causal" , None )
63+ if is_causal is None :
64+ is_causal = module .is_causal
6365
6466 attn_output = _flash_attention_forward (
6567 query ,
6668 key ,
6769 value ,
6870 attention_mask ,
6971 query_length = seq_len ,
70- is_causal = module . is_causal ,
72+ is_causal = is_causal ,
7173 dropout = dropout ,
7274 softmax_scale = scaling ,
7375 sliding_window = sliding_window ,
You can’t perform that action at this time.
0 commit comments