-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes support of sequential continuous features for sequential and non-sequential models #969
Conversation
Documentation preview |
2d13ceb
to
2e1c2a7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @gabrielspmoreira! All the changes sound good to me. I just left some minor comments.
@@ -371,7 +398,6 @@ class SequenceTargetAsInput(SequenceTransform): | |||
so that the tensors sequences can be processed | |||
""" | |||
|
|||
@tf.function |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know why we needed to decorate this call with @tf.function
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/unit/tf/blocks/test_mlp.py
Outdated
predict_last = ml.SequencePredictLast(schema=schema.select_by_tag(Tags.SEQUENCE), target=target) | ||
|
||
testing_utils.model_test( | ||
model, loader, run_eagerly=run_eagerly, reload_model=False, fit_kwargs={"pre": predict_last} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we test with reload_model=True
to ensure the model can be re-loaded for new training/eval iterations and to check that the input signatures are correctly saved for serving?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I changed model_test()
to reload_model=True
and the test still passes
83e200f
to
95cc7a5
Compare
rerun tests |
Fixes #953 , Closes #960
Goals ⚽
Add support to sequential continuous features on sequential models (e.g. Transformers, RNNs) and non-sequential models (e.g. averaging the sequential features and projecting with MLP).
Implementation Details 🚧
SequenceAggregator
, which nows support dict inputs and applied the combiner only on tensors whose shape is higher than the reduction axis (e.g. 3D tensors). This allows keeping 2D tensors (batch_size, 1) of continuous features as they are and applying the reduction only on sequential continuous features (batch_size, None, 1).SequenceAggregation
enum, as it created difficulties for serialization/deserialization of custom objects. Now SequenceAgreggator just tasks str as input.Here is an example on how to average both sequential categorical and continuous features for a non-sequential model (e.g. simple MLP):
P.s. The aggregation (average) will only apply for 3D features, keeping the other 2D features unchanged (e.g. user or session level features), so that they can be concatenated.
Other related fixes:
MLPBlock
betweenInputBlockV2
and Transformer blocks in the units tests to serve as example for the users that usually the concatenated features dim does not match the input dim of Transformers (d_model).compute_output_shape()
ofSequenceTransform
derived classesProcessList
to return dense tensors when in column schema the value_count.max == value_count.minmaybe_deserialize_keras_objects()
to expose thecustom_objects
argumentcompute_output_shape()
ofEmbeddingTable
to deal correctly with 3D input tensors (batch_size, seq_length, 1)Continuous
class to filter by default based on Tags.CONTINUOUS, if the filter is not providedTesting Details 🔍
test_mlp_model_with_sequential_features_and_combiner
) to demonstrate and test how to use sequential continuous and categorical features in non-sequential model by averaging.