diff --git a/merlin/models/tf/prediction_tasks/regression.py b/merlin/models/tf/prediction_tasks/regression.py index eb0e6ac50a..326056dc5a 100644 --- a/merlin/models/tf/prediction_tasks/regression.py +++ b/merlin/models/tf/prediction_tasks/regression.py @@ -70,31 +70,54 @@ def __init__( else: target_name = target if target else kwargs.pop("target_name", None) - logit = kwargs.pop("logit", None) + output_layer = kwargs.pop("output_layer", None) super().__init__( target_name=target_name, task_name=task_name, task_block=task_block, **kwargs, ) - self.logit = logit or tf.keras.layers.Dense(1, name=self.child_name("logit")) + self.output_layer = output_layer or tf.keras.layers.Dense( + 1, name=self.child_name("output_layer") + ) # To ensure that the output is always fp32, avoiding numerical # instabilities with mixed_float16 policy self.output_activation = tf.keras.layers.Activation( "linear", dtype="float32", name="prediction" ) - def call(self, inputs, training=False, **kwargs): - return self.output_activation(self.logit(inputs)) + def call(self, inputs: tf.Tensor, training=False, **kwargs) -> tf.Tensor: + """Projects the input with the output layer to a single logit + + Parameters + ---------- + inputs : tf.Tensor + Input tensor + training : bool, optional + Flag that indicates whether it is training or not, by default False + + Returns + ------- + tf.Tensor + Tensor with the regression logit + """ + return self.output_activation(self.output_layer(inputs)) + + def compute_output_shape(self, input_shape): + return self.output_layer.compute_output_shape(input_shape) def get_config(self): config = super().get_config() - config = maybe_serialize_keras_objects(self, config, {"logit": tf.keras.layers.serialize}) + config = maybe_serialize_keras_objects( + self, config, {"output_layer": tf.keras.layers.serialize} + ) return config @classmethod def from_config(cls, config): - config = maybe_deserialize_keras_objects(config, ["logit"], tf.keras.layers.deserialize) + config = maybe_deserialize_keras_objects( + config, ["output_layer"], tf.keras.layers.deserialize + ) return super().from_config(config) diff --git a/tests/unit/tf/prediction_tasks/test_regression.py b/tests/unit/tf/prediction_tasks/test_regression.py index 930bcfada1..6bd5a34a5d 100644 --- a/tests/unit/tf/prediction_tasks/test_regression.py +++ b/tests/unit/tf/prediction_tasks/test_regression.py @@ -15,14 +15,29 @@ # import pytest -import merlin.models.tf as ml +import merlin.models.tf as mm from merlin.io import Dataset from merlin.models.tf.utils import testing_utils @pytest.mark.parametrize("run_eagerly", [True, False]) -def test_regression_head(ecommerce_data: Dataset, run_eagerly): - body = ml.InputBlock(ecommerce_data.schema).connect(ml.MLPBlock([64])) - model = ml.Model(body, ml.RegressionTask("click")) +def test_regression_head(ecommerce_data: Dataset, run_eagerly: bool): + body = mm.InputBlock(ecommerce_data.schema).connect(mm.MLPBlock([64])) + model = mm.Model(body, mm.RegressionTask("click")) - testing_utils.model_test(model, ecommerce_data) + testing_utils.model_test(model, ecommerce_data, run_eagerly=run_eagerly) + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_regression_head_schema(music_streaming_data: Dataset, run_eagerly: bool): + body = mm.InputBlock(music_streaming_data.schema).connect(mm.MLPBlock([64])) + model = mm.Model(body, mm.RegressionTask(music_streaming_data.schema)) + + testing_utils.model_test(model, music_streaming_data, run_eagerly=run_eagerly) + + +def test_regression_head_serialization(music_streaming_data: Dataset): + regression_task = mm.RegressionTask("click") + assert isinstance( + regression_task.from_config(regression_task.get_config()), type(regression_task) + )