-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making TF BART-like models XLA and AMP compliant #10191
Conversation
I succeed to fix Marian and Pegasus, and my first guess was the good one. I basically reworked a bit how the embedding was created, and now it works in XLA_GPU. Of course, all the corresponding slow tests are passing, and the weights are properly loaded. |
0940def
to
dfb3d5e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, LGTM!
class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Embedding): | ||
class TFMarianSinusoidalPositionalEmbedding(tf.keras.layers.Layer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was causing trouble to XLA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, more precisely, XLA on GPU. This is because the weights with the tf.keras.layers.Embeddings
are initialized on CPU and the model is run on GPU, in XLA you cannot access from one to the other.
This is because the embeddings are created in the __init__
of the classes instead of being created in the build()
.
def test_saved_model_creation(self): | ||
# This test is too long (>30sec) and makes fail the CI | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks a lot for fixing those!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
What does this PR do?
This PR makes the TF BART-like models compliant with AMP and XLA. The main issue for XLA was all the asserts, XLA is not compliant with them (see the TF doc), so I had to disable them if the model is run with another mode than eager.
TF Marian and Pegasus have still their XLA test locked because they are not working for XLA_GPU. I need to investigate more in order to better understand why. My first guess is because of the
TFXSinusoidalPositionalEmbedding
class.