Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New design of the transformer API #1022

Merged
merged 14 commits into from
Mar 21, 2023
Merged

New design of the transformer API #1022

merged 14 commits into from
Mar 21, 2023

Conversation

sararb
Copy link
Contributor

@sararb sararb commented Mar 16, 2023

Fixes #918
Fixes #1024
Fixes #1025
Fixes #1026

  • This work was necessary for the GTC tutorial 2023 and is fixing issues related to Causal LM and at the same time simplifying the high-level transformer API.

  • The new API for defining a transformer-based model is defined as follows:

    transformer_input_dim = 48
    transformer_block = BertBlock(d_model=transformer_input_dim, n_head=8, n_layer=2)
    model = mm.Model(
        mm.InputBlockV2(...),
        transformer_block,
        mm.CategoricalOutput(...),
    )
    seq_mask_random = mm.SequenceMaskRandom(
        schema=seq_schema, target=target, masking_prob=0.3, transformer=transformer_block
    )
    model.compile(...)
    model.fit(..., pre=seq_mask_random)

Goals ⚽

This PR aims to address the following limitations in the current Transformer API:

  • 1. Inference for CausalLM is not supported: the model returns scores for all positions in the input sequence but we are only interested on the last position at inference time.
  • 2. CLM only supports the transform SequencePredictNext: We cannot train or evaluate a CLM model on the last item of the sequence only (i.e using SequencePredictLast)
  • 3. Specialized and complex API: the support of masking and inference is done via specialized Merlin Blocks (ReplaceMaskedEmbeddings, SequenceMaskLastInference). So the user should be familiar with all these custom blocks to correctly define and train a transformer-based model with CLM or MLM approach.
  • 4. Output of the model is a padded dense tensor of scores: As the HuggingFace transformer layer is requiring a dense input. We are converting the ragged inputs to dense (by 0-padding to the maximum length seen in the given batch) right before calling the HF layer. The ops that follow this block are then applied to the padded dense tensor. This means that we compute the logit scores for all positions (even the padded ones) and which can be costly (such in the weight tying multiplication between the hidden representation all items embeddings).

Implementation Details 🚧

  • Add Inference support to a transformer-based model trained with CLM: i.e. select the prediction score of the last non-padded position.
  • Support evaluating a transformer-based model trained with CLM (trained with SequencePredictNext) on the last item in the sequence (i.e using SequencePredictLast).
  • Simplify the transformer API by abstracting the definition of the masking_pre and masking_post blocks using the pre transform set in the fit() method.
  • Optimize the predictions generation of transformer-based model: Convert the dense tensor returned by the transformer layer to a tf.RaggedTensor so that all the logic happening after the transformer block (MLP projections, softmax layer...) is applied only on actual positions.
  • Add masking support for SequencePredictNext and SequencePredictLast
  • Add checks to ensure the SequenceTransform used in evaluate() is aligned with the masking_pre used to train the transformer model.
  • Implement the following specialized blocks (needed for CLM and MLM support):
    • SequenceCausalLastInference to generate a mask indicating the last non-padded position of each input sequence at inference time.
    • ExtractMaskFromTargets to move the logic defined in ReplaceMaskedEmbeddings and which consists of inferring the mask information based on targets.
    • TransformerOutputToRagged to convert the output of the transformer layer to Ragged based on the mask information.
  • Update the example notebooks

Testing Details 🔍

  • Update test_transformer_with_masked_language_modeling and test_transformer_with_causal_language_modeling to account for the changes.
  • Update the

@sararb sararb added this to the Merlin 23.03 milestone Mar 16, 2023
@sararb sararb self-assigned this Mar 16, 2023
@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1022

@rnyak rnyak requested review from gabrielspmoreira and rnyak March 20, 2023 15:34
Copy link
Member

@gabrielspmoreira gabrielspmoreira left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me @sararb .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants