-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Return scalar losses instead of per-sample means #18013
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you @Rocketknight1!
- I didn't see the change in RAG model (which you modified in the last PR and we have some test failures)
- You want to work on this in a separate PR?
- Could you run (some of) the current (related) failed tests before merge?
@ydshieh will investigate RAG and make sure tests pass before merging! |
@ydshieh I reverted the RAG loss function to the pre-XLA version, so hopefully those tests pass now. All other tests are passing! Do you think there's anything else I'm missing before I merge? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Do we have some tests for TF<>PT loss equivalence? If not, would it be hard to set them?
Nothing I can think of as the tests pass now 🎉 . Thank you! |
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator | ||
return reduced_masked_loss | ||
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask) | ||
return tf.reshape(reduced_masked_loss, (1,)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Rocketknight1 Sorry to bother, but I didn't notice this part yesterday, and now wondering why we need (1,)
as returned shapre?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ydshieh Keras expects the loss to have at least one dimension, and so standard methods will often fail if the loss is a pure scalar (with shape None
). Everything works fine if it has shape (1,)
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I respect Keras!
* Return scalar losses instead of per-sample means * Make loss shape (1,) instead of scalar * Allow scalar losses in test_loss_computation * Allow scalar losses in test_loss_computation * Allow scalar losses in test_loss_computation * Remove XLA loss function for RAG
This updates the TF XLA-compatible losses to return scalars instead of per-sample means. As @ydshieh pointed out, per-sample means give too much weight to samples with fewer masked positions. The new approach should match PyTorch losses exactly (up to floating-point error).
TODO: