Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes support of sequential continuous features for sequential and non-sequential models #969

Merged
merged 6 commits into from
Feb 8, 2023

Conversation

gabrielspmoreira
Copy link
Member

@gabrielspmoreira gabrielspmoreira commented Feb 1, 2023

Fixes #953 , Closes #960

Goals ⚽

Add support to sequential continuous features on sequential models (e.g. Transformers, RNNs) and non-sequential models (e.g. averaging the sequential features and projecting with MLP).

Implementation Details 🚧

  • Fixed bugs when using concatenating continuous sequential columns with categorical sequential columns in sequential models (e.g. Transformers, RNNs)
  • Added support to average continuous sequential features by using SequenceAggregator, which nows support dict inputs and applied the combiner only on tensors whose shape is higher than the reduction axis (e.g. 3D tensors). This allows keeping 2D tensors (batch_size, 1) of continuous features as they are and applying the reduction only on sequential continuous features (batch_size, None, 1).
  • Removed SequenceAggregation enum, as it created difficulties for serialization/deserialization of custom objects. Now SequenceAgreggator just tasks str as input.

Here is an example on how to average both sequential categorical and continuous features for a non-sequential model (e.g. simple MLP):
P.s. The aggregation (average) will only apply for 3D features, keeping the other 2D features unchanged (e.g. user or session level features), so that they can be concatenated.

ml.InputBlockV2(
            schema,
            categorical=ml.Embeddings(
                schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner="mean"
            ),
            continuous=ml.Continuous(post=SequenceAggregator("mean")),
        )

Other related fixes:

  • Included MLPBlock between InputBlockV2 and Transformer blocks in the units tests to serve as example for the users that usually the concatenated features dim does not match the input dim of Transformers (d_model).
  • Fixed compute_output_shape() of SequenceTransform derived classes
  • Fixed ProcessList to return dense tensors when in column schema the value_count.max == value_count.min
  • Changed maybe_deserialize_keras_objects() to expose the custom_objects argument
  • Updated compute_output_shape() of EmbeddingTable to deal correctly with 3D input tensors (batch_size, seq_length, 1)
  • Changed Continuous class to filter by default based on Tags.CONTINUOUS, if the filter is not provided

Testing Details 🔍

  • Added test (test_mlp_model_with_sequential_features_and_combiner) to demonstrate and test how to use sequential continuous and categorical features in non-sequential model by averaging.
  • Changed Transformers tests to include not only categorical sequential features but also continuous ones.

@gabrielspmoreira gabrielspmoreira self-assigned this Feb 1, 2023
@gabrielspmoreira gabrielspmoreira added bug Something isn't working enhancement New feature or request labels Feb 1, 2023
@gabrielspmoreira gabrielspmoreira added this to the Merlin 23.02 milestone Feb 1, 2023
@github-actions
Copy link

github-actions bot commented Feb 1, 2023

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-969

@gabrielspmoreira gabrielspmoreira marked this pull request as draft February 1, 2023 02:26
@rnyak rnyak added the P0 label Feb 1, 2023
@gabrielspmoreira gabrielspmoreira force-pushed the tf/continuous_seq_feats_fix branch from 2d13ceb to 2e1c2a7 Compare February 2, 2023 17:07
@gabrielspmoreira gabrielspmoreira marked this pull request as ready for review February 2, 2023 17:08
merlin/models/tf/core/aggregation.py Outdated Show resolved Hide resolved
merlin/models/tf/inputs/continuous.py Show resolved Hide resolved
Copy link
Contributor

@sararb sararb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR @gabrielspmoreira! All the changes sound good to me. I just left some minor comments.

@@ -371,7 +398,6 @@ class SequenceTargetAsInput(SequenceTransform):
so that the tensors sequences can be processed
"""

@tf.function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why we needed to decorate this call with @tf.function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point @sararb . I end up removing it during debugging and did not returned it. Using Git Blame,
I found that @edknv in this PR annoted with @tf.function some functions to fix some dataloader race conditions with list columns with TF >= 2.10. I gonna return the @tf.function annotation.

merlin/models/tf/transforms/sequence.py Outdated Show resolved Hide resolved
predict_last = ml.SequencePredictLast(schema=schema.select_by_tag(Tags.SEQUENCE), target=target)

testing_utils.model_test(
model, loader, run_eagerly=run_eagerly, reload_model=False, fit_kwargs={"pre": predict_last}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test with reload_model=True to ensure the model can be re-loaded for new training/eval iterations and to check that the input signatures are correctly saved for serving?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I changed model_test() to reload_model=True and the test still passes

@gabrielspmoreira gabrielspmoreira force-pushed the tf/continuous_seq_feats_fix branch from 83e200f to 95cc7a5 Compare February 6, 2023 15:41
@gabrielspmoreira
Copy link
Member Author

rerun tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request P0
Projects
None yet
3 participants