-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New design of the transformer API #1022
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sararb
added
enhancement
New feature or request
area/tensorflow
P0
breaking
Breaking change
area/session-based
labels
Mar 16, 2023
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Documentation preview |
…hem inside configure_for_train() function.
gabrielspmoreira
approved these changes
Mar 21, 2023
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me @sararb .
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #918
Fixes #1024
Fixes #1025
Fixes #1026
This work was necessary for the GTC tutorial 2023 and is fixing issues related to Causal LM and at the same time simplifying the high-level transformer API.
The new API for defining a transformer-based model is defined as follows:
Goals ⚽
This PR aims to address the following limitations in the current Transformer API:
ReplaceMaskedEmbeddings
,SequenceMaskLastInference
). So the user should be familiar with all these custom blocks to correctly define and train a transformer-based model with CLM or MLM approach.Implementation Details 🚧
masking_pre
andmasking_post
blocks using thepre
transform set in thefit()
method.tf.RaggedTensor
so that all the logic happening after the transformer block (MLP projections, softmax layer...) is applied only on actual positions.SequencePredictNext
andSequencePredictLast
masking_pre
used to train the transformer model.SequenceCausalLastInference
to generate a mask indicating the last non-padded position of each input sequence at inference time.ExtractMaskFromTargets
to move the logic defined inReplaceMaskedEmbeddings
and which consists of inferring the mask information based on targets.TransformerOutputToRagged
to convert the output of the transformer layer to Ragged based on the mask information.Testing Details 🔍
test_transformer_with_masked_language_modeling
andtest_transformer_with_causal_language_modeling
to account for the changes.