-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stop confusing the TF compiler with ModelOutput objects #28712
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1060,7 +1060,8 @@ def call( | |
labels = labels[:, 1:] | ||
labels = tf.reshape(labels, (-1,)) | ||
# Keras won't give us label smoothing for sparse CE, so we de-sparsify things here | ||
one_hot_labels = tf.one_hot(labels, depth=self.config.vocab_size, dtype=tf.float32) | ||
# Use relu to clamp masked labels at 0 to avoid NaN (we will be zeroing those out later anyway) | ||
one_hot_labels = tf.one_hot(tf.nn.relu(labels), depth=self.config.vocab_size, dtype=tf.float32) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silly question, but won't this massively increase the loss as it will set the model as predicting token 0 for all the cases when it's -100? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but the masked positions are set to 0 by the line In other words, it doesn't matter what value the masked labels have, as long as it's nonnegative - any non-NaN value will be masked to 0, but NaNs can survive multiplication by zero and infest the rest of the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK - works for me! |
||
loss_fct = tf.keras.losses.CategoricalCrossentropy(from_logits=True, label_smoothing=0.1, reduction="none") | ||
masked_positions = tf.cast(tf.not_equal(labels, -100), dtype=tf.float32) | ||
lm_loss = loss_fct(one_hot_labels, shifted_prediction_scores) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it necessary to do this if we still have the correct indexing logic below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regrettably yes - this bit is the bit that actually fixes the bug! The change to the loss logic turned out to be actually unrelated - I just spotted that there was a NaN risk there when I was trying to find the bug. Still, I think the change to it is worth keeping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. As it's already done in BLIP and it's a pattern we also use for flax models I think it's fine to do :)