Skip to content

Commit

Permalink
Fix TF Roberta for mixed precision training (#11675)
Browse files Browse the repository at this point in the history
  • Loading branch information
jplu authored May 11, 2021
1 parent a135f59 commit d9b2862
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/models/roberta/modeling_tf_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,9 @@ def call(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(1.0, extended_attention_mask), -10000.0)
one_cst = tf.constant(1.0, dtype=embedding_output.dtype)
ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype)
extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst)

# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
Expand Down

0 comments on commit d9b2862

Please sign in to comment.