Skip to content

Commit

Permalink
Merge branch 'main' into tf-session-based-broadcast-features
Browse files Browse the repository at this point in the history
  • Loading branch information
marcromeyn authored Sep 13, 2022
2 parents 834ecb0 + 11617aa commit a3e1a0a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 11 deletions.
35 changes: 29 additions & 6 deletions merlin/models/tf/prediction_tasks/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,31 +70,54 @@ def __init__(
else:
target_name = target if target else kwargs.pop("target_name", None)

logit = kwargs.pop("logit", None)
output_layer = kwargs.pop("output_layer", None)
super().__init__(
target_name=target_name,
task_name=task_name,
task_block=task_block,
**kwargs,
)
self.logit = logit or tf.keras.layers.Dense(1, name=self.child_name("logit"))
self.output_layer = output_layer or tf.keras.layers.Dense(
1, name=self.child_name("output_layer")
)
# To ensure that the output is always fp32, avoiding numerical
# instabilities with mixed_float16 policy
self.output_activation = tf.keras.layers.Activation(
"linear", dtype="float32", name="prediction"
)

def call(self, inputs, training=False, **kwargs):
return self.output_activation(self.logit(inputs))
def call(self, inputs: tf.Tensor, training=False, **kwargs) -> tf.Tensor:
"""Projects the input with the output layer to a single logit
Parameters
----------
inputs : tf.Tensor
Input tensor
training : bool, optional
Flag that indicates whether it is training or not, by default False
Returns
-------
tf.Tensor
Tensor with the regression logit
"""
return self.output_activation(self.output_layer(inputs))

def compute_output_shape(self, input_shape):
return self.output_layer.compute_output_shape(input_shape)

def get_config(self):
config = super().get_config()
config = maybe_serialize_keras_objects(self, config, {"logit": tf.keras.layers.serialize})
config = maybe_serialize_keras_objects(
self, config, {"output_layer": tf.keras.layers.serialize}
)

return config

@classmethod
def from_config(cls, config):
config = maybe_deserialize_keras_objects(config, ["logit"], tf.keras.layers.deserialize)
config = maybe_deserialize_keras_objects(
config, ["output_layer"], tf.keras.layers.deserialize
)

return super().from_config(config)
25 changes: 20 additions & 5 deletions tests/unit/tf/prediction_tasks/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,29 @@
#
import pytest

import merlin.models.tf as ml
import merlin.models.tf as mm
from merlin.io import Dataset
from merlin.models.tf.utils import testing_utils


@pytest.mark.parametrize("run_eagerly", [True, False])
def test_regression_head(ecommerce_data: Dataset, run_eagerly):
body = ml.InputBlock(ecommerce_data.schema).connect(ml.MLPBlock([64]))
model = ml.Model(body, ml.RegressionTask("click"))
def test_regression_head(ecommerce_data: Dataset, run_eagerly: bool):
body = mm.InputBlock(ecommerce_data.schema).connect(mm.MLPBlock([64]))
model = mm.Model(body, mm.RegressionTask("click"))

testing_utils.model_test(model, ecommerce_data)
testing_utils.model_test(model, ecommerce_data, run_eagerly=run_eagerly)


@pytest.mark.parametrize("run_eagerly", [True, False])
def test_regression_head_schema(music_streaming_data: Dataset, run_eagerly: bool):
body = mm.InputBlock(music_streaming_data.schema).connect(mm.MLPBlock([64]))
model = mm.Model(body, mm.RegressionTask(music_streaming_data.schema))

testing_utils.model_test(model, music_streaming_data, run_eagerly=run_eagerly)


def test_regression_head_serialization(music_streaming_data: Dataset):
regression_task = mm.RegressionTask("click")
assert isinstance(
regression_task.from_config(regression_task.get_config()), type(regression_task)
)

0 comments on commit a3e1a0a

Please sign in to comment.