Skip to content

Commit

Permalink
🚨 FLAVA: Remove double softmax (#31322)
Browse files Browse the repository at this point in the history
Remove double softmax
  • Loading branch information
amyeroberts authored Jun 10, 2024
1 parent 8fff07d commit a4e1a1d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
2 changes: 0 additions & 2 deletions src/transformers/models/flava/modeling_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,6 @@ def forward(

# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
Expand Down
2 changes: 1 addition & 1 deletion tests/models/flava/test_modeling_flava.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ def test_inference(self):
# verify the embeddings
self.assertAlmostEqual(outputs.image_embeddings.sum().item(), -1352.53540, places=4)
self.assertAlmostEqual(outputs.text_embeddings.sum().item(), -198.98225, places=4)
self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.4602050, places=4)
self.assertAlmostEqual(outputs.multimodal_embeddings.sum().item(), -4030.4604492, places=4)


@require_vision
Expand Down

0 comments on commit a4e1a1d

Please sign in to comment.