Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making TF BART-like models XLA and AMP compliant #10191

Merged
merged 16 commits into from
Feb 17, 2021

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Feb 15, 2021

What does this PR do?

This PR makes the TF BART-like models compliant with AMP and XLA. The main issue for XLA was all the asserts, XLA is not compliant with them (see the TF doc), so I had to disable them if the model is run with another mode than eager.

TF Marian and Pegasus have still their XLA test locked because they are not working for XLA_GPU. I need to investigate more in order to better understand why. My first guess is because of the TFXSinusoidalPositionalEmbedding class.

@jplu
Copy link
Contributor Author

jplu commented Feb 15, 2021

I succeed to fix Marian and Pegasus, and my first guess was the good one. I basically reworked a bit how the embedding was created, and now it works in XLA_GPU. Of course, all the corresponding slow tests are passing, and the weights are properly loaded.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, LGTM!

class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Layer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was causing trouble to XLA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, more precisely, XLA on GPU. This is because the weights with the tf.keras.layers.Embeddings are initialized on CPU and the model is run on GPU, in XLA you cannot access from one to the other.

This is because the embeddings are created in the __init__ of the classes instead of being created in the build().

Comment on lines +280 to +283
def test_saved_model_creation(self):
# This test is too long (>30sec) and makes fail the CI
pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks a lot for fixing those!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

@jplu jplu merged commit 83d803b into huggingface:master Feb 17, 2021
@jplu jplu deleted the tf-bart-xla-amp branch February 17, 2021 17:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants