-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
XLA train step fixes #17973
XLA train step fixes #17973
Conversation
The documentation is not available anymore as the PR was closed or merged. |
I'd be interested in having @ydshieh's review as well |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Rocketknight1 Great for more XLA compatibility!
Haven't checked the changes in test files yet, but left a few comments/questions in modeling_tf_utils.py
unmasked_loss = loss_fn(labels, logits) | ||
loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype) | ||
loss_denominator = tf.reduce_sum(loss_mask, axis=1) | ||
# Masked positions will have a loss of NaN because -100 is not a valid label |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Is this documented somewhere in TF doc? I don't remember this behavior.
- I think putting this line below
# make sure only labels that are not equal to -100 affect the loss
is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if it's documented, but any cross-entropy loss will yield NaN
or -inf
when the model's probability for the true label is 0 (because log(0)
is undefined). Since labels < 0 do not correspond to a valid category, the loss in those cases will always be one of those values.
# Masked positions will have a loss of NaN because -100 is not a valid label | ||
masked_loss = tf.math.multiply_no_nan(unmasked_loss, loss_mask) | ||
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator | ||
return reduced_masked_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is not equivalent the previous computation.
-
previous: We compute the loss for each actual tokens.
-
this PR: along each batch dimension (i.e. for each sequence), the loss is averaged (over the actual tokens in that sequence).
- return
masked_loss
should be fine (as we get0
for inactivate tokens) - (but I don't know why we don't return a scalar loss that is obtained by averaging over all active tokens - this is what is done in
GPT2LMHeadModel
for example)
- return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was checking the docs now, and PT only outputs one number for the batch -- so it makes sense to include the sum here 👍
(we would have to update the TF docstring)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(BTW, the PT output the average instead of the sum, as I see in GPT2LMHeadModel
- CrossEntropyLoss
have reduction='mean'
by default).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be impossible to keep the old behaviour with XLA, because the number of 'active' tokens will change in each batch.
We could return a scalar number, but in Keras it's nice to return a vector of per-sample losses, because this means the user can use the sample_weight
argument to fit()
if they want to. I think that's fairly uncommon though, so if we want to stick with a scalar, that's fine!
@@ -251,7 +255,7 @@ class TFSequenceClassificationLoss: | |||
""" | |||
|
|||
def hf_compute_loss(self, labels, logits): | |||
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1: | |||
if logits.shape.rank == 1 or logits.shape[1] == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change really necessary?
-
I have experienced a few times
shape
is not working in graph mode, and when it occurs,shape_list
makes those tests pass. -
I heard @gante mentioned some issues with
shape_list
+XLA
, but I didn't check (even completely forget what was wrong).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't remember the exact issue, other than it often causes XLA compilation to fail :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving tf.shape()
here did cause XLA compilation to fail. I could use shape_list
, but I think .shape
and .shape.rank
are fine!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't see tf.shape
used in the previous version.
Regarding .shape
v.s shape_list
, I am ok as long as things work. I am just feel confused that most of the time (in other places) I see .shape
fails while shape_list
works (with graph mode / symbolic tensors etc.).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry, I should explain! shape_list
uses a combination of .shape
and tf.shape()
to build the list. I think using tf.shape()
here confuses XLA, because it looks like a conditional that depends on the specific data you input, and those are forbidden.
I'm not 100% sure of the exact rules it uses, but all I can tell you is that it failed before and it works like this!
# Masked positions will have a loss of NaN because -100 and -1 are not valid labels | ||
masked_loss = tf.math.multiply_no_nan(unmasked_loss, loss_mask) | ||
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator | ||
return reduced_masked_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question above the averaging along dim 1 in TFCausalLanguageModelingLoss
above
if isinstance(x, dict): | ||
x = x.copy() | ||
if isinstance(y, dict): | ||
y = y.copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XLA <3
# Masked positions will have a loss of NaN because -100 is not a valid label | ||
masked_loss = tf.math.multiply_no_nan(unmasked_loss, loss_mask) | ||
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator | ||
return reduced_masked_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was checking the docs now, and PT only outputs one number for the batch -- so it makes sense to include the sum here 👍
(we would have to update the TF docstring)
@@ -251,7 +255,7 @@ class TFSequenceClassificationLoss: | |||
""" | |||
|
|||
def hf_compute_loss(self, labels, logits): | |||
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1: | |||
if logits.shape.rank == 1 or logits.shape[1] == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also don't remember the exact issue, other than it often causes XLA compilation to fail :)
# Just zero out samples where label is -100, no reduction | ||
masked_ns_loss = tf.math.multiply_no_nan(unmasked_ns_loss, ns_loss_mask) | ||
|
||
return masked_ns_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this one be reduced as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ns_loss
is calculated per sample rather than per position, so masked_ns_loss
is a vector of shape (num_samples,)
. We could reduce that to a scalar if we want, though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly, this changes the loss returned by TensorFlow models from a matrix to a vector, which is obviously breaking. While I don't know how many TensorFlow users rely on the current structure of the loss, we at least need to have a flag (probably use_xla=False
) to enable the previous behavior for users who relied on it.
Could you confirm first that my understanding is correct?
I believe both prev. and current version return a vector. The difference is on the size:
|
@@ -236,6 +236,9 @@ class PretrainedConfig(PushToHubMixin): | |||
|
|||
use_bfloat16 (`bool`, *optional*, defaults to `False`): | |||
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models). | |||
tf_legacy_loss (`bool`, *optional*, defaults to `False`): | |||
Whether or not the model should use legacy TensorFlow losses. Legacy losses have variable output | |||
shapes and may not be XLA-compatible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add that this parameter is here for backward compatibility but will be removed in v5?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
@@ -195,11 +195,21 @@ def hf_compute_loss(self, labels, logits): | |||
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( | |||
from_logits=True, reduction=tf.keras.losses.Reduction.NONE | |||
) | |||
if self.config.get("tf_legacy_loss", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Configs are not dict, this will fail.
if self.config.get("tf_legacy_loss", False): | |
if self.config.tf_legacy_loss: |
(Same below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry, I meant to use getattr
, but I guess the key will always be present anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you set in the base class, yes :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this!
@ydshieh I was completely wrong earlier - I'll rewrite my loss functions to not depend on that behaviour, and change the loss computation tests to mask some positions to ensure that gets tested, so I don't miss anything like this in future. |
transformers/src/transformers/modeling_tf_utils.py Lines 211 to 213 in f17136c
As in an earlier comment, I think this loss value is incorrect. Imagine we have 2 sequences of length 100.
In this latest version, the unique token in sentence 1 get an weight (when computing the loss) 20 times larger than each token in the 2nd sentence. (As you first average the loss along sequence dimension). Furthermore, this doesn't correspond to PyTorch's computation, which leads to test failures (I didn't check in detail if this is the cause, but I believe it is). Q: Is there any reason we don't want to sum each token's loss value? |
Hi @ydshieh I'm sorry, I think you're right there! Let me investigate and see if I can make a PR to weight tokens properly, which should hopefully resolve the issue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Big +1 on @ydshieh's comment here. I think the average doesn't correctly consider different padding lengths
@patrickvonplaten Agreed! I fixed that in #18013 |
@@ -236,6 +236,10 @@ class PretrainedConfig(PushToHubMixin): | |||
|
|||
use_bfloat16 (`bool`, *optional*, defaults to `False`): | |||
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models). | |||
tf_legacy_loss (`bool`, *optional*, defaults to `False`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure we want to set the default to False
? This is breaking no? Also it's a somewhat hard-to-discover silent error in this case no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's only very slightly breaking - anyone using Keras or a custom model will not notice any change. The existing losses return very strange shapes like vectors of shape (num_unmasked_tokens,)
that vary in length each iteration, with no mapping from there back to the original tokens. I doubt anyone is using them directly, without computing tf.reduce_mean()
on them.
* Copy inputs to train and test step before modifying them, as this breaks things * Add XLA tests, fix our loss functions to be XLA-compatible * make fixup * Update loss computation test to expect vector of per-sample losses * Patch loss for TFLED * Patch loss for TFAlbert * Add a tf_legacy_loss config flag that enables old loss functions * Stop using config.get() because it's not a dict * Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it * make fixup * Add XLA-compatible RAG loss * Fix dtype of loss mask for TFAlbert * Fix test for XLNet too because it overrides the default one * make fixup * Fix config test * No more depending on GPU NaN behaviour * Add test, avoid potential zero division * Fix test item assignment * Fix loss computation masking test * make fixup * Fix dtype bugs
This PR makes a bunch of changes to the TF codebase to improve XLA support, in preparation for our upcoming big TF release. The goal is to allow users to use
jit_compile
on the vast majority of our models, which should yield large performance improvements for TF. In particular:train_step
andtest_step
so that any mutable Python input dicts are not modified in the step. This was a bad idea anyway, but it causes particular problems with XLA, which is very functional and hates side effects, like JAX.hf_compute_loss
functions to ensure that static shapes are maintained throughout, so that XLA compilation is possible.core
models for now and tagged as@slow
.Left to do:
hf_compute_loss
functions. On a quick search it looked like there were 4-5 of these, so it shouldn't take too long. Any use oftf.boolean_mask
is a surefire sign that XLA compilation will break, because output shapes become data-dependent.