Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

XLA train step fixes #17973

Merged
merged 21 commits into from
Jul 1, 2022
Merged

XLA train step fixes #17973

merged 21 commits into from
Jul 1, 2022

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Jun 30, 2022

This PR makes a bunch of changes to the TF codebase to improve XLA support, in preparation for our upcoming big TF release. The goal is to allow users to use jit_compile on the vast majority of our models, which should yield large performance improvements for TF. In particular:

  • Rewrites to the train_step and test_step so that any mutable Python input dicts are not modified in the step. This was a bad idea anyway, but it causes particular problems with XLA, which is very functional and hates side effects, like JAX.
  • Rewrites to the common hf_compute_loss functions to ensure that static shapes are maintained throughout, so that XLA compilation is possible.
  • Add a test to ensure that we can still fit models when XLA compilation is used. XLA compilation is quite expensive, which makes this test quite slow, so it's restricted to core models for now and tagged as @slow.

Left to do:

  • Fix XLA-incompatible model-specific hf_compute_loss functions. On a quick search it looked like there were 4-5 of these, so it shouldn't take too long. Any use of tf.boolean_mask is a surefire sign that XLA compilation will break, because output shapes become data-dependent.
  • See if there's a way to test non-core models for XLA fit support without crippling performance. (No, but we're using the XLA losses in non-XLA tests by default, so that partially tests it for all models)

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 30, 2022

The documentation is not available anymore as the PR was closed or merged.

@LysandreJik
Copy link
Member

I'd be interested in having @ydshieh's review as well

@LysandreJik LysandreJik requested a review from ydshieh July 1, 2022 06:38
Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Rocketknight1 Great for more XLA compatibility!

Haven't checked the changes in test files yet, but left a few comments/questions in modeling_tf_utils.py

unmasked_loss = loss_fn(labels, logits)
loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
loss_denominator = tf.reduce_sum(loss_mask, axis=1)
# Masked positions will have a loss of NaN because -100 is not a valid label
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Is this documented somewhere in TF doc? I don't remember this behavior.
  • I think putting this line below # make sure only labels that are not equal to -100 affect the loss is better

Copy link
Member Author

@Rocketknight1 Rocketknight1 Jul 1, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's documented, but any cross-entropy loss will yield NaN or -inf when the model's probability for the true label is 0 (because log(0) is undefined). Since labels < 0 do not correspond to a valid category, the loss in those cases will always be one of those values.

# Masked positions will have a loss of NaN because -100 is not a valid label
masked_loss = tf.math.multiply_no_nan(unmasked_loss, loss_mask)
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is not equivalent the previous computation.

  • previous: We compute the loss for each actual tokens.

  • this PR: along each batch dimension (i.e. for each sequence), the loss is averaged (over the actual tokens in that sequence).

    • return masked_loss should be fine (as we get 0 for inactivate tokens)
    • (but I don't know why we don't return a scalar loss that is obtained by averaging over all active tokens - this is what is done in GPT2LMHeadModel for example)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was checking the docs now, and PT only outputs one number for the batch -- so it makes sense to include the sum here 👍

(we would have to update the TF docstring)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(BTW, the PT output the average instead of the sum, as I see in GPT2LMHeadModel - CrossEntropyLoss have reduction='mean' by default).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will be impossible to keep the old behaviour with XLA, because the number of 'active' tokens will change in each batch.

We could return a scalar number, but in Keras it's nice to return a vector of per-sample losses, because this means the user can use the sample_weight argument to fit() if they want to. I think that's fairly uncommon though, so if we want to stick with a scalar, that's fine!

@@ -251,7 +255,7 @@ class TFSequenceClassificationLoss:
"""

def hf_compute_loss(self, labels, logits):
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
if logits.shape.rank == 1 or logits.shape[1] == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change really necessary?

  • I have experienced a few times shape is not working in graph mode, and when it occurs, shape_list makes those tests pass.

  • I heard @gante mentioned some issues with shape_list + XLA, but I didn't check (even completely forget what was wrong).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't remember the exact issue, other than it often causes XLA compilation to fail :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving tf.shape() here did cause XLA compilation to fail. I could use shape_list, but I think .shape and .shape.rank are fine!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see tf.shape used in the previous version.

Regarding .shape v.s shape_list, I am ok as long as things work. I am just feel confused that most of the time (in other places) I see .shape fails while shape_list works (with graph mode / symbolic tensors etc.).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I should explain! shape_list uses a combination of .shape and tf.shape() to build the list. I think using tf.shape() here confuses XLA, because it looks like a conditional that depends on the specific data you input, and those are forbidden.

I'm not 100% sure of the exact rules it uses, but all I can tell you is that it failed before and it works like this!

# Masked positions will have a loss of NaN because -100 and -1 are not valid labels
masked_loss = tf.math.multiply_no_nan(unmasked_loss, loss_mask)
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question above the averaging along dim 1 in TFCausalLanguageModelingLoss above

if isinstance(x, dict):
x = x.copy()
if isinstance(y, dict):
y = y.copy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice~

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XLA <3

# Masked positions will have a loss of NaN because -100 is not a valid label
masked_loss = tf.math.multiply_no_nan(unmasked_loss, loss_mask)
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was checking the docs now, and PT only outputs one number for the batch -- so it makes sense to include the sum here 👍

(we would have to update the TF docstring)

@@ -251,7 +255,7 @@ class TFSequenceClassificationLoss:
"""

def hf_compute_loss(self, labels, logits):
if len(shape_list(logits)) == 1 or shape_list(logits)[1] == 1:
if logits.shape.rank == 1 or logits.shape[1] == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't remember the exact issue, other than it often causes XLA compilation to fail :)

# Just zero out samples where label is -100, no reduction
masked_ns_loss = tf.math.multiply_no_nan(unmasked_ns_loss, ns_loss_mask)

return masked_ns_loss
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this one be reduced as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ns_loss is calculated per sample rather than per position, so masked_ns_loss is a vector of shape (num_samples,). We could reduce that to a scalar if we want, though!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this changes the loss returned by TensorFlow models from a matrix to a vector, which is obviously breaking. While I don't know how many TensorFlow users rely on the current structure of the loss, we at least need to have a flag (probably use_xla=False) to enable the previous behavior for users who relied on it.

Could you confirm first that my understanding is correct?

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 1, 2022

If I understand correctly, this changes the loss returned by TensorFlow models from a matrix to a vector, which is obviously breaking. While I don't know how many TensorFlow users rely on the current structure of the loss, we at least need to have a flag (probably use_xla=False) to enable the previous behavior for users who relied on it.

Could you confirm first that my understanding is correct?

I believe both prev. and current version return a vector. The difference is on the size:

  • prev: number of active tokens (non-padding tokens)
  • now: batch size

@@ -236,6 +236,9 @@ class PretrainedConfig(PushToHubMixin):

use_bfloat16 (`bool`, *optional*, defaults to `False`):
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
Whether or not the model should use legacy TensorFlow losses. Legacy losses have variable output
shapes and may not be XLA-compatible.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add that this parameter is here for backward compatibility but will be removed in v5?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

@@ -195,11 +195,21 @@ def hf_compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
if self.config.get("tf_legacy_loss", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Configs are not dict, this will fail.

Suggested change
if self.config.get("tf_legacy_loss", False):
if self.config.tf_legacy_loss:

(Same below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, I meant to use getattr, but I guess the key will always be present anyway?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you set in the base class, yes :-)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this!

@Rocketknight1
Copy link
Member Author

@ydshieh I was completely wrong earlier - SparseCategoricalCrossentropy only returns nan for invalid labels when running on GPU! On CPU, inputs are validated and TensorFlow throws an error.

I'll rewrite my loss functions to not depend on that behaviour, and change the loss computation tests to mask some positions to ensure that gets tested, so I don't miss anything like this in future.

@Rocketknight1 Rocketknight1 merged commit d6cec45 into main Jul 1, 2022
@Rocketknight1 Rocketknight1 deleted the xla_train_step_fixes branch July 1, 2022 18:11
@ydshieh
Copy link
Collaborator

ydshieh commented Jul 4, 2022

@Rocketknight1

masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss

As in an earlier comment, I think this loss value is incorrect. Imagine we have 2 sequences of length 100.

  • 1st sentence: 1 active token + 99 pad tokens (somehow non-sense 😄 )
  • 2nd sentence: 20 active token + 80 pad tokens

In this latest version, the unique token in sentence 1 get an weight (when computing the loss) 20 times larger than each token in the 2nd sentence. (As you first average the loss along sequence dimension).

Furthermore, this doesn't correspond to PyTorch's computation, which leads to test failures (I didn't check in detail if this is the cause, but I believe it is).

Q: Is there any reason we don't want to sum each token's loss value?

cc @gante @patrickvonplaten @sgugger

@Rocketknight1
Copy link
Member Author

Hi @ydshieh I'm sorry, I think you're right there! Let me investigate and see if I can make a PR to weight tokens properly, which should hopefully resolve the issue.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big +1 on @ydshieh's comment here. I think the average doesn't correctly consider different padding lengths

@Rocketknight1
Copy link
Member Author

@patrickvonplaten Agreed! I fixed that in #18013

@@ -236,6 +236,10 @@ class PretrainedConfig(PushToHubMixin):

use_bfloat16 (`bool`, *optional*, defaults to `False`):
Whether or not the model should use BFloat16 scalars (only used by some TensorFlow models).
tf_legacy_loss (`bool`, *optional*, defaults to `False`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure we want to set the default to False? This is breaking no? Also it's a somewhat hard-to-discover silent error in this case no?

Copy link
Member Author

@Rocketknight1 Rocketknight1 Jul 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only very slightly breaking - anyone using Keras or a custom model will not notice any change. The existing losses return very strange shapes like vectors of shape (num_unmasked_tokens,) that vary in length each iteration, with no mapping from there back to the original tokens. I doubt anyone is using them directly, without computing tf.reduce_mean() on them.

@ydshieh ydshieh mentioned this pull request Jul 5, 2022
9 tasks
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* Copy inputs to train and test step before modifying them, as this breaks things

* Add XLA tests, fix our loss functions to be XLA-compatible

* make fixup

* Update loss computation test to expect vector of per-sample losses

* Patch loss for TFLED

* Patch loss for TFAlbert

* Add a tf_legacy_loss config flag that enables old loss functions

* Stop using config.get() because it's not a dict

* Skip loss computation test for RAG because its loss is very strange and I'm afraid to rewrite it

* make fixup

* Add XLA-compatible RAG loss

* Fix dtype of loss mask for TFAlbert

* Fix test for XLNet too because it overrides the default one

* make fixup

* Fix config test

* No more depending on GPU NaN behaviour

* Add test, avoid potential zero division

* Fix test item assignment

* Fix loss computation masking test

* make fixup

* Fix dtype bugs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants