Skip to content

Commit

Permalink
[FA2] Fix flash attention 2 fine-tuning with Falcon (#26852)
Browse files Browse the repository at this point in the history
fix fa2 + dropout issue
  • Loading branch information
younesbelkada authored Oct 17, 2023
1 parent 4b423e6 commit 41c42f8
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def forward(
if alibi is not None:
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")

attn_dropout = self.attention_dropout if self.training else 0.0
attn_dropout = self.config.attention_dropout if self.training else 0.0

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
Expand Down
4 changes: 4 additions & 0 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2810,6 +2810,10 @@ def test_flash_attn_2_inference(self):

self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))

# check with inference + dropout
model.train()
_ = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
Expand Down

0 comments on commit 41c42f8

Please sign in to comment.