From 2cf4aff6c2d831a07a887be8bee945224b5a9a89 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 18 Jul 2022 09:29:19 +0200 Subject: [PATCH 01/27] First commit --- merlin/models/tf/predictions/__init__.py | 0 merlin/models/tf/predictions/base.py | 45 ++++++++++++++++++++++ merlin/models/tf/predictions/binary.py | 29 ++++++++++++++ merlin/models/tf/predictions/regression.py | 26 +++++++++++++ 4 files changed, 100 insertions(+) create mode 100644 merlin/models/tf/predictions/__init__.py create mode 100644 merlin/models/tf/predictions/base.py create mode 100644 merlin/models/tf/predictions/binary.py create mode 100644 merlin/models/tf/predictions/regression.py diff --git a/merlin/models/tf/predictions/__init__.py b/merlin/models/tf/predictions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py new file mode 100644 index 0000000000..0b1c1492cb --- /dev/null +++ b/merlin/models/tf/predictions/base.py @@ -0,0 +1,45 @@ +from tensorflow.keras.layers import Layer + +from merlin.models.tf.core.transformations import LogitsTemperatureScaler + + +class PredictionBlock(Layer): + def __init__( + self, + prediction, + default_loss, + default_metrics, + target=None, + pre=None, + post=None, + logits_temperature=1.0 + ): + self.prediction = prediction + self.default_loss = default_loss + self.default_metrics = default_metrics + self.target = target + self.pre = pre + self.post = post + self.logits_temperature = logits_temperature + if logits_temperature != 1.0: + self.logits_scaler = LogitsTemperatureScaler(logits_temperature) + + def call(self, inputs, context): + return self.prediction(inputs, context) + + def __call__(self, inputs, *args, **kwargs): + # call pre + if self.pre: + inputs = self.pre(inputs, *args, **kwargs) + + # super call + outputs = super().__call__(inputs, *args, **kwargs) + + if self.post: + outputs = self.post(outputs, *args, **kwargs) + + if getattr(self, "logits_scaler", None): + outputs = self.logits_scaler(outputs) + + return outputs + diff --git a/merlin/models/tf/predictions/binary.py b/merlin/models/tf/predictions/binary.py new file mode 100644 index 0000000000..d6a66b4a56 --- /dev/null +++ b/merlin/models/tf/predictions/binary.py @@ -0,0 +1,29 @@ +import tensorflow as tf + +from merlin.models.tf.predictions.base import PredictionBlock + + +class BinaryPrediction(PredictionBlock): + def __init__( + self, + default_loss="binary_crossentropy", + default_metrics=( + tf.keras.metrics.Precision, + tf.keras.metrics.Recall, + tf.keras.metrics.BinaryAccuracy, + tf.keras.metrics.AUC, + ), + target=None, + pre=None, + post=None, + logits_temperature=1.0 + ): + super().__init__( + prediction=tf.keras.layers.Dense(1, activation="sigmoid"), + default_loss=default_loss, + default_metrics=default_metrics, + target=target, + pre=pre, + post=post, + logits_temperature=logits_temperature + ) diff --git a/merlin/models/tf/predictions/regression.py b/merlin/models/tf/predictions/regression.py new file mode 100644 index 0000000000..dd47284b92 --- /dev/null +++ b/merlin/models/tf/predictions/regression.py @@ -0,0 +1,26 @@ +import tensorflow as tf + +from merlin.models.tf.predictions.base import PredictionBlock + + +class RegressionPrediction(PredictionBlock): + def __init__( + self, + default_loss="mse", + default_metrics=( + tf.keras.metrics.RootMeanSquaredError, + ), + target=None, + pre=None, + post=None, + logits_temperature=1.0 + ): + super().__init__( + prediction=tf.keras.layers.Dense(1, activation="linear"), + default_loss=default_loss, + default_metrics=default_metrics, + target=target, + pre=pre, + post=post, + logits_temperature=logits_temperature + ) From e0aed3f83e77e6afa42fa20162a90b2841b8d6cc Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Mon, 18 Jul 2022 17:45:06 +0200 Subject: [PATCH 02/27] Introducing PredictionBlock --- merlin/models/tf/__init__.py | 2 + merlin/models/tf/core/combinators.py | 12 +- merlin/models/tf/core/prediction.py | 4 + merlin/models/tf/models/base.py | 142 ++++++++++---- merlin/models/tf/predictions/base.py | 173 ++++++++++++++---- tests/unit/tf/predictions/__init__.py | 0 tests/unit/tf/predictions/test_base.py | 61 ++++++ .../tf/predictions/test_classification.py | 0 tests/unit/tf/predictions/test_regression.py | 0 9 files changed, 318 insertions(+), 76 deletions(-) create mode 100644 tests/unit/tf/predictions/__init__.py create mode 100644 tests/unit/tf/predictions/test_base.py create mode 100644 tests/unit/tf/predictions/test_classification.py create mode 100644 tests/unit/tf/predictions/test_regression.py diff --git a/merlin/models/tf/__init__.py b/merlin/models/tf/__init__.py index c0c04eb1dc..03382301b8 100644 --- a/merlin/models/tf/__init__.py +++ b/merlin/models/tf/__init__.py @@ -109,6 +109,7 @@ from merlin.models.tf.prediction_tasks.next_item import NextItemPredictionTask from merlin.models.tf.prediction_tasks.regression import RegressionTask from merlin.models.tf.prediction_tasks.retrieval import ItemRetrievalTask +from merlin.models.tf.predictions.base import PredictionBlock from merlin.models.tf.utils import repr_utils from merlin.models.tf.utils.tf_utils import TensorInitializer @@ -168,6 +169,7 @@ "DotProductInteraction", "FMPairwiseInteraction", "LabelToOneHot", + "PredictionBlock", "PredictionTask", "BinaryClassificationTask", "MultiClassClassificationTask", diff --git a/merlin/models/tf/core/combinators.py b/merlin/models/tf/core/combinators.py index 2e9b5f7acc..e973b4d275 100644 --- a/merlin/models/tf/core/combinators.py +++ b/merlin/models/tf/core/combinators.py @@ -351,6 +351,7 @@ def __init__( name: Optional[str] = None, strict: bool = False, automatic_pruning: bool = True, + use_layer_name: bool = True, **kwargs, ): super().__init__( @@ -370,10 +371,13 @@ def __init__( parsed_to_merge[key] = val self.parallel_layers = parsed_to_merge elif all(isinstance(x, tf.keras.layers.Layer) for x in inputs): - parsed: List[TabularBlock] = [] - for i, inp in enumerate(inputs): - parsed.append(inp) # type: ignore - self.parallel_layers = parsed + if use_layer_name: + self.parallel_layers = {layer.name: layer for layer in inputs} + else: + parsed: List[TabularBlock] = [] + for i, inp in enumerate(inputs): + parsed.append(inp) # type: ignore + self.parallel_layers = parsed else: raise ValueError( "Please provide one or multiple layer's to merge or " diff --git a/merlin/models/tf/core/prediction.py b/merlin/models/tf/core/prediction.py index c69f0d9670..79cd6bafad 100644 --- a/merlin/models/tf/core/prediction.py +++ b/merlin/models/tf/core/prediction.py @@ -56,3 +56,7 @@ class Prediction(NamedTuple): outputs: Dict[str, TensorLike] targets: Optional[Union[tf.Tensor, Dict[str, tf.Tensor]]] = None features: Optional[Dict[str, TensorLike]] = None + + @property + def predictions(self): + return self.outputs diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index bc525bfcd4..55cea5cd8f 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -23,8 +23,9 @@ from merlin.models.tf.metrics.topk import filter_topk_metrics from merlin.models.tf.models.utils import parse_prediction_tasks from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask +from merlin.models.tf.predictions.base import PredictionBlock from merlin.models.tf.utils.search_utils import find_all_instances_in_layers -from merlin.models.tf.utils.tf_utils import call_layer, maybe_deserialize_keras_objects +from merlin.models.tf.utils.tf_utils import call_layer, maybe_serialize_keras_objects from merlin.models.utils.dataset import unique_rows_by_features from merlin.schema import Schema, Tags @@ -288,37 +289,24 @@ def compile( initial_value=lambda: False, ) - self.output_names = [task.task_name for task in self.prediction_tasks] + num_v1_blocks = len(self.prediction_tasks) + num_v2_blocks = len(self.prediction_blocks) - _metrics = {} - if isinstance(metrics, (list, tuple)) and len(self.prediction_tasks) == 1: - _metrics = {task.task_name: metrics for task in self.prediction_tasks} - - # If metrics are not provided, use the defaults from the prediction-tasks. - # TODO: Do the same for weight_metrics. - if not metrics: - for task_name, task in self.prediction_tasks_by_name().items(): - _metrics[task_name] = [ - m() if inspect.isclass(m) else m for m in task.DEFAULT_METRICS - ] - - _loss = {} - if isinstance(loss, (tf.keras.losses.Loss, str)) and len(self.prediction_tasks) == 1: - _loss = {task.task_name: loss for task in self.prediction_tasks} - - # If loss is not provided, use the defaults from the prediction-tasks. - if not loss: - for task_name, task in self.prediction_tasks_by_name().items(): - _loss[task_name] = task.DEFAULT_LOSS + if num_v1_blocks > 1 and num_v2_blocks > 1: + raise ValueError( + "You cannot use both `prediction_tasks` and `prediction_blocks` at the same time.", + "`prediction_tasks` is deprecated and will be removed in a future version.", + ) - for key in _loss: - if isinstance(_loss[key], str) and _loss[key] in loss_registry: - _loss[key] = loss_registry.parse(_loss[key]) + if num_v1_blocks > 0: + self.output_names = [task.task_name for task in self.prediction_tasks] + else: + self.output_names = [block.full_name for block in self.prediction_blocks] super(BaseModel, self).compile( optimizer=optimizer, - loss=_loss, - metrics=_metrics, + loss=self._create_loss(loss), + metrics=self._create_metrics(metrics), weighted_metrics=weighted_metrics, run_eagerly=run_eagerly, loss_weights=loss_weights, @@ -327,6 +315,50 @@ def compile( **kwargs, ) + def _create_metrics(self, metrics=None): + out = {} + + num_v1_blocks = len(self.prediction_tasks) + num_v2_blocks = len(self.prediction_blocks) + + if isinstance(metrics, (list, tuple)): + if num_v1_blocks == 1: + out = {task.task_name: metrics for task in self.prediction_tasks} + elif num_v2_blocks == 1: + out = {task.task_name: metrics for task in self.prediction_blocks} + + if not metrics: + for task_name, task in self.prediction_tasks_by_name().items(): + out[task_name] = [m() if inspect.isclass(m) else m for m in task.DEFAULT_METRICS] + + for task_name, task in self.predictions_by_name().items(): + out[task_name] = [m() if inspect.isclass(m) else m for m in task.default_metrics] + + return out + + def _create_loss(self, loss=None): + out = {} + + if isinstance(loss, (tf.keras.losses.Loss, str)): + if len(self.prediction_tasks) == 1: + out = {task.task_name: loss for task in self.prediction_tasks} + elif len(self.prediction_blocks) == 1: + out = {task.name: loss for task in self.prediction_blocks} + + # If loss is not provided, use the defaults from the prediction-tasks. + if not loss: + for task_name, task in self.prediction_tasks_by_name().items(): + out[task_name] = task.DEFAULT_LOSS + + for task_name, task in self.predictions_by_name().items(): + out[task_name] = task.default_loss + + for key in out: + if isinstance(out[key], str) and out[key] in loss_registry: + out[key] = loss_registry.parse(out[key]) + + return out + @property def prediction_tasks(self) -> List[PredictionTask]: from merlin.models.tf.prediction_tasks.base import PredictionTask @@ -350,9 +382,30 @@ def prediction_tasks_by_target(self) -> Dict[str, List[PredictionTask]]: return outputs + @property + def prediction_blocks(self) -> List[PredictionBlock]: + results = find_all_instances_in_layers(self, PredictionBlock) + + return results + + def predictions_by_name(self) -> Dict[str, PredictionBlock]: + return {task.full_name: task for task in self.prediction_blocks} + + def predictions_by_target(self) -> Dict[str, List[PredictionBlock]]: + outputs: Dict[str, List[PredictionBlock]] = {} + for task in self.prediction_blocks: + if task.target in outputs: + if isinstance(outputs[task.target], list): + outputs[task.target].append(task) + else: + outputs[task.target] = [outputs[task.target], task] + outputs[task.target] = task + + return outputs + def call_train_test( self, x, y=None, training=False, testing=False, **kwargs - ) -> PredictionOutput: + ) -> Union[Prediction, PredictionOutput]: forward = self( x, targets=y, @@ -360,10 +413,11 @@ def call_train_test( testing=testing, **kwargs, ) - if not self.prediction_tasks: + if not (self.prediction_tasks or self.prediction_blocks): return PredictionOutput(forward, y) predictions, targets, output = {}, {}, None + # V1 for task in self.prediction_tasks: task_x = forward if isinstance(forward, dict) and task.task_name in forward: @@ -378,14 +432,34 @@ def call_train_test( targets[task.task_name] = task_y predictions[task.task_name] = task_x + if len(predictions) == 1 and len(targets) == 1: + predictions = predictions[list(predictions.keys())[0]] + targets = targets[list(targets.keys())[0]] + + if output: + return output.copy_with_updates(predictions, targets) + + return PredictionOutput(predictions, targets) + + # V2 + for task in self.prediction_blocks: + task_x = forward + if isinstance(forward, dict) and task.full_name in forward: + task_x = forward[task.full_name] + if isinstance(task_x, Prediction): + task_y = task_x.targets + task_x = task_x.outputs + else: + task_y = y[task.target] if isinstance(y, dict) and y else y + + targets[task.full_name] = task_y + predictions[task.full_name] = task_x + if len(predictions) == 1 and len(targets) == 1: predictions = predictions[list(predictions.keys())[0]] targets = targets[list(targets.keys())[0]] - if output: - return output.copy_with_updates(predictions, targets) - - return PredictionOutput(predictions, targets) + return Prediction(predictions, targets) def train_step(self, data): """Custom train step using the `compute_loss` method.""" @@ -832,7 +906,7 @@ def from_config(cls, config, custom_objects=None): return cls(*layers, pre=pre, post=post) def get_config(self): - config = maybe_deserialize_keras_objects({}, ["pre", "post"]) + config = maybe_serialize_keras_objects(self, {}, ["pre", "post"]) for i, layer in enumerate(self.blocks): config[i] = tf.keras.utils.serialize_keras_object(layer) diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py index 0b1c1492cb..2552158590 100644 --- a/merlin/models/tf/predictions/base.py +++ b/merlin/models/tf/predictions/base.py @@ -1,45 +1,142 @@ +from typing import Optional, Sequence, Union + +import tensorflow as tf +from keras.utils.generic_utils import to_snake_case from tensorflow.keras.layers import Layer +from merlin.models.tf.core.base import name_fn +from merlin.models.tf.core.prediction import Prediction from merlin.models.tf.core.transformations import LogitsTemperatureScaler +from merlin.models.tf.utils import tf_utils +@tf.keras.utils.register_keras_serializable(package="merlin.models") class PredictionBlock(Layer): - def __init__( - self, - prediction, - default_loss, - default_metrics, - target=None, - pre=None, - post=None, - logits_temperature=1.0 - ): - self.prediction = prediction - self.default_loss = default_loss - self.default_metrics = default_metrics - self.target = target - self.pre = pre - self.post = post - self.logits_temperature = logits_temperature - if logits_temperature != 1.0: - self.logits_scaler = LogitsTemperatureScaler(logits_temperature) - - def call(self, inputs, context): - return self.prediction(inputs, context) - - def __call__(self, inputs, *args, **kwargs): - # call pre - if self.pre: - inputs = self.pre(inputs, *args, **kwargs) - - # super call - outputs = super().__call__(inputs, *args, **kwargs) - - if self.post: - outputs = self.post(outputs, *args, **kwargs) - - if getattr(self, "logits_scaler", None): - outputs = self.logits_scaler(outputs) - - return outputs + def __init__( + self, + prediction: Layer, + default_loss: Union[str, tf.keras.losses.Loss], + default_metrics: Sequence[tf.keras.metrics.Metric], + name: Optional[str] = None, + target: Optional[str] = None, + pre: Optional[Layer] = None, + post: Optional[Layer] = None, + logits_temperature=1.0, + **kwargs, + ): + logits_scaler = kwargs.pop("logits_scaler", None) + self.target = target + base_name = to_snake_case(self.__class__.__name__) + self.full_name = name_fn(self.target, base_name) if self.target else base_name + + super().__init__(name=name or self.full_name, **kwargs) + self.prediction = prediction + self.default_loss = default_loss + self.default_metrics = default_metrics + self.pre = pre + self.post = post + if logits_scaler is not None: + self.logits_scaler = logits_scaler + self.logits_temperature = logits_scaler.temperature + else: + self.logits_temperature = logits_temperature + if logits_temperature != 1.0: + self.logits_scaler = LogitsTemperatureScaler(logits_temperature) + + def build(self, input_shape=None): + """Builds the PredictionBlock. + + Parameters + ---------- + input_shape : tf.TensorShape, optional + The input shape, by default None + """ + if self.pre is not None: + self.pre.build(input_shape) + input_shape = self.pre.compute_output_shape(input_shape) + + input_shape = self.prediction.compute_output_shape(input_shape) + + if self.post is not None: + self.post.build(input_shape) + + self.built = True + + def call(self, inputs, **kwargs): + return tf_utils.call_layer(self.prediction, inputs, **kwargs) + + def compute_output_shape(self, input_shape): + output_shape = input_shape + if self.pre is not None: + output_shape = self.pre.compute_output_shape(output_shape) + + output_shape = self.prediction.compute_output_shape(output_shape) + + if self.post is not None: + output_shape = self.post.compute_output_shape(output_shape) + + return output_shape + + def __call__(self, inputs, *args, **kwargs): + # call pre + if self.pre: + inputs = tf_utils.call_layer(self.pre, inputs, **kwargs) + + # super call + outputs = super(PredictionBlock, self).__call__(inputs, *args, **kwargs) + + if self.post: + outputs = tf_utils.call_layer(self.post, inputs, **kwargs) + + if getattr(self, "logits_scaler", None): + outputs = self.logits_scaler(outputs) + + if kwargs.get("training", False) or kwargs.get("testing", False): + targets = kwargs.get("targets", {}) + if isinstance(targets, dict) and self.target: + targets = targets.get(self.target, targets) + + return Prediction(outputs, targets) + + return outputs + + def get_config(self): + config = super(PredictionBlock, self).get_config() + config.update( + { + "target": self.target, + } + ) + + objects = [ + "default_metrics", + "prediction", + "pre", + "post", + "logits_scaler", + ] + + if isinstance(self.default_loss, str): + config["default_loss"] = self.default_loss + else: + objects.append("default_loss") + + config = tf_utils.maybe_serialize_keras_objects(self, config, objects) + + return config + + @classmethod + def from_config(cls, config): + config = tf_utils.maybe_deserialize_keras_objects( + config, + { + "default_metrics": tf.keras.metrics.deserialize, + "default_loss": tf.keras.losses.deserialize, + "prediction": tf.keras.layers.deserialize, + "pre": tf.keras.layers.deserialize, + "post": tf.keras.layers.deserialize, + "logits_scaler": tf.keras.layers.deserialize, + }, + ) + return super().from_config(config) diff --git a/tests/unit/tf/predictions/__init__.py b/tests/unit/tf/predictions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/tf/predictions/test_base.py b/tests/unit/tf/predictions/test_base.py new file mode 100644 index 0000000000..99241f841b --- /dev/null +++ b/tests/unit/tf/predictions/test_base.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest +import tensorflow as tf + +import merlin.models.tf as mm +from merlin.io import Dataset +from merlin.models.tf.utils import testing_utils + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_prediction_block(ecommerce_data: Dataset, run_eagerly): + model = mm.Model( + mm.InputBlock(ecommerce_data.schema), + mm.MLPBlock([64]), + _BinaryPrediction("click"), + ) + + _, history = testing_utils.model_test(model, ecommerce_data, run_eagerly=run_eagerly) + + assert set(history.history.keys()) == {"loss", "precision", "regularization_loss"} + + +@pytest.mark.parametrize("run_eagerly", [True]) +def test_parallel_prediction_blocks(ecommerce_data: Dataset, run_eagerly): + predictions = mm.ParallelBlock( + _BinaryPrediction("click"), + _BinaryPrediction("conversion"), + ) + + model = mm.Model( + mm.InputBlock(ecommerce_data.schema), + mm.MLPBlock([64]), + predictions, + ) + + _, history = testing_utils.model_test(model, ecommerce_data, run_eagerly=run_eagerly) + + assert len(history.history.keys()) == 6 + + +def _BinaryPrediction(name): + return mm.PredictionBlock( + tf.keras.layers.Dense(1, activation="sigmoid"), + default_loss="binary_crossentropy", + default_metrics=(tf.keras.metrics.Precision(),), + target=name, + ) diff --git a/tests/unit/tf/predictions/test_classification.py b/tests/unit/tf/predictions/test_classification.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/tf/predictions/test_regression.py b/tests/unit/tf/predictions/test_regression.py new file mode 100644 index 0000000000..e69de29bb2 From 75135dec5c234e7f347c36bd9ef062af21542461 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 11:28:42 +0200 Subject: [PATCH 03/27] Running black --- merlin/models/tf/predictions/binary.py | 28 ++++++++++++---------- merlin/models/tf/predictions/regression.py | 20 ++++++++-------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/merlin/models/tf/predictions/binary.py b/merlin/models/tf/predictions/binary.py index d6a66b4a56..0f4ded44fe 100644 --- a/merlin/models/tf/predictions/binary.py +++ b/merlin/models/tf/predictions/binary.py @@ -4,19 +4,21 @@ class BinaryPrediction(PredictionBlock): + """Binary-classification prediction block""" + def __init__( - self, - default_loss="binary_crossentropy", - default_metrics=( - tf.keras.metrics.Precision, - tf.keras.metrics.Recall, - tf.keras.metrics.BinaryAccuracy, - tf.keras.metrics.AUC, - ), - target=None, - pre=None, - post=None, - logits_temperature=1.0 + self, + default_loss="binary_crossentropy", + default_metrics=( + tf.keras.metrics.Precision, + tf.keras.metrics.Recall, + tf.keras.metrics.BinaryAccuracy, + tf.keras.metrics.AUC, + ), + target=None, + pre=None, + post=None, + logits_temperature=1.0, ): super().__init__( prediction=tf.keras.layers.Dense(1, activation="sigmoid"), @@ -25,5 +27,5 @@ def __init__( target=target, pre=pre, post=post, - logits_temperature=logits_temperature + logits_temperature=logits_temperature, ) diff --git a/merlin/models/tf/predictions/regression.py b/merlin/models/tf/predictions/regression.py index dd47284b92..895c14c9b2 100644 --- a/merlin/models/tf/predictions/regression.py +++ b/merlin/models/tf/predictions/regression.py @@ -4,16 +4,16 @@ class RegressionPrediction(PredictionBlock): + """Regression prediction block""" + def __init__( - self, - default_loss="mse", - default_metrics=( - tf.keras.metrics.RootMeanSquaredError, - ), - target=None, - pre=None, - post=None, - logits_temperature=1.0 + self, + default_loss="mse", + default_metrics=(tf.keras.metrics.RootMeanSquaredError,), + target=None, + pre=None, + post=None, + logits_temperature=1.0, ): super().__init__( prediction=tf.keras.layers.Dense(1, activation="linear"), @@ -22,5 +22,5 @@ def __init__( target=target, pre=pre, post=post, - logits_temperature=logits_temperature + logits_temperature=logits_temperature, ) From 1508f38ac244a049dcbede7037118fa4b123d830 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 13:42:00 +0200 Subject: [PATCH 04/27] Some fixes --- merlin/models/tf/models/base.py | 33 ++++++++++--------- merlin/models/tf/prediction_tasks/base.py | 21 +++++++++--- .../tf/prediction_tasks/classification.py | 2 +- tests/unit/tf/predictions/test_base.py | 13 ++++---- 4 files changed, 41 insertions(+), 28 deletions(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 55cea5cd8f..2ff3d0cb68 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -329,7 +329,7 @@ def _create_metrics(self, metrics=None): if not metrics: for task_name, task in self.prediction_tasks_by_name().items(): - out[task_name] = [m() if inspect.isclass(m) else m for m in task.DEFAULT_METRICS] + out[task_name] = task.create_default_metrics() for task_name, task in self.predictions_by_name().items(): out[task_name] = [m() if inspect.isclass(m) else m for m in task.default_metrics] @@ -373,11 +373,11 @@ def prediction_tasks_by_name(self) -> Dict[str, PredictionTask]: def prediction_tasks_by_target(self) -> Dict[str, List[PredictionTask]]: outputs: Dict[str, Union[PredictionTask, List[PredictionTask]]] = {} for task in self.prediction_tasks: - if task.target in outputs: - if isinstance(outputs[task.target], list): + if task.target_name in outputs: + if isinstance(outputs[task.target_name], list): outputs[task.target].append(task) else: - outputs[task.target] = [outputs[task.target], task] + outputs[task.target_name] = [outputs[task.target_name], task] outputs[task.target] = task return outputs @@ -418,19 +418,20 @@ def call_train_test( predictions, targets, output = {}, {}, None # V1 - for task in self.prediction_tasks: - task_x = forward - if isinstance(forward, dict) and task.task_name in forward: - task_x = forward[task.task_name] - if isinstance(task_x, PredictionOutput): - output = task_x - task_y = task_x.targets - task_x = task_x.predictions - else: - task_y = y[task.target_name] if isinstance(y, dict) and y else y + if self.prediction_tasks: + for task in self.prediction_tasks: + task_x = forward + if isinstance(forward, dict) and task.task_name in forward: + task_x = forward[task.task_name] + if isinstance(task_x, PredictionOutput): + output = task_x + task_y = task_x.targets + task_x = task_x.predictions + else: + task_y = y[task.target_name] if isinstance(y, dict) and y else y - targets[task.task_name] = task_y - predictions[task.task_name] = task_x + targets[task.task_name] = task_y + predictions[task.task_name] = task_x if len(predictions) == 1 and len(targets) == 1: predictions = predictions[list(predictions.keys())[0]] diff --git a/merlin/models/tf/prediction_tasks/base.py b/merlin/models/tf/prediction_tasks/base.py index bfa2c80eb4..196f3b5c84 100644 --- a/merlin/models/tf/prediction_tasks/base.py +++ b/merlin/models/tf/prediction_tasks/base.py @@ -65,10 +65,11 @@ def __init__( name: Optional[Text] = None, **kwargs, ) -> None: - super().__init__(name=name, **kwargs) self.target_name = target_name - self.task_block = task_block self._task_name = task_name + name = name or self.task_name + super().__init__(name=name, **kwargs) + self.task_block = task_block self.pre = pre self._pre_eval_topk = pre_eval_topk @@ -208,6 +209,13 @@ def get_config(self): return config + def create_default_metrics(self): + metrics = [] + for metric in self.DEFAULT_METRICS: + metrics.append(metric(name=self.child_name(to_snake_case(metric.__name__)))) + + return metrics + @tf.keras.utils.register_keras_serializable(package="merlin.models") class ParallelPredictionBlock(ParallelBlock): @@ -230,20 +238,24 @@ def __init__( *prediction_tasks: PredictionTask, task_blocks: Optional[Union[Layer, Dict[str, Layer]]] = None, bias_block: Optional[Layer] = None, + task_weights=None, pre: Optional[BlockType] = None, post: Optional[BlockType] = None, **kwargs, ): self.prediction_tasks = prediction_tasks self.bias_block = bias_block - self.bias_logit = tf.keras.layers.Dense(1) + if bias_block: + self.bias_logit = tf.keras.layers.Dense(1) self.prediction_task_dict = {} if prediction_tasks: for task in prediction_tasks: self.prediction_task_dict[task.task_name] = task - super(ParallelPredictionBlock, self).__init__(self.prediction_task_dict, pre=pre, post=post) + super(ParallelPredictionBlock, self).__init__( + self.prediction_task_dict, pre=pre, post=post, use_layer_name=False, **kwargs + ) if task_blocks: self._set_task_blocks(task_blocks) @@ -418,6 +430,7 @@ def from_config(cls, config, **kwargs): config["schema"] = tensorflow_metadata_json_to_schema(config["schema"]) prediction_tasks = config.pop("prediction_tasks", []) + config.pop("parallel_layers", None) return cls(*prediction_tasks, **config) diff --git a/merlin/models/tf/prediction_tasks/classification.py b/merlin/models/tf/prediction_tasks/classification.py index 1a14f83fe3..34f2071876 100644 --- a/merlin/models/tf/prediction_tasks/classification.py +++ b/merlin/models/tf/prediction_tasks/classification.py @@ -89,7 +89,7 @@ def __init__( ) self.output_layer = output_layer or tf.keras.layers.Dense( - 1, activation="linear", name=self.child_name("output_layer") + 1, activation="linear", name="output_layer" ) # To ensure that the output is always fp32, avoiding numerical # instabilities with mixed_float16 (fp16) policy diff --git a/tests/unit/tf/predictions/test_base.py b/tests/unit/tf/predictions/test_base.py index 99241f841b..c98a13e98b 100644 --- a/tests/unit/tf/predictions/test_base.py +++ b/tests/unit/tf/predictions/test_base.py @@ -36,15 +36,13 @@ def test_prediction_block(ecommerce_data: Dataset, run_eagerly): @pytest.mark.parametrize("run_eagerly", [True]) def test_parallel_prediction_blocks(ecommerce_data: Dataset, run_eagerly): - predictions = mm.ParallelBlock( - _BinaryPrediction("click"), - _BinaryPrediction("conversion"), - ) - model = mm.Model( mm.InputBlock(ecommerce_data.schema), mm.MLPBlock([64]), - predictions, + mm.ParallelBlock( + _BinaryPrediction("click", pre=mm.MLPBlock([16])), + _BinaryPrediction("conversion", pre=mm.MLPBlock([16])), + ), ) _, history = testing_utils.model_test(model, ecommerce_data, run_eagerly=run_eagerly) @@ -52,10 +50,11 @@ def test_parallel_prediction_blocks(ecommerce_data: Dataset, run_eagerly): assert len(history.history.keys()) == 6 -def _BinaryPrediction(name): +def _BinaryPrediction(name, **kwargs): return mm.PredictionBlock( tf.keras.layers.Dense(1, activation="sigmoid"), default_loss="binary_crossentropy", default_metrics=(tf.keras.metrics.Precision(),), target=name, + **kwargs ) From bff9e316c95c66ef100f5461e906212a8cb1e2db Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 15:23:09 +0200 Subject: [PATCH 05/27] Making metric-names better --- merlin/models/tf/models/base.py | 15 +++++++++------ merlin/models/tf/predictions/base.py | 20 +++++++++++++++++--- tests/unit/tf/predictions/test_base.py | 19 +++++++++++++------ 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 2ff3d0cb68..ce5a54a1fd 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -303,6 +303,8 @@ def compile( else: self.output_names = [block.full_name for block in self.prediction_blocks] + from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0) + super(BaseModel, self).compile( optimizer=optimizer, loss=self._create_loss(loss), @@ -312,6 +314,7 @@ def compile( loss_weights=loss_weights, steps_per_execution=steps_per_execution, jit_compile=jit_compile, + from_serialized=from_serialized, **kwargs, ) @@ -319,20 +322,20 @@ def _create_metrics(self, metrics=None): out = {} num_v1_blocks = len(self.prediction_tasks) - num_v2_blocks = len(self.prediction_blocks) if isinstance(metrics, (list, tuple)): - if num_v1_blocks == 1: + if num_v1_blocks > 0: out = {task.task_name: metrics for task in self.prediction_tasks} - elif num_v2_blocks == 1: - out = {task.task_name: metrics for task in self.prediction_blocks} + else: + for i, block in enumerate(self.prediction_blocks): + out[block.full_name] = metrics[i] if not metrics: for task_name, task in self.prediction_tasks_by_name().items(): - out[task_name] = task.create_default_metrics() + out[task_name] = [m() if inspect.isclass(m) else m for m in task.DEFAULT_METRICS] for task_name, task in self.predictions_by_name().items(): - out[task_name] = [m() if inspect.isclass(m) else m for m in task.default_metrics] + out[task_name] = task.create_default_metrics() return out diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py index 2552158590..ef9178fba4 100644 --- a/merlin/models/tf/predictions/base.py +++ b/merlin/models/tf/predictions/base.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Union import tensorflow as tf from keras.utils.generic_utils import to_snake_case @@ -32,7 +32,12 @@ def __init__( super().__init__(name=name or self.full_name, **kwargs) self.prediction = prediction self.default_loss = default_loss - self.default_metrics = default_metrics + self._default_metrics = [ + tf.keras.metrics.serialize(metric) + if isinstance(metric, tf.keras.metrics.Metric) + else metric + for metric in default_metrics + ] self.pre = pre self.post = post if logits_scaler is not None: @@ -100,16 +105,25 @@ def __call__(self, inputs, *args, **kwargs): return outputs + def create_default_metrics(self) -> List[tf.keras.metrics.Metric]: + metrics = [] + for metric in self._default_metrics: + name = self.full_name + "/" + to_snake_case(metric["class_name"]) + metric["config"]["name"] = name + metrics.append(tf.keras.metrics.deserialize(metric)) + + return metrics + def get_config(self): config = super(PredictionBlock, self).get_config() config.update( { "target": self.target, + "default_metrics": self._default_metrics, } ) objects = [ - "default_metrics", "prediction", "pre", "post", diff --git a/tests/unit/tf/predictions/test_base.py b/tests/unit/tf/predictions/test_base.py index c98a13e98b..89a48f830f 100644 --- a/tests/unit/tf/predictions/test_base.py +++ b/tests/unit/tf/predictions/test_base.py @@ -25,7 +25,7 @@ def test_prediction_block(ecommerce_data: Dataset, run_eagerly): model = mm.Model( mm.InputBlock(ecommerce_data.schema), - mm.MLPBlock([64]), + mm.MLPBlock([8]), _BinaryPrediction("click"), ) @@ -34,20 +34,27 @@ def test_prediction_block(ecommerce_data: Dataset, run_eagerly): assert set(history.history.keys()) == {"loss", "precision", "regularization_loss"} -@pytest.mark.parametrize("run_eagerly", [True]) +@pytest.mark.parametrize("run_eagerly", [True, False]) def test_parallel_prediction_blocks(ecommerce_data: Dataset, run_eagerly): model = mm.Model( mm.InputBlock(ecommerce_data.schema), - mm.MLPBlock([64]), + mm.MLPBlock([8]), mm.ParallelBlock( - _BinaryPrediction("click", pre=mm.MLPBlock([16])), - _BinaryPrediction("conversion", pre=mm.MLPBlock([16])), + _BinaryPrediction("click", pre=mm.MLPBlock([4])), + _BinaryPrediction("conversion", pre=mm.MLPBlock([4])), ), ) _, history = testing_utils.model_test(model, ecommerce_data, run_eagerly=run_eagerly) - assert len(history.history.keys()) == 6 + assert list(history.history.keys()) == [ + "loss", + "click/prediction_block_loss", + "conversion/prediction_block_loss", + "click/prediction_block/precision", + "conversion/prediction_block/precision", + "regularization_loss", + ] def _BinaryPrediction(name, **kwargs): From c72cdefbf3cb982f86eb9b411ba9d371c034f192 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 15:27:47 +0200 Subject: [PATCH 06/27] Move arguments in BinaryPrediction + RegressionPrediction --- merlin/models/tf/predictions/binary.py | 8 ++++---- merlin/models/tf/predictions/regression.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/merlin/models/tf/predictions/binary.py b/merlin/models/tf/predictions/binary.py index 0f4ded44fe..4f83f17473 100644 --- a/merlin/models/tf/predictions/binary.py +++ b/merlin/models/tf/predictions/binary.py @@ -8,6 +8,10 @@ class BinaryPrediction(PredictionBlock): def __init__( self, + target=None, + pre=None, + post=None, + logits_temperature=1.0, default_loss="binary_crossentropy", default_metrics=( tf.keras.metrics.Precision, @@ -15,10 +19,6 @@ def __init__( tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.AUC, ), - target=None, - pre=None, - post=None, - logits_temperature=1.0, ): super().__init__( prediction=tf.keras.layers.Dense(1, activation="sigmoid"), diff --git a/merlin/models/tf/predictions/regression.py b/merlin/models/tf/predictions/regression.py index 895c14c9b2..fac5054317 100644 --- a/merlin/models/tf/predictions/regression.py +++ b/merlin/models/tf/predictions/regression.py @@ -8,12 +8,12 @@ class RegressionPrediction(PredictionBlock): def __init__( self, - default_loss="mse", - default_metrics=(tf.keras.metrics.RootMeanSquaredError,), target=None, pre=None, post=None, logits_temperature=1.0, + default_loss="mse", + default_metrics=(tf.keras.metrics.RootMeanSquaredError,), ): super().__init__( prediction=tf.keras.layers.Dense(1, activation="linear"), From e8f1664b721e04aea9c33f1f1b2b3d303d9c571d Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 15:57:43 +0200 Subject: [PATCH 07/27] Updating type-hints --- merlin/models/tf/models/base.py | 1 + merlin/models/tf/predictions/base.py | 2 +- merlin/models/tf/predictions/binary.py | 15 +++++++++++---- merlin/models/tf/predictions/regression.py | 15 +++++++++++---- tests/unit/tf/predictions/test_base.py | 6 +++++- 5 files changed, 29 insertions(+), 10 deletions(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index ce5a54a1fd..5afede79b0 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -303,6 +303,7 @@ def compile( else: self.output_names = [block.full_name for block in self.prediction_blocks] + # This flag will make Keras change the metric-names which is not needed in v2 from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0) super(BaseModel, self).compile( diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py index ef9178fba4..e621244ce6 100644 --- a/merlin/models/tf/predictions/base.py +++ b/merlin/models/tf/predictions/base.py @@ -21,7 +21,7 @@ def __init__( target: Optional[str] = None, pre: Optional[Layer] = None, post: Optional[Layer] = None, - logits_temperature=1.0, + logits_temperature: float = 1.0, **kwargs, ): logits_scaler = kwargs.pop("logits_scaler", None) diff --git a/merlin/models/tf/predictions/binary.py b/merlin/models/tf/predictions/binary.py index 4f83f17473..9cb607dd3c 100644 --- a/merlin/models/tf/predictions/binary.py +++ b/merlin/models/tf/predictions/binary.py @@ -1,4 +1,7 @@ +from typing import Optional + import tensorflow as tf +from tensorflow.keras.layers import Layer from merlin.models.tf.predictions.base import PredictionBlock @@ -8,10 +11,11 @@ class BinaryPrediction(PredictionBlock): def __init__( self, - target=None, - pre=None, - post=None, - logits_temperature=1.0, + target: Optional[str] = None, + pre: Optional[Layer] = None, + post: Optional[Layer] = None, + logits_temperature: float = 1.0, + name: Optional[str] = None, default_loss="binary_crossentropy", default_metrics=( tf.keras.metrics.Precision, @@ -19,6 +23,7 @@ def __init__( tf.keras.metrics.BinaryAccuracy, tf.keras.metrics.AUC, ), + **kwargs, ): super().__init__( prediction=tf.keras.layers.Dense(1, activation="sigmoid"), @@ -28,4 +33,6 @@ def __init__( pre=pre, post=post, logits_temperature=logits_temperature, + name=name, + **kwargs, ) diff --git a/merlin/models/tf/predictions/regression.py b/merlin/models/tf/predictions/regression.py index fac5054317..5da393b9b5 100644 --- a/merlin/models/tf/predictions/regression.py +++ b/merlin/models/tf/predictions/regression.py @@ -1,4 +1,7 @@ +from typing import Optional + import tensorflow as tf +from tensorflow.keras.layers import Layer from merlin.models.tf.predictions.base import PredictionBlock @@ -8,12 +11,14 @@ class RegressionPrediction(PredictionBlock): def __init__( self, - target=None, - pre=None, - post=None, - logits_temperature=1.0, + target: Optional[str] = None, + pre: Optional[Layer] = None, + post: Optional[Layer] = None, + logits_temperature: float = 1.0, + name: Optional[str] = None, default_loss="mse", default_metrics=(tf.keras.metrics.RootMeanSquaredError,), + **kwargs, ): super().__init__( prediction=tf.keras.layers.Dense(1, activation="linear"), @@ -23,4 +28,6 @@ def __init__( pre=pre, post=post, logits_temperature=logits_temperature, + name=name, + **kwargs, ) diff --git a/tests/unit/tf/predictions/test_base.py b/tests/unit/tf/predictions/test_base.py index 89a48f830f..2bde533afc 100644 --- a/tests/unit/tf/predictions/test_base.py +++ b/tests/unit/tf/predictions/test_base.py @@ -31,7 +31,11 @@ def test_prediction_block(ecommerce_data: Dataset, run_eagerly): _, history = testing_utils.model_test(model, ecommerce_data, run_eagerly=run_eagerly) - assert set(history.history.keys()) == {"loss", "precision", "regularization_loss"} + assert set(history.history.keys()) == { + "loss", + "click/prediction_block/precision", + "regularization_loss", + } @pytest.mark.parametrize("run_eagerly", [True, False]) From 7edd18a5b2f4d193d35e68e8b9e3a2563eb7df00 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 19:09:47 +0200 Subject: [PATCH 08/27] Trying to fix failing tests --- merlin/models/tf/models/base.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 5afede79b0..d090baae95 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -304,7 +304,7 @@ def compile( self.output_names = [block.full_name for block in self.prediction_blocks] # This flag will make Keras change the metric-names which is not needed in v2 - from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0) + # from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0) super(BaseModel, self).compile( optimizer=optimizer, @@ -315,7 +315,7 @@ def compile( loss_weights=loss_weights, steps_per_execution=steps_per_execution, jit_compile=jit_compile, - from_serialized=from_serialized, + # from_serialized=from_serialized, **kwargs, ) @@ -326,7 +326,8 @@ def _create_metrics(self, metrics=None): if isinstance(metrics, (list, tuple)): if num_v1_blocks > 0: - out = {task.task_name: metrics for task in self.prediction_tasks} + for i, task in enumerate(self.prediction_tasks): + out[task.task_name] = metrics[i] else: for i, block in enumerate(self.prediction_blocks): out[block.full_name] = metrics[i] From 6cf887d08eeca140c215e66ecf1c8ac40526d78c Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 20:24:18 +0200 Subject: [PATCH 09/27] Trying to fix failing tests --- merlin/models/tf/models/base.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index d090baae95..089465de2d 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -326,11 +326,17 @@ def _create_metrics(self, metrics=None): if isinstance(metrics, (list, tuple)): if num_v1_blocks > 0: - for i, task in enumerate(self.prediction_tasks): - out[task.task_name] = metrics[i] + if num_v1_blocks == 1: + out[self.prediction_tasks[0].task_name] = metrics + else: + for i, task in enumerate(self.prediction_tasks): + out[task.task_name] = metrics[i] else: - for i, block in enumerate(self.prediction_blocks): - out[block.full_name] = metrics[i] + if len(self.prediction_blocks) == 1: + out[self.prediction_blocks[0].full_name] = metrics + else: + for i, block in enumerate(self.prediction_blocks): + out[block.full_name] = metrics[i] if not metrics: for task_name, task in self.prediction_tasks_by_name().items(): From d0e47287880714ace76c253e80c25147391f1069 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 21:02:28 +0200 Subject: [PATCH 10/27] Trying to fix failing tests --- merlin/models/tf/models/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 089465de2d..4743be75ac 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -304,7 +304,7 @@ def compile( self.output_names = [block.full_name for block in self.prediction_blocks] # This flag will make Keras change the metric-names which is not needed in v2 - # from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0) + from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0) super(BaseModel, self).compile( optimizer=optimizer, @@ -315,7 +315,7 @@ def compile( loss_weights=loss_weights, steps_per_execution=steps_per_execution, jit_compile=jit_compile, - # from_serialized=from_serialized, + from_serialized=from_serialized, **kwargs, ) From 0c251978bff8c7d1f5ae823d17e264e30da9f93c Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Thu, 21 Jul 2022 14:19:15 +0200 Subject: [PATCH 11/27] Adding BinaryPrediction test --- merlin/models/tf/__init__.py | 4 ++ .../{binary.py => classification.py} | 0 .../tf/predictions/test_classification.py | 40 +++++++++++++++++++ 3 files changed, 44 insertions(+) rename merlin/models/tf/predictions/{binary.py => classification.py} (100%) diff --git a/merlin/models/tf/__init__.py b/merlin/models/tf/__init__.py index 03382301b8..81063d674a 100644 --- a/merlin/models/tf/__init__.py +++ b/merlin/models/tf/__init__.py @@ -110,6 +110,8 @@ from merlin.models.tf.prediction_tasks.regression import RegressionTask from merlin.models.tf.prediction_tasks.retrieval import ItemRetrievalTask from merlin.models.tf.predictions.base import PredictionBlock +from merlin.models.tf.predictions.classification import BinaryPrediction +from merlin.models.tf.predictions.regression import RegressionPrediction from merlin.models.tf.utils import repr_utils from merlin.models.tf.utils.tf_utils import TensorInitializer @@ -170,6 +172,8 @@ "FMPairwiseInteraction", "LabelToOneHot", "PredictionBlock", + "BinaryPrediction", + "RegressionPrediction", "PredictionTask", "BinaryClassificationTask", "MultiClassClassificationTask", diff --git a/merlin/models/tf/predictions/binary.py b/merlin/models/tf/predictions/classification.py similarity index 100% rename from merlin/models/tf/predictions/binary.py rename to merlin/models/tf/predictions/classification.py diff --git a/tests/unit/tf/predictions/test_classification.py b/tests/unit/tf/predictions/test_classification.py index e69de29bb2..5fe268b33b 100644 --- a/tests/unit/tf/predictions/test_classification.py +++ b/tests/unit/tf/predictions/test_classification.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +import merlin.models.tf as mm +from merlin.io import Dataset +from merlin.models.tf.utils import testing_utils + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_binary_prediction_block(ecommerce_data: Dataset, run_eagerly): + model = mm.Model( + mm.InputBlock(ecommerce_data.schema), + mm.MLPBlock([8]), + mm.BinaryPrediction("click"), + ) + + _, history = testing_utils.model_test(model, ecommerce_data, run_eagerly=run_eagerly) + + assert set(history.history.keys()) == { + "loss", + "click/binary_prediction/precision", + "click/binary_prediction/recall", + "click/binary_prediction/binary_accuracy", + "click/binary_prediction/auc", + "regularization_loss", + } From d3f2664d6c1a437c3d5248cc4859535b9af0584b Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Thu, 21 Jul 2022 14:22:28 +0200 Subject: [PATCH 12/27] Adding RegressionPrediction test --- tests/unit/tf/predictions/test_regression.py | 37 ++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/unit/tf/predictions/test_regression.py b/tests/unit/tf/predictions/test_regression.py index e69de29bb2..368264fd71 100644 --- a/tests/unit/tf/predictions/test_regression.py +++ b/tests/unit/tf/predictions/test_regression.py @@ -0,0 +1,37 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +import merlin.models.tf as mm +from merlin.io import Dataset +from merlin.models.tf.utils import testing_utils + + +@pytest.mark.parametrize("run_eagerly", [True, False]) +def test_binary_prediction_block(ecommerce_data: Dataset, run_eagerly): + model = mm.Model( + mm.InputBlock(ecommerce_data.schema), + mm.MLPBlock([8]), + mm.RegressionPrediction("click"), + ) + + _, history = testing_utils.model_test(model, ecommerce_data, run_eagerly=run_eagerly) + + assert set(history.history.keys()) == { + "loss", + "click/binary_prediction/mse", + "regularization_loss", + } From 2f4d5ed8be8236c311ac6a03a6b5116704614c66 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 15:49:36 +0200 Subject: [PATCH 13/27] First commit --- merlin/models/tf/predictions/base.py | 60 +++++++++++++++++++++ merlin/models/tf/predictions/dot_product.py | 60 +++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 merlin/models/tf/predictions/dot_product.py diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py index e621244ce6..6fa1de8cb0 100644 --- a/merlin/models/tf/predictions/base.py +++ b/merlin/models/tf/predictions/base.py @@ -154,3 +154,63 @@ def from_config(cls, config): ) return super().from_config(config) + + +class ContrastivePredictionBlock(PredictionBlock): + def __init__( + self, + prediction: Layer, + prediction_with_negatives: Layer, + default_loss: Union[str, tf.keras.losses.Loss], + default_metrics: Sequence[tf.keras.metrics.Metric], + default_contrastive_metrics: Sequence[tf.keras.metrics.Metric], + name: Optional[str] = None, + target: Optional[str] = None, + pre: Optional[Layer] = None, + post: Optional[Layer] = None, + logits_temperature=1.0, + negative_sampling=None, + downscore_false_negatives=False, + **kwargs, + ): + super(ContrastivePredictionBlock, self).__init__( + prediction, + default_loss=default_loss, + default_metrics=default_metrics, + target=target, + pre=pre, + post=post, + logits_temperature=logits_temperature, + name=name, + **kwargs, + ) + self.prediction_with_negatives = prediction_with_negatives + self.negative_sampling = negative_sampling + self.downscore_false_negatives = downscore_false_negatives + self._default_contrastive_metrics = [ + tf.keras.metrics.serialize(metric) + if isinstance(metric, tf.keras.metrics.Metric) + else metric + for metric in default_contrastive_metrics + ] + + @property + def has_negative_samplers(self) -> bool: + return self.negative_sampling is not None and len(self.negative_sampling) > 0 + + def compile( + self, + negative_sampling=None, + downscore_false_negatives=False + ): + self.negative_sampling = negative_sampling + self.downscore_false_negatives = downscore_false_negatives + + def call(self, inputs, training=False, testing=False, **kwargs): + to_call = self.prediction + + if self.has_negative_samplers and (training or testing): + to_call = self.prediction_with_negatives + + return to_call(inputs, training=training, testing=testing, **kwargs) + diff --git a/merlin/models/tf/predictions/dot_product.py b/merlin/models/tf/predictions/dot_product.py new file mode 100644 index 0000000000..8ae610049e --- /dev/null +++ b/merlin/models/tf/predictions/dot_product.py @@ -0,0 +1,60 @@ +from typing import Union, Sequence, Optional + +import tensorflow as tf +from tensorflow.keras.layers import Layer + +from merlin.models.tf.predictions.base import ContrastivePredictionBlock + + +class DotProductCategoricalPrediction(ContrastivePredictionBlock): + def __init__( + self, + negative_sampling=None, + downscore_false_negatives=False, + target=None, + pre=None, + post=None, + logits_temperature=1.0, + name=None, + default_loss="categorical-cross-entropy", + default_metrics=(), + default_contrastive_metrics=(), + **kwargs, + ): + super().__init__( + prediction=DotProduct(), + prediction_with_negatives=ContrastiveDotProduct(), + default_loss=default_loss, + default_metrics=default_metrics, + default_contrastive_metrics=default_contrastive_metrics, + name=name, + target=target, + pre=pre, + post=post, + logits_temperature=logits_temperature, + negative_sampling=negative_sampling, + downscore_false_negatives=downscore_false_negatives, + **kwargs + ) + + +class DotProduct(Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, inputs): + return tf.reduce_sum(inputs, axis=1) + + def compute_output_shape(self, input_shape): + return input_shape[0], 1 + + +class ContrastiveDotProduct(Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, inputs): + return tf.reduce_sum(inputs, axis=1) + + def compute_output_shape(self, input_shape): + return input_shape[0], 1 \ No newline at end of file From 8142ab05ca53f06261d58f7f854436a7a97459ed Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 16:03:40 +0200 Subject: [PATCH 14/27] Make in-batch the default --- merlin/models/tf/predictions/dot_product.py | 31 +++++++++++---------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/merlin/models/tf/predictions/dot_product.py b/merlin/models/tf/predictions/dot_product.py index 8ae610049e..8b16f67986 100644 --- a/merlin/models/tf/predictions/dot_product.py +++ b/merlin/models/tf/predictions/dot_product.py @@ -1,4 +1,4 @@ -from typing import Union, Sequence, Optional +from typing import Optional, Sequence, Union import tensorflow as tf from tensorflow.keras.layers import Layer @@ -6,20 +6,21 @@ from merlin.models.tf.predictions.base import ContrastivePredictionBlock +# Or: RetrievalCategoricalPrediction class DotProductCategoricalPrediction(ContrastivePredictionBlock): def __init__( - self, - negative_sampling=None, - downscore_false_negatives=False, - target=None, - pre=None, - post=None, - logits_temperature=1.0, - name=None, - default_loss="categorical-cross-entropy", - default_metrics=(), - default_contrastive_metrics=(), - **kwargs, + self, + negative_sampling="in-batch", + downscore_false_negatives=False, + target=None, + pre=None, + post=None, + logits_temperature=1.0, + name=None, + default_loss="categorical-cross-entropy", + default_metrics=(), + default_contrastive_metrics=(), + **kwargs, ): super().__init__( prediction=DotProduct(), @@ -34,7 +35,7 @@ def __init__( logits_temperature=logits_temperature, negative_sampling=negative_sampling, downscore_false_negatives=downscore_false_negatives, - **kwargs + **kwargs, ) @@ -57,4 +58,4 @@ def call(self, inputs): return tf.reduce_sum(inputs, axis=1) def compute_output_shape(self, input_shape): - return input_shape[0], 1 \ No newline at end of file + return input_shape[0], 1 From 7c1a23463bfe0d5b602a9862e60c0b47f6644810 Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Tue, 19 Jul 2022 20:11:46 +0200 Subject: [PATCH 15/27] First pass over DotProduct --- merlin/models/tf/predictions/base.py | 70 +++++++------- merlin/models/tf/predictions/dot_product.py | 91 ++++++++++++++----- tests/unit/tf/predictions/test_dot_product.py | 0 3 files changed, 105 insertions(+), 56 deletions(-) create mode 100644 tests/unit/tf/predictions/test_dot_product.py diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py index 6fa1de8cb0..7fbe2e618e 100644 --- a/merlin/models/tf/predictions/base.py +++ b/merlin/models/tf/predictions/base.py @@ -158,20 +158,19 @@ def from_config(cls, config): class ContrastivePredictionBlock(PredictionBlock): def __init__( - self, - prediction: Layer, - prediction_with_negatives: Layer, - default_loss: Union[str, tf.keras.losses.Loss], - default_metrics: Sequence[tf.keras.metrics.Metric], - default_contrastive_metrics: Sequence[tf.keras.metrics.Metric], - name: Optional[str] = None, - target: Optional[str] = None, - pre: Optional[Layer] = None, - post: Optional[Layer] = None, - logits_temperature=1.0, - negative_sampling=None, - downscore_false_negatives=False, - **kwargs, + self, + prediction: Layer, + prediction_with_negatives: Layer, + default_loss: Union[str, tf.keras.losses.Loss], + default_metrics: Sequence[tf.keras.metrics.Metric], + name: Optional[str] = None, + target: Optional[str] = None, + pre: Optional[Layer] = None, + post: Optional[Layer] = None, + logits_temperature: float = 1.0, + negative_sampling=None, + downscore_false_negatives=False, + **kwargs, ): super(ContrastivePredictionBlock, self).__init__( prediction, @@ -185,26 +184,6 @@ def __init__( **kwargs, ) self.prediction_with_negatives = prediction_with_negatives - self.negative_sampling = negative_sampling - self.downscore_false_negatives = downscore_false_negatives - self._default_contrastive_metrics = [ - tf.keras.metrics.serialize(metric) - if isinstance(metric, tf.keras.metrics.Metric) - else metric - for metric in default_contrastive_metrics - ] - - @property - def has_negative_samplers(self) -> bool: - return self.negative_sampling is not None and len(self.negative_sampling) > 0 - - def compile( - self, - negative_sampling=None, - downscore_false_negatives=False - ): - self.negative_sampling = negative_sampling - self.downscore_false_negatives = downscore_false_negatives def call(self, inputs, training=False, testing=False, **kwargs): to_call = self.prediction @@ -214,3 +193,26 @@ def call(self, inputs, training=False, testing=False, **kwargs): return to_call(inputs, training=training, testing=testing, **kwargs) + @property + def has_negative_samplers(self) -> bool: + return self.negative_sampling is not None and len(self.negative_sampling) > 0 + + def compile(self, negative_sampling=None, downscore_false_negatives=False): + self.prediction_with_negatives.negative_sampling = negative_sampling + self.prediction_with_negatives.downscore_false_negatives = downscore_false_negatives + + @property + def negative_sampling(self): + return self.prediction_with_negatives.negative_sampling + + @negative_sampling.setter + def negative_sampling(self, value): + self.prediction_with_negatives.negative_sampling = value + + @property + def downscore_false_negatives(self): + return self.prediction_with_negatives.downscore_false_negatives + + @downscore_false_negatives.setter + def downscore_false_negatives(self, value): + self.prediction_with_negatives.downscore_false_negatives = value diff --git a/merlin/models/tf/predictions/dot_product.py b/merlin/models/tf/predictions/dot_product.py index 8b16f67986..341171b9a6 100644 --- a/merlin/models/tf/predictions/dot_product.py +++ b/merlin/models/tf/predictions/dot_product.py @@ -3,31 +3,45 @@ import tensorflow as tf from tensorflow.keras.layers import Layer +from merlin.models.tf.core.prediction import Prediction +from merlin.models.tf.metrics.topk import AvgPrecisionAt, MRRAt, NDCGAt, PrecisionAt, RecallAt from merlin.models.tf.predictions.base import ContrastivePredictionBlock - # Or: RetrievalCategoricalPrediction +from merlin.models.utils.constants import MIN_FLOAT +from merlin.schema import Tags + + +@tf.keras.utils.register_keras_serializable(package="merlin_models") class DotProductCategoricalPrediction(ContrastivePredictionBlock): + DEFAULT_K = 10 + def __init__( self, negative_sampling="in-batch", downscore_false_negatives=False, - target=None, - pre=None, - post=None, - logits_temperature=1.0, - name=None, - default_loss="categorical-cross-entropy", - default_metrics=(), - default_contrastive_metrics=(), + target: Optional[str] = None, + pre: Optional[Layer] = None, + post: Optional[Layer] = None, + logits_temperature: float = 1.0, + name: Optional[str] = None, + default_loss: Union[str, tf.keras.losses.Loss] = "categorical-cross-entropy", + default_metrics: Sequence[tf.keras.metrics.Metric] = ( + RecallAt(DEFAULT_K), + MRRAt(DEFAULT_K), + NDCGAt(DEFAULT_K), + AvgPrecisionAt(DEFAULT_K), + PrecisionAt(DEFAULT_K), + ), + query_name: str = "query", + item_name: str = "item", **kwargs, ): super().__init__( - prediction=DotProduct(), - prediction_with_negatives=ContrastiveDotProduct(), + prediction=DotProduct(query_name, item_name), + prediction_with_negatives=ContrastiveDotProduct(query_name, item_name), default_loss=default_loss, default_metrics=default_metrics, - default_contrastive_metrics=default_contrastive_metrics, name=name, target=target, pre=pre, @@ -39,23 +53,56 @@ def __init__( ) +@tf.keras.utils.register_keras_serializable(package="merlin_models") class DotProduct(Layer): - def __init__(self, **kwargs): + def __init__(self, query_name: str = "query", item_name: str = "item", **kwargs): super().__init__(**kwargs) + self.query_name = query_name + self.item_name = item_name - def call(self, inputs): - return tf.reduce_sum(inputs, axis=1) + def call(self, inputs, **kwargs): + return tf.reduce_sum( + tf.multiply(inputs[self.query_name], inputs[self.item_name]), keepdims=True, axis=-1 + ) def compute_output_shape(self, input_shape): return input_shape[0], 1 + def get_config(self): + return { + **super(DotProduct, self).get_config(), + "query_name": self.query_name, + "item_name": self.item_name, + } -class ContrastiveDotProduct(Layer): - def __init__(self, **kwargs): - super().__init__(**kwargs) - def call(self, inputs): - return tf.reduce_sum(inputs, axis=1) +@tf.keras.utils.register_keras_serializable(package="merlin_models") +class ContrastiveDotProduct(DotProduct): + def __init__( + self, + negative_sampling="in-batch", + downscore_false_negatives=False, + sampling_downscore_false_negatives_value: float = MIN_FLOAT, + query_name: str = "query", + item_name: str = "item", + **kwargs, + ): + super().__init__(query_name, item_name, **kwargs) + self.negative_sampling = negative_sampling + self.downscore_false_negatives = downscore_false_negatives + self.sampling_downscore_false_negatives_value = sampling_downscore_false_negatives_value + + def build(self, input_shape): + super(DotProduct, self).build(input_shape) + self.item_id_feature_name = self.schema.select_by_tag(Tags.ITEM_ID).first.name - def compute_output_shape(self, input_shape): - return input_shape[0], 1 + def call(self, inputs, features, targets, **kwargs): + outputs = inputs + + positive_scores = super(ContrastiveDotProduct, self).call(outputs, **kwargs) + positive_item_ids = features[self.item_id_feature_name] + + if isinstance(targets, tf.Tensor) and len(targets.shape) == len(outputs.shape) - 1: + outputs = tf.squeeze(outputs) + + return Prediction(outputs, targets) diff --git a/tests/unit/tf/predictions/test_dot_product.py b/tests/unit/tf/predictions/test_dot_product.py new file mode 100644 index 0000000000..e69de29bb2 From d180893e95bd63d938d174ef69e136390f782cee Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Wed, 20 Jul 2022 09:35:32 +0200 Subject: [PATCH 16/27] Making test_dot_product_prediction pass --- merlin/models/tf/metrics/topk.py | 5 + merlin/models/tf/models/base.py | 5 +- merlin/models/tf/predictions/base.py | 49 ++-- merlin/models/tf/predictions/dot_product.py | 210 ++++++++++++++++-- .../tf/predictions/sampling/__init__.py | 0 merlin/models/tf/predictions/sampling/base.py | 114 ++++++++++ .../tf/predictions/sampling/in_batch.py | 89 ++++++++ tests/unit/tf/predictions/test_dot_product.py | 14 ++ 8 files changed, 439 insertions(+), 47 deletions(-) create mode 100644 merlin/models/tf/predictions/sampling/__init__.py create mode 100644 merlin/models/tf/predictions/sampling/base.py create mode 100644 merlin/models/tf/predictions/sampling/in_batch.py diff --git a/merlin/models/tf/metrics/topk.py b/merlin/models/tf/metrics/topk.py index 6cca5af339..dbd3525425 100644 --- a/merlin/models/tf/metrics/topk.py +++ b/merlin/models/tf/metrics/topk.py @@ -301,30 +301,35 @@ def from_config(cls, config): return super(TopkMetric, cls).from_config(config) +@tf.keras.utils.register_keras_serializable(package="merlin.models") @metrics_registry.register_with_multiple_names("recall_at", "recall") class RecallAt(TopkMetric): def __init__(self, k=10, pre_sorted=False, name="recall_at"): super().__init__(recall_at, k=k, pre_sorted=pre_sorted, name=name) +@tf.keras.utils.register_keras_serializable(package="merlin.models") @metrics_registry.register_with_multiple_names("precision_at", "precision") class PrecisionAt(TopkMetric): def __init__(self, k=10, pre_sorted=False, name="precision_at"): super().__init__(precision_at, k=k, pre_sorted=pre_sorted, name=name) +@tf.keras.utils.register_keras_serializable(package="merlin.models") @metrics_registry.register_with_multiple_names("map_at", "map") class AvgPrecisionAt(TopkMetric): def __init__(self, k=10, pre_sorted=False, name="map_at"): super().__init__(average_precision_at, k=k, pre_sorted=pre_sorted, name=name) +@tf.keras.utils.register_keras_serializable(package="merlin.models") @metrics_registry.register_with_multiple_names("mrr_at", "mrr") class MRRAt(TopkMetric): def __init__(self, k=10, pre_sorted=False, name="mrr_at"): super().__init__(mrr_at, k=k, pre_sorted=pre_sorted, name=name) +@tf.keras.utils.register_keras_serializable(package="merlin.models") @metrics_registry.register_with_multiple_names("ndcg_at", "ndcg") class NDCGAt(TopkMetric): def __init__(self, k=10, pre_sorted=False, name="ndcg_at"): diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 4743be75ac..cf5355aac8 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -550,8 +550,9 @@ def compute_metrics( # Providing label_relevant_counts for TopkMetrics, as metric.update_state() # should have standard signature for better compatibility with Keras methods # like self.compiled_metrics.update_state() - for topk_metric in filter_topk_metrics(self.compiled_metrics.metrics): - topk_metric.label_relevant_counts = prediction_outputs.label_relevant_counts + if hasattr(prediction_outputs, "label_relevant_counts"): + for topk_metric in filter_topk_metrics(self.compiled_metrics.metrics): + topk_metric.label_relevant_counts = prediction_outputs.label_relevant_counts self.compiled_metrics.update_state( prediction_outputs.targets, prediction_outputs.predictions, sample_weight diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py index 7fbe2e618e..ed2e982614 100644 --- a/merlin/models/tf/predictions/base.py +++ b/merlin/models/tf/predictions/base.py @@ -8,6 +8,7 @@ from merlin.models.tf.core.prediction import Prediction from merlin.models.tf.core.transformations import LogitsTemperatureScaler from merlin.models.tf.utils import tf_utils +from merlin.models.tf.utils.tf_utils import call_layer @tf.keras.utils.register_keras_serializable(package="merlin.models") @@ -156,7 +157,10 @@ def from_config(cls, config): return super().from_config(config) +@tf.keras.utils.register_keras_serializable(package="merlin.models") class ContrastivePredictionBlock(PredictionBlock): + """A prediction block that uses contrastive loss.""" + def __init__( self, prediction: Layer, @@ -168,8 +172,6 @@ def __init__( pre: Optional[Layer] = None, post: Optional[Layer] = None, logits_temperature: float = 1.0, - negative_sampling=None, - downscore_false_negatives=False, **kwargs, ): super(ContrastivePredictionBlock, self).__init__( @@ -188,31 +190,30 @@ def __init__( def call(self, inputs, training=False, testing=False, **kwargs): to_call = self.prediction - if self.has_negative_samplers and (training or testing): + if self.prediction_with_negatives.has_negative_samplers and (training or testing): to_call = self.prediction_with_negatives - return to_call(inputs, training=training, testing=testing, **kwargs) - - @property - def has_negative_samplers(self) -> bool: - return self.negative_sampling is not None and len(self.negative_sampling) > 0 + return call_layer(to_call, inputs, training=training, testing=testing, **kwargs) - def compile(self, negative_sampling=None, downscore_false_negatives=False): - self.prediction_with_negatives.negative_sampling = negative_sampling - self.prediction_with_negatives.downscore_false_negatives = downscore_false_negatives - - @property - def negative_sampling(self): - return self.prediction_with_negatives.negative_sampling + def get_config(self): + config = super(ContrastivePredictionBlock, self).get_config() + config.update( + { + "prediction_with_negatives": tf.keras.utils.serialize_keras_object( + self.prediction_with_negatives + ), + } + ) - @negative_sampling.setter - def negative_sampling(self, value): - self.prediction_with_negatives.negative_sampling = value + return config - @property - def downscore_false_negatives(self): - return self.prediction_with_negatives.downscore_false_negatives + @classmethod + def from_config(cls, config): + config = tf_utils.maybe_deserialize_keras_objects( + config, + { + "prediction_with_negatives": tf.keras.layers.deserialize, + }, + ) - @downscore_false_negatives.setter - def downscore_false_negatives(self, value): - self.prediction_with_negatives.downscore_false_negatives = value + return super().from_config(config) diff --git a/merlin/models/tf/predictions/dot_product.py b/merlin/models/tf/predictions/dot_product.py index 341171b9a6..ff1912c3b0 100644 --- a/merlin/models/tf/predictions/dot_product.py +++ b/merlin/models/tf/predictions/dot_product.py @@ -1,4 +1,5 @@ -from typing import Optional, Sequence, Union +import logging +from typing import List, Optional, Sequence, Union import tensorflow as tf from tensorflow.keras.layers import Layer @@ -6,26 +7,35 @@ from merlin.models.tf.core.prediction import Prediction from merlin.models.tf.metrics.topk import AvgPrecisionAt, MRRAt, NDCGAt, PrecisionAt, RecallAt from merlin.models.tf.predictions.base import ContrastivePredictionBlock - -# Or: RetrievalCategoricalPrediction +from merlin.models.tf.predictions.sampling.base import Items, ItemSampler, ItemSamplersType +from merlin.models.tf.typing import TabularData +from merlin.models.tf.utils import tf_utils +from merlin.models.tf.utils.tf_utils import call_layer, rescore_false_negatives +from merlin.models.utils import schema_utils from merlin.models.utils.constants import MIN_FLOAT -from merlin.schema import Tags +from merlin.schema import Schema, Tags + +LOG = logging.getLogger("merlin_models") @tf.keras.utils.register_keras_serializable(package="merlin_models") +# Or: RetrievalCategoricalPrediction class DotProductCategoricalPrediction(ContrastivePredictionBlock): + """Contrastive prediction using negative-sampling, used in retrieval models.""" + DEFAULT_K = 10 def __init__( self, - negative_sampling="in-batch", + schema: Schema, + negative_samplers: ItemSamplersType = "in-batch", downscore_false_negatives=False, target: Optional[str] = None, pre: Optional[Layer] = None, post: Optional[Layer] = None, logits_temperature: float = 1.0, name: Optional[str] = None, - default_loss: Union[str, tf.keras.losses.Loss] = "categorical-cross-entropy", + default_loss: Union[str, tf.keras.losses.Loss] = "categorical_crossentropy", default_metrics: Sequence[tf.keras.metrics.Metric] = ( RecallAt(DEFAULT_K), MRRAt(DEFAULT_K), @@ -37,9 +47,21 @@ def __init__( item_name: str = "item", **kwargs, ): + prediction = kwargs.pop("prediction", DotProduct(query_name, item_name)) + prediction_with_negatives = kwargs.pop( + "prediction_with_negatives", + ContrastiveDotProduct( + schema, + negative_samplers, + downscore_false_negatives, + query_name=query_name, + item_name=item_name, + ), + ) + super().__init__( - prediction=DotProduct(query_name, item_name), - prediction_with_negatives=ContrastiveDotProduct(query_name, item_name), + prediction=prediction, + prediction_with_negatives=prediction_with_negatives, default_loss=default_loss, default_metrics=default_metrics, name=name, @@ -47,14 +69,45 @@ def __init__( pre=pre, post=post, logits_temperature=logits_temperature, - negative_sampling=negative_sampling, - downscore_false_negatives=downscore_false_negatives, **kwargs, ) + def compile(self, negative_sampling=None, downscore_false_negatives=False): + self.prediction_with_negatives.negative_sampling = negative_sampling + self.prediction_with_negatives.downscore_false_negatives = downscore_false_negatives + + # TODO + def add_sampler(self, sampler): + self.prediction_with_negatives.negative_samplers.append(sampler) + + return self + + @property + def negative_samplers(self): + return self.prediction_with_negatives.negative_samplers + + @negative_samplers.setter + def negative_samplers(self, value): + self.prediction_with_negatives.negative_samplers = value + + @property + def downscore_false_negatives(self): + return self.prediction_with_negatives.downscore_false_negatives + + @downscore_false_negatives.setter + def downscore_false_negatives(self, value): + self.prediction_with_negatives.downscore_false_negatives = value + + def get_config(self): + config = super().get_config() + config["schema"] = config["prediction_with_negatives"]["config"]["schema"] + return config + @tf.keras.utils.register_keras_serializable(package="merlin_models") class DotProduct(Layer): + """Dot-product between queries & items.""" + def __init__(self, query_name: str = "query", item_name: str = "item", **kwargs): super().__init__(**kwargs) self.query_name = query_name @@ -66,7 +119,9 @@ def call(self, inputs, **kwargs): ) def compute_output_shape(self, input_shape): - return input_shape[0], 1 + batch_size = tf_utils.calculate_batch_size_from_input_shapes(input_shape) + + return batch_size, 1 def get_config(self): return { @@ -78,31 +133,144 @@ def get_config(self): @tf.keras.utils.register_keras_serializable(package="merlin_models") class ContrastiveDotProduct(DotProduct): + """Contrastive dot-product between queries & items.""" + def __init__( self, - negative_sampling="in-batch", - downscore_false_negatives=False, - sampling_downscore_false_negatives_value: float = MIN_FLOAT, + schema: Schema, + negative_samplers: ItemSamplersType = "in-batch", + downscore_false_negatives=True, + false_negative_score: float = MIN_FLOAT, query_name: str = "query", item_name: str = "item", + item_id_tag: Tags = Tags.ITEM_ID, + query_id_tag: Tags = Tags.USER_ID, **kwargs, ): super().__init__(query_name, item_name, **kwargs) - self.negative_sampling = negative_sampling + if not isinstance(negative_samplers, (list, tuple)): + negative_samplers = [negative_samplers] + self.negative_samplers = [ItemSampler.parse(s) for s in list(negative_samplers)] self.downscore_false_negatives = downscore_false_negatives - self.sampling_downscore_false_negatives_value = sampling_downscore_false_negatives_value + self.false_negative_score = false_negative_score + self.item_id_tag = item_id_tag + self.query_id_tag = query_id_tag + self.schema = schema def build(self, input_shape): super(DotProduct, self).build(input_shape) - self.item_id_feature_name = self.schema.select_by_tag(Tags.ITEM_ID).first.name + self.item_id_name = self.schema.select_by_tag(self.item_id_tag).first.name + self.query_id_name = self.schema.select_by_tag(self.query_id_tag).first.name + + def call(self, inputs, features, targets=None, training=False, testing=False): + query_id, query_emb = self.get_id_and_embedding( + self.query_name, self.query_id_name, inputs, features + ) + pos_item_id, pos_item_emb = self.get_id_and_embedding( + self.item_name, self.item_id_name, inputs, features + ) + neg_items = self.sample_negatives( + Items(pos_item_id, {}).with_embedding(pos_item_emb), + features, + training=training, + testing=testing, + ) + + # Apply dot-product to positive item and negative items + positive_scores = super(ContrastiveDotProduct, self).call(inputs) + negative_scores = tf.linalg.matmul(query_emb, neg_items.embedding(), transpose_b=True) + + if self.downscore_false_negatives: + negative_scores, _ = rescore_false_negatives( + pos_item_id, neg_items.id, negative_scores, self.false_negative_score + ) - def call(self, inputs, features, targets, **kwargs): - outputs = inputs + outputs = tf.concat([positive_scores, negative_scores], axis=-1) - positive_scores = super(ContrastiveDotProduct, self).call(outputs, **kwargs) - positive_item_ids = features[self.item_id_feature_name] + # To ensure that the output is always fp32, avoiding numerical + # instabilities with mixed_float16 policy + outputs = tf.cast(outputs, tf.float32) + + targets = tf.concat( + [ + tf.ones([tf.shape(outputs)[0], 1], dtype=outputs.dtype), + tf.zeros( + [tf.shape(outputs)[0], tf.shape(outputs)[1] - 1], + dtype=outputs.dtype, + ), + ], + axis=1, + ) if isinstance(targets, tf.Tensor) and len(targets.shape) == len(outputs.shape) - 1: outputs = tf.squeeze(outputs) return Prediction(outputs, targets) + + def get_id_and_embedding( + self, + key: str, + feature_name: str, + inputs: TabularData, + features: TabularData, + ): + embedding = inputs[key] + if f"{key}_id" in inputs: + ids = inputs[f"{key}_id"] + else: + ids = features[feature_name] + + return ids, embedding + + def sample_negatives( + self, + positive_items: Items, + features: TabularData, + training=False, + testing=False, + ) -> Items: + negative_items: List[Items] = [] + sampling_kwargs = {"training": training, "testing": testing, "features": features} + + # Adds items from the current batch into samplers and sample a number of negatives + for sampler in self.negative_samplers: + sampler_items: Items = call_layer(sampler, positive_items, **sampling_kwargs) + + if tf.shape(sampler_items.id)[0] > 0: + negative_items.append(sampler_items) + else: + LOG.warn( + f"The sampler {type(sampler).__name__} returned no samples for this batch." + ) + + if len(negative_items) == 0: + raise Exception(f"No negative items where sampled from samplers {self.samplers}") + + negatives = sum(negative_items) if len(negative_items) > 1 else negative_items[0] + + return negatives + + @property + def has_negative_samplers(self) -> bool: + return self.negative_samplers is not None and len(self.negative_samplers) > 0 + + def get_config(self): + config = tf_utils.maybe_serialize_keras_objects( + self, + { + **super().get_config(), + "downscore_false_negatives": self.downscore_false_negatives, + "false_negative_score": self.false_negative_score, + }, + ["negative_samplers"], + ) + config["schema"] = schema_utils.schema_to_tensorflow_metadata_json(self.schema) + + return config + + @classmethod + def from_config(cls, config): + config = tf_utils.maybe_deserialize_keras_objects(config, ["negative_samplers"]) + config["schema"] = schema_utils.tensorflow_metadata_json_to_schema(config["schema"]) + + return super().from_config(config) diff --git a/merlin/models/tf/predictions/sampling/__init__.py b/merlin/models/tf/predictions/sampling/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/merlin/models/tf/predictions/sampling/base.py b/merlin/models/tf/predictions/sampling/base.py new file mode 100644 index 0000000000..38bb7179a7 --- /dev/null +++ b/merlin/models/tf/predictions/sampling/base.py @@ -0,0 +1,114 @@ +import abc +from typing import Dict, List, NamedTuple, Optional, Sequence, Union + +import tensorflow as tf + +from merlin.models.utils.registry import Registry, RegistryMixin + +ITEM_EMBEDDING_KEY = "__item_embedding__" + + +class Items(NamedTuple): + id: tf.Tensor + metadata: Dict[str, tf.Tensor] + + def embedding(self) -> tf.Tensor: + return self.metadata[ITEM_EMBEDDING_KEY] + + @property + def has_embedding(self) -> bool: + return ITEM_EMBEDDING_KEY in self.metadata + + def with_embedding(self, embedding: tf.Tensor) -> "Items": + self.metadata[ITEM_EMBEDDING_KEY] = embedding + + return self + + def __add__(self, other): + return Items( + id=_list_to_tensor([self.id, other.ids]), + metadata={ + key: _list_to_tensor([self.metadata[key], other.metadata[key]]) + for key, val in self.metadata.items() + }, + ) + + @property + def shape(self) -> "Items": + return Items(self.id.shape, {key: val.shape for key, val in self.metadata.items()}) + + def __repr__(self): + metadata = {key: str(val) for key, val in self.metadata.items()} + + return f"Items({self.id}, {metadata})" + + def __str__(self): + metadata = {key: str(val) for key, val in self.metadata.items()} + + return f"Items({self.id}, {metadata})" + + def get_config(self): + return { + "ids": self.id.as_list() if self.id else None, + "metadata": {key: val.as_list() for key, val in self.metadata.items()}, + } + + @classmethod + def from_config(cls, config): + ids = tf.TensorShape(config["config"]["id"]) + metadata = {key: tf.TensorShape(val) for key, val in config["config"]["metadata"].items()} + + return cls(ids, metadata) + + +negative_sampling_registry: Registry = Registry.class_registry("tf.negative_sampling") + + +class ItemSampler(tf.keras.layers.Layer, RegistryMixin["ItemSampler"], abc.ABC): + registry = negative_sampling_registry + + def __init__( + self, + max_num_samples: Optional[int] = None, + **kwargs, + ): + super(ItemSampler, self).__init__(**kwargs) + self.set_max_num_samples(max_num_samples) + + def call( + self, items: Items, features=None, targets=None, training=False, testing=False + ) -> Items: + if training: + self.add(items) + items = self.sample() + + return items + + @abc.abstractmethod + def add(self, items: Items): + raise NotImplementedError() + + @abc.abstractmethod + def sample(self) -> Items: + raise NotImplementedError() + + @property + def max_num_samples(self) -> int: + return self._max_num_samples + + def set_max_num_samples(self, value) -> None: + self._max_num_samples = value + + +def _list_to_tensor(input_list: List[tf.Tensor]) -> tf.Tensor: + output: tf.Tensor + + if len(input_list) == 1: + output = input_list[0] + else: + output = tf.concat(input_list, axis=0) + + return output + + +ItemSamplersType = Union[ItemSampler, Sequence[Union[ItemSampler, str]], str] diff --git a/merlin/models/tf/predictions/sampling/in_batch.py b/merlin/models/tf/predictions/sampling/in_batch.py new file mode 100644 index 0000000000..d0da08b29a --- /dev/null +++ b/merlin/models/tf/predictions/sampling/in_batch.py @@ -0,0 +1,89 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional + +import tensorflow as tf + +from merlin.models.tf.predictions.sampling.base import Items, ItemSampler + + +@ItemSampler.registry.register("in-batch") +@tf.keras.utils.register_keras_serializable(package="merlin.models") +class InBatchSampler(ItemSampler): + """Provides in-batch sampling [1]_ for two-tower item retrieval + models. The implementation is very simple, as it + just returns the current item embeddings and metadata, but it is necessary to have + `InBatchSampler` under the same interface of other more advanced samplers + (e.g. `CachedCrossBatchSampler`). + In a nutshell, for a given (user,item) embeddings pair, the other in-batch item + embeddings are used as negative items, rather than computing different embeddings + exclusively for negative items. + This is a popularity-biased sampling as popular items are observed more often + in training batches. + P.s. Ignoring the false negatives (negative items equal to the positive ones) is + managed by `ItemRetrievalScorer(..., sampling_downscore_false_negatives=True)` + References + ---------- + .. [1] Yi, Xinyang, et al. "Sampling-bias-corrected neural modeling for large corpus item + recommendations." Proceedings of the 13th ACM Conference on Recommender Systems. 2019. + Parameters + ---------- + batch_size : int, optional + The batch size. If not set it is inferred when the layer is built (first call()) + """ + + def __init__(self, batch_size: Optional[int] = None, **kwargs): + super().__init__(max_num_samples=batch_size, **kwargs) + self._last_batch: Optional[Items] = None # type: ignore + self.set_batch_size(batch_size) + + @property + def batch_size(self) -> int: + return self._batch_size + + def set_batch_size(self, value): + self._batch_size = value + if value is not None: + self.set_max_num_samples(value) + + def build(self, items: Items) -> None: + if isinstance(items, dict): + items = Items.from_config(items) + if self._batch_size is None: + self.set_batch_size(items.id[0]) + + def add(self, items: Items): + self._last_batch = items + + def call( + self, items: Items, features=None, targets=None, training=False, testing=False + ) -> Items: + self.add(items) + items = self.sample() + + return items + + def sample(self) -> Items: + return self._last_batch + + def get_config(self): + config = super().get_config() + config["batch_size"] = self._batch_size + + # TODO: This is a side-effect, could this lead to problems? + self._last_batch = None + + return config diff --git a/tests/unit/tf/predictions/test_dot_product.py b/tests/unit/tf/predictions/test_dot_product.py index e69de29bb2..efdd97cda6 100644 --- a/tests/unit/tf/predictions/test_dot_product.py +++ b/tests/unit/tf/predictions/test_dot_product.py @@ -0,0 +1,14 @@ +import merlin.models.tf as mm +from merlin.io import Dataset +from merlin.models.tf.predictions.dot_product import DotProductCategoricalPrediction +from merlin.models.tf.predictions.sampling.in_batch import InBatchSampler +from merlin.models.tf.utils import testing_utils + + +def test_dot_product_prediction(ecommerce_data: Dataset): + model = mm.Model( + mm.TwoTowerBlock(ecommerce_data.schema, query_tower=mm.MLPBlock([8])), + DotProductCategoricalPrediction(ecommerce_data.schema, negative_samplers=InBatchSampler()), + ) + + _, history = testing_utils.model_test(model, ecommerce_data) From 8758b33dc2c8f7a96137c8c015b5661748e42e0a Mon Sep 17 00:00:00 2001 From: Marc Romeyn Date: Fri, 22 Jul 2022 10:45:17 +0200 Subject: [PATCH 17/27] Adding TODO --- merlin/models/tf/models/base.py | 2 ++ tests/unit/tf/predictions/test_dot_product.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index cf5355aac8..4da638f252 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -345,6 +345,8 @@ def _create_metrics(self, metrics=None): for task_name, task in self.predictions_by_name().items(): out[task_name] = task.create_default_metrics() + # TODO: Check for top-k metrics & wrap them in TopKMetricsAggregator + return out def _create_loss(self, loss=None): diff --git a/tests/unit/tf/predictions/test_dot_product.py b/tests/unit/tf/predictions/test_dot_product.py index efdd97cda6..b5c82f81b9 100644 --- a/tests/unit/tf/predictions/test_dot_product.py +++ b/tests/unit/tf/predictions/test_dot_product.py @@ -6,7 +6,7 @@ def test_dot_product_prediction(ecommerce_data: Dataset): - model = mm.Model( + model = mm.RetrievalModel( mm.TwoTowerBlock(ecommerce_data.schema, query_tower=mm.MLPBlock([8])), DotProductCategoricalPrediction(ecommerce_data.schema, negative_samplers=InBatchSampler()), ) From b31ab0c2f60de2143080d09ce064e2b0a2bafb46 Mon Sep 17 00:00:00 2001 From: sararb Date: Fri, 29 Jul 2022 14:37:30 -0400 Subject: [PATCH 18/27] unify top-k metrics in one topkaggregator instanc --- merlin/models/tf/metrics/topk.py | 41 ++++++++++++++++++++++++++++++++ merlin/models/tf/models/base.py | 10 +++++--- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/merlin/models/tf/metrics/topk.py b/merlin/models/tf/metrics/topk.py index dbd3525425..4db12807fb 100644 --- a/merlin/models/tf/metrics/topk.py +++ b/merlin/models/tf/metrics/topk.py @@ -414,6 +414,20 @@ def default_metrics(cls, top_ks: Sequence[int], **kwargs) -> Sequence[TopkMetric aggregator = cls(*metrics) return [aggregator] + def get_config(self): + config = {} + for i, metric in enumerate(self.topk_metrics): + config[i] = tf.keras.utils.serialize_keras_object(metric) + return config + + @classmethod + def from_config(cls, config, custom_objects=None): + metrics = [ + tf.keras.layers.deserialize(conf, custom_objects=custom_objects) + for conf in config.values() + ] + return TopKMetricsAggregator(*metrics) + def filter_topk_metrics( metrics: Sequence[Metric], @@ -438,3 +452,30 @@ def filter_topk_metrics( ] ) return topk_metrics + + +def split_metrics( + metrics: Sequence[Metric], + return_other_metrics: bool = False, +) -> List[TopkMetric, TopKMetricsAggregator, Metric]: + """Split the list of metrics into top-k metrics, top-k aggregators and others + + Parameters + ---------- + metrics : List[Metric] + List of metrics + + Returns + ------- + List[TopkMetric, TopKMetricsAggregator, Metric] + List with the top-k metrics in the list of input metrics + """ + topk_metrics, topk_aggregators, other_metrics = [], [], [] + for metric in metrics: + if isinstance(metric, TopkMetric): + topk_metrics.append(metric) + elif isinstance(metric, TopKMetricsAggregator): + topk_aggregators.append(metric) + else: + other_metrics.append(metric) + return topk_metrics, topk_aggregators, other_metrics diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 4da638f252..122c3502cf 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -20,7 +20,7 @@ from merlin.models.tf.dataset import BatchedDataset from merlin.models.tf.inputs.base import InputBlock from merlin.models.tf.losses.base import loss_registry -from merlin.models.tf.metrics.topk import filter_topk_metrics +from merlin.models.tf.metrics.topk import TopKMetricsAggregator, filter_topk_metrics, split_metrics from merlin.models.tf.models.utils import parse_prediction_tasks from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask from merlin.models.tf.predictions.base import PredictionBlock @@ -325,6 +325,12 @@ def _create_metrics(self, metrics=None): num_v1_blocks = len(self.prediction_tasks) if isinstance(metrics, (list, tuple)): + # Retrieve top-k metrics & wrap them in TopKMetricsAggregator + topk_metrics, topk_aggregators, other_metrics = split_metrics(metrics) + if len(topk_metrics) > 0: + topk_aggregators.append(TopKMetricsAggregator(*topk_metrics)) + metrics = other_metrics + topk_aggregators + if num_v1_blocks > 0: if num_v1_blocks == 1: out[self.prediction_tasks[0].task_name] = metrics @@ -345,8 +351,6 @@ def _create_metrics(self, metrics=None): for task_name, task in self.predictions_by_name().items(): out[task_name] = task.create_default_metrics() - # TODO: Check for top-k metrics & wrap them in TopKMetricsAggregator - return out def _create_loss(self, loss=None): From d86081fcb23ea7f0a0cf520f1550d98f83dc3173 Mon Sep 17 00:00:00 2001 From: sararb Date: Tue, 2 Aug 2022 17:14:19 -0400 Subject: [PATCH 19/27] fix split_metrics --- merlin/models/tf/metrics/topk.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/merlin/models/tf/metrics/topk.py b/merlin/models/tf/metrics/topk.py index 4db12807fb..acf7a5ad26 100644 --- a/merlin/models/tf/metrics/topk.py +++ b/merlin/models/tf/metrics/topk.py @@ -15,7 +15,7 @@ # # Adapted from source code: https://github.com/karlhigley/ranking-metrics-torch -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union import tensorflow as tf from keras.utils import losses_utils, metrics_utils @@ -336,6 +336,7 @@ def __init__(self, k=10, pre_sorted=False, name="ndcg_at"): super().__init__(ndcg_at, k=k, pre_sorted=pre_sorted, name=name) +@tf.keras.utils.register_keras_serializable(package="merlin.models") class TopKMetricsAggregator(Metric, TopkMetricWithLabelRelevantCountsMixin): """Aggregator for top-k metrics (TopkMetric) that is optimized to sort top-k predictions only once for all metrics. @@ -457,7 +458,7 @@ def filter_topk_metrics( def split_metrics( metrics: Sequence[Metric], return_other_metrics: bool = False, -) -> List[TopkMetric, TopKMetricsAggregator, Metric]: +) -> Tuple[TopkMetric, TopKMetricsAggregator, Metric]: """Split the list of metrics into top-k metrics, top-k aggregators and others Parameters From 162fa4890511e6799dca54dc939b3108594b19be Mon Sep 17 00:00:00 2001 From: sararb Date: Tue, 2 Aug 2022 17:16:21 -0400 Subject: [PATCH 20/27] add negative_sampling to compile method --- merlin/models/tf/models/base.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index 122c3502cf..07644aa6d9 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -23,7 +23,10 @@ from merlin.models.tf.metrics.topk import TopKMetricsAggregator, filter_topk_metrics, split_metrics from merlin.models.tf.models.utils import parse_prediction_tasks from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask -from merlin.models.tf.predictions.base import PredictionBlock +from merlin.models.tf.predictions.base import ContrastivePredictionBlock, PredictionBlock + +# from merlin.models.tf.predictions.dot_product import DotProductCategoricalPrediction +from merlin.models.tf.typing import TabularData from merlin.models.tf.utils.search_utils import find_all_instances_in_layers from merlin.models.tf.utils.tf_utils import call_layer, maybe_serialize_keras_objects from merlin.models.utils.dataset import unique_rows_by_features @@ -302,6 +305,14 @@ def compile( self.output_names = [task.task_name for task in self.prediction_tasks] else: self.output_names = [block.full_name for block in self.prediction_blocks] + negative_sampling = kwargs.pop("negative_sampling", None) + if negative_sampling: + if not isinstance(self.prediction_blocks[0], ContrastivePredictionBlock): + raise ValueError( + "Negative sampling strategy can be used only with a" + " `ContrastivePredictionBlock` prediction block" + ) + self.prediction_blocks[0].compile(negative_sampling=negative_sampling) # This flag will make Keras change the metric-names which is not needed in v2 from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0) From b064fd3c2164003f1c36e8984a6b0327d39d9561 Mon Sep 17 00:00:00 2001 From: sararb Date: Tue, 2 Aug 2022 18:26:59 -0400 Subject: [PATCH 21/27] update contrastive prediction block --- merlin/models/tf/predictions/dot_product.py | 101 ++++++++++++++- tests/unit/tf/predictions/test_dot_product.py | 120 +++++++++++++++++- 2 files changed, 217 insertions(+), 4 deletions(-) diff --git a/merlin/models/tf/predictions/dot_product.py b/merlin/models/tf/predictions/dot_product.py index ff1912c3b0..dd00033131 100644 --- a/merlin/models/tf/predictions/dot_product.py +++ b/merlin/models/tf/predictions/dot_product.py @@ -21,7 +21,46 @@ @tf.keras.utils.register_keras_serializable(package="merlin_models") # Or: RetrievalCategoricalPrediction class DotProductCategoricalPrediction(ContrastivePredictionBlock): - """Contrastive prediction using negative-sampling, used in retrieval models.""" + """Contrastive prediction using negative-sampling, + used in retrieval models. + + Parameters + ---------- + schema : Schema + The schema object including features to use and their properties. + This Schema object will be automatically generated using + [NVTabular](https://nvidia-merlin.github.io/NVTabular/main/Introduction.html). + Next to this, it's also possible to construct it manually + negative_samplers : ItemSamplersType, optional + List of samplers for negative sampling, + by default by default "in-batch" + downscore_false_negatives : bool, optional + Identify false negatives (sampled item ids equal to the positive item and downscore them + to the `sampling_downscore_false_negatives_value`), + by default False + target : Optional[str], optional + If specified, name of the target tensor to retrieve from dataloader, + by default None + pre: Optional[Block], optional + Optional block to transform predictions before applying the prediction layer, + by default None + post: Optional[Block], optional + Optional block to transform predictions after applying the prediction layer, + by default None + logits_temperature: float, optional + Parameter used to reduce model overconfidence, so that logits / T. + by default 1. + name: Optional[Text], optional + Task name, by default None + default_loss: Union[str, tf.keras.losses.Loss] + Default loss to set if the user does not specify one + default_metrics: Sequence[tf.keras.metrics.Metric] + Default metrics to set if the user does not specify any + query_name : str, optional + Identify query tower for query/user embeddings, by default 'query' + item_name : str, optional + Identify item tower for item embeddings, by default'item' + """ DEFAULT_K = 10 @@ -106,7 +145,14 @@ def get_config(self): @tf.keras.utils.register_keras_serializable(package="merlin_models") class DotProduct(Layer): - """Dot-product between queries & items.""" + """Dot-product between queries & items. + Parameters: + ----------- + query_name : str, optional + Identify query tower for query/user embeddings, by default 'query' + item_name : str, optional + Identify item tower for item embeddings, by default 'item' + """ def __init__(self, query_name: str = "query", item_name: str = "item", **kwargs): super().__init__(**kwargs) @@ -133,7 +179,34 @@ def get_config(self): @tf.keras.utils.register_keras_serializable(package="merlin_models") class ContrastiveDotProduct(DotProduct): - """Contrastive dot-product between queries & items.""" + """Contrastive dot-product between queries & items. + Parameters + ---------- + schema : Schema + The schema object including features to use and their properties. + This Schema object will be automatically generated using + [NVTabular](https://nvidia-merlin.github.io/NVTabular/main/Introduction.html). + Next to this, it's also possible to construct it manually + negative_samplers : ItemSamplersType, optional + List of samplers for negative sampling, + by default by default "in-batch" + downscore_false_negatives : bool, optional + Identify false negatives (sampled item ids equal to the positive item and downscore them + to the `sampling_downscore_false_negatives_value`), + by default False + false_negative_score : float, optional + Value to be used to downscore false negatives when + `sampling_downscore_false_negatives=True`, + by default `np.finfo(np.float32).min / 100.0` + query_name : str, optional + Identify query tower for query/user embeddings, by default 'query' + item_name : str, optional + Identify item tower for item embeddings, by default 'item' + item_id_tag : Tags, optional + The tag to select the item id feature, by default `Tags.ITEM_ID` + query_id_tag : Tags, optional + The tag to select the user id feature, by default `Tags.USER_ID` + """ def __init__( self, @@ -151,6 +224,10 @@ def __init__( if not isinstance(negative_samplers, (list, tuple)): negative_samplers = [negative_samplers] self.negative_samplers = [ItemSampler.parse(s) for s in list(negative_samplers)] + assert ( + len(self.negative_samplers) > 0 + ), "At least one sampler is required by ContrastiveDotProduct for negative sampling" + self.downscore_false_negatives = downscore_false_negatives self.false_negative_score = false_negative_score self.item_id_tag = item_id_tag @@ -229,6 +306,24 @@ def sample_negatives( training=False, testing=False, ) -> Items: + """Method to sample negatives from `self.negative_samplers` + + Parameters + ---------- + positive_items : Items + Class containing embeddings and metadata about positive items + features : TabularData + Dictionary of input raw tensors + training : bool, optional + Flag for train mode, by default False + testing : bool, optional + Flag for test mode, by default False + + Returns + ------- + Items + Class containing embeddings and metadata about sampled negative items + """ negative_items: List[Items] = [] sampling_kwargs = {"training": training, "testing": testing, "features": features} diff --git a/tests/unit/tf/predictions/test_dot_product.py b/tests/unit/tf/predictions/test_dot_product.py index b5c82f81b9..14c6ff49e2 100644 --- a/tests/unit/tf/predictions/test_dot_product.py +++ b/tests/unit/tf/predictions/test_dot_product.py @@ -1,6 +1,12 @@ +import pytest +import tensorflow as tf + import merlin.models.tf as mm from merlin.io import Dataset -from merlin.models.tf.predictions.dot_product import DotProductCategoricalPrediction +from merlin.models.tf.predictions.dot_product import ( + ContrastiveDotProduct, + DotProductCategoricalPrediction, +) from merlin.models.tf.predictions.sampling.in_batch import InBatchSampler from merlin.models.tf.utils import testing_utils @@ -12,3 +18,115 @@ def test_dot_product_prediction(ecommerce_data: Dataset): ) _, history = testing_utils.model_test(model, ecommerce_data) + + +@testing_utils.mark_run_eagerly_modes +def test_setting_negative_sampling_strategy(ecommerce_data: Dataset, run_eagerly: bool): + model = mm.RetrievalModel( + mm.TwoTowerBlock(ecommerce_data.schema, query_tower=mm.MLPBlock([8])), + DotProductCategoricalPrediction(ecommerce_data.schema), + ) + model.compile(run_eagerly=run_eagerly, optimizer="adam", negative_sampling="in-batch") + _ = model.fit(ecommerce_data, batch_size=50, epochs=5, steps_per_epoch=1) + + +def test_add_sampler(ecommerce_data: Dataset): + model = mm.RetrievalModel( + mm.TwoTowerBlock(ecommerce_data.schema, query_tower=mm.MLPBlock([8])), + DotProductCategoricalPrediction(ecommerce_data.schema, negative_samplers=InBatchSampler()), + ) + assert len(model.prediction_blocks[0].prediction_with_negatives.negative_samplers) == 1 + model.prediction_blocks[0].add_sampler(mm.InBatchSampler()) + assert len(model.prediction_blocks[0].prediction_with_negatives.negative_samplers) == 2 + + +def test_contrastive_dot_product(ecommerce_data: Dataset): + batch_size = 10 + inbatch_sampler = InBatchSampler() + + retrieval_scorer = ContrastiveDotProduct( + schema=ecommerce_data.schema, + negative_samplers=[inbatch_sampler], + downscore_false_negatives=False, + ) + inputs, features = _retrieval_inputs_(batch_size=batch_size) + output = retrieval_scorer(inputs, features=features) + + expected_num_samples_inbatch = batch_size + 1 + tf.assert_equal(tf.shape(output.predictions)[0], batch_size) + # Number of negatives plus one positive + tf.assert_equal(tf.shape(output.predictions)[1], expected_num_samples_inbatch) + + +def test_item_retrieval_scorer_no_sampler(ecommerce_data: Dataset): + with pytest.raises(Exception) as excinfo: + inputs, features = _retrieval_inputs_(batch_size=10) + retrieval_scorer = ContrastiveDotProduct( + schema=ecommerce_data.schema, negative_samplers=[], downscore_false_negatives=False + ) + _ = retrieval_scorer(inputs, features=features, training=True) + assert "At least one sampler is required by ContrastiveDotProduct for negative sampling" in str( + excinfo.value + ) + + +def test_item_retrieval_scorer_downscore_false_negatives(ecommerce_data: Dataset): + batch_size = 10 + + inbatch_sampler = InBatchSampler() + inputs, features = _retrieval_inputs_(batch_size=batch_size) + + FALSE_NEGATIVE_SCORE = -100_000_000.0 + item_retrieval_scorer = ContrastiveDotProduct( + schema=ecommerce_data.schema, + negative_samplers=[inbatch_sampler], + downscore_false_negatives=True, + false_negative_score=FALSE_NEGATIVE_SCORE, + ) + + outputs = item_retrieval_scorer( + inputs, + training=True, + features=features, + ) + output_scores = outputs.predictions + + output_neg_scores = output_scores[:, 1:] + + diag_mask = tf.eye(tf.shape(output_neg_scores)[0], dtype=tf.bool) + tf.assert_equal(output_neg_scores[diag_mask], FALSE_NEGATIVE_SCORE) + tf.assert_equal( + tf.reduce_all( + tf.not_equal( + output_neg_scores[tf.math.logical_not(diag_mask)], + tf.constant(FALSE_NEGATIVE_SCORE, dtype=output_neg_scores.dtype), + ) + ), + True, + ) + + +def test_retrieval_prediction_only_positive_when_not_training(ecommerce_data: Dataset): + batch_size = 10 + + inbatch_sampler = InBatchSampler() + item_retrieval_prediction = DotProductCategoricalPrediction( + schema=ecommerce_data.schema, + negative_samplers=[inbatch_sampler], + downscore_false_negatives=False, + ) + + inputs, features = _retrieval_inputs_(batch_size=batch_size) + output_scores = item_retrieval_prediction(inputs) + tf.assert_equal( + (int(tf.shape(output_scores)[0]), int(tf.shape(output_scores)[1])), (batch_size, 1) + ) + + +def _retrieval_inputs_(batch_size): + users_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32) + items_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32) + positive_items = tf.random.uniform(shape=(10,), minval=1, maxval=100, dtype=tf.int32) + inputs = {"query": users_embeddings, "item": items_embeddings, "item_id": positive_items} + features = {"product_id": positive_items, "user_id": None} + return inputs, features From e350b1811a8c93594d25a0348cd928121a61117f Mon Sep 17 00:00:00 2001 From: sararb Date: Wed, 3 Aug 2022 11:40:00 -0400 Subject: [PATCH 22/27] fix missing imports --- merlin/models/tf/predictions/base.py | 63 ++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py index 369a68686b..79f0f60caf 100644 --- a/merlin/models/tf/predictions/base.py +++ b/merlin/models/tf/predictions/base.py @@ -8,6 +8,7 @@ from merlin.models.tf.core.prediction import Prediction from merlin.models.tf.core.transformations import LogitsTemperatureScaler from merlin.models.tf.utils import tf_utils +from merlin.models.tf.utils.tf_utils import call_layer @tf.keras.utils.register_keras_serializable(package="merlin.models") @@ -179,3 +180,65 @@ def from_config(cls, config): ) return super().from_config(config) + + +@tf.keras.utils.register_keras_serializable(package="merlin.models") +class ContrastivePredictionBlock(PredictionBlock): + """A prediction block that uses contrastive loss.""" + + def __init__( + self, + prediction: Layer, + prediction_with_negatives: Layer, + default_loss: Union[str, tf.keras.losses.Loss], + default_metrics: Sequence[tf.keras.metrics.Metric], + name: Optional[str] = None, + target: Optional[str] = None, + pre: Optional[Layer] = None, + post: Optional[Layer] = None, + logits_temperature: float = 1.0, + **kwargs, + ): + super(ContrastivePredictionBlock, self).__init__( + prediction, + default_loss=default_loss, + default_metrics=default_metrics, + target=target, + pre=pre, + post=post, + logits_temperature=logits_temperature, + name=name, + **kwargs, + ) + self.prediction_with_negatives = prediction_with_negatives + + def call(self, inputs, training=False, testing=False, **kwargs): + to_call = self.prediction + + if self.prediction_with_negatives.has_negative_samplers and (training or testing): + to_call = self.prediction_with_negatives + + return call_layer(to_call, inputs, training=training, testing=testing, **kwargs) + + def get_config(self): + config = super(ContrastivePredictionBlock, self).get_config() + config.update( + { + "prediction_with_negatives": tf.keras.utils.serialize_keras_object( + self.prediction_with_negatives + ), + } + ) + + return config + + @classmethod + def from_config(cls, config): + config = tf_utils.maybe_deserialize_keras_objects( + config, + { + "prediction_with_negatives": tf.keras.layers.deserialize, + }, + ) + + return super().from_config(config) From e5c58bcb2826a00d86c04bfa97c8f976c182e534 Mon Sep 17 00:00:00 2001 From: sararb Date: Wed, 3 Aug 2022 13:57:09 -0400 Subject: [PATCH 23/27] fix failing test --- tests/unit/tf/predictions/test_dot_product.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/tf/predictions/test_dot_product.py b/tests/unit/tf/predictions/test_dot_product.py index 14c6ff49e2..33220da592 100644 --- a/tests/unit/tf/predictions/test_dot_product.py +++ b/tests/unit/tf/predictions/test_dot_product.py @@ -126,7 +126,9 @@ def test_retrieval_prediction_only_positive_when_not_training(ecommerce_data: Da def _retrieval_inputs_(batch_size): users_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32) items_embeddings = tf.random.uniform(shape=(batch_size, 5), dtype=tf.float32) - positive_items = tf.random.uniform(shape=(10,), minval=1, maxval=100, dtype=tf.int32) + positive_items = tf.random.uniform( + shape=(batch_size,), minval=1, maxval=1000000, dtype=tf.int32 + ) inputs = {"query": users_embeddings, "item": items_embeddings, "item_id": positive_items} features = {"product_id": positive_items, "user_id": None} return inputs, features From 2ce437d3a669dc3cc0916cd021ee8f5c170b1b6b Mon Sep 17 00:00:00 2001 From: sararb Date: Thu, 4 Aug 2022 05:42:33 -0400 Subject: [PATCH 24/27] add docstrings --- merlin/models/tf/predictions/base.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py index 79f0f60caf..0eebb78764 100644 --- a/merlin/models/tf/predictions/base.py +++ b/merlin/models/tf/predictions/base.py @@ -184,7 +184,32 @@ def from_config(cls, config): @tf.keras.utils.register_keras_serializable(package="merlin.models") class ContrastivePredictionBlock(PredictionBlock): - """A prediction block that uses contrastive loss.""" + """Base-class for prediction blocks that uses contrastive loss. + + Parameters + ---------- + prediction : Layer + The prediction layer + prediction_with_negatives : Layer + The prediction layer that includes negative sampling + default_loss: Union[str, tf.keras.losses.Loss] + Default loss to set if the user does not specify one + default_metrics: Sequence[tf.keras.metrics.Metric] + Default metrics to set if the user does not specify any + name: Optional[Text], optional + Task name, by default None + target: Optional[str], optional + Label name, by default None + pre: Optional[Block], optional + Optional block to transform predictions before applying the prediction layer, + by default None + post: Optional[Block], optional + Optional block to transform predictions after applying the prediction layer, + by default None + logits_temperature: float, optional + Parameter used to reduce model overconfidence, so that logits / T. + by default 1. + """ def __init__( self, @@ -199,6 +224,7 @@ def __init__( logits_temperature: float = 1.0, **kwargs, ): + super(ContrastivePredictionBlock, self).__init__( prediction, default_loss=default_loss, From 015ac6ecc057d8162f6e626c6627a9c919a8b4df Mon Sep 17 00:00:00 2001 From: sararb Date: Thu, 4 Aug 2022 12:18:33 -0400 Subject: [PATCH 25/27] update names of new sampling classes --- merlin/models/tf/__init__.py | 2 + merlin/models/tf/predictions/dot_product.py | 4 +- merlin/models/tf/predictions/sampling/base.py | 30 +++++++++++-- .../tf/predictions/sampling/in_batch.py | 27 +++++++++-- tests/unit/tf/predictions/test_dot_product.py | 18 +++++--- tests/unit/tf/predictions/test_sampling.py | 45 +++++++++++++++++++ 6 files changed, 111 insertions(+), 15 deletions(-) create mode 100644 tests/unit/tf/predictions/test_sampling.py diff --git a/merlin/models/tf/__init__.py b/merlin/models/tf/__init__.py index 9733615937..fb86d5afb5 100644 --- a/merlin/models/tf/__init__.py +++ b/merlin/models/tf/__init__.py @@ -120,6 +120,8 @@ from merlin.models.tf.predictions.base import PredictionBlock from merlin.models.tf.predictions.classification import BinaryPrediction from merlin.models.tf.predictions.regression import RegressionPrediction +from merlin.models.tf.predictions.sampling.base import Items, ItemSamplerV2 +from merlin.models.tf.predictions.sampling.in_batch import InBatchSamplerV2 from merlin.models.tf.utils import repr_utils from merlin.models.tf.utils.tf_utils import TensorInitializer diff --git a/merlin/models/tf/predictions/dot_product.py b/merlin/models/tf/predictions/dot_product.py index dd00033131..35ced9e2cd 100644 --- a/merlin/models/tf/predictions/dot_product.py +++ b/merlin/models/tf/predictions/dot_product.py @@ -7,7 +7,7 @@ from merlin.models.tf.core.prediction import Prediction from merlin.models.tf.metrics.topk import AvgPrecisionAt, MRRAt, NDCGAt, PrecisionAt, RecallAt from merlin.models.tf.predictions.base import ContrastivePredictionBlock -from merlin.models.tf.predictions.sampling.base import Items, ItemSampler, ItemSamplersType +from merlin.models.tf.predictions.sampling.base import Items, ItemSamplersType, ItemSamplerV2 from merlin.models.tf.typing import TabularData from merlin.models.tf.utils import tf_utils from merlin.models.tf.utils.tf_utils import call_layer, rescore_false_negatives @@ -223,7 +223,7 @@ def __init__( super().__init__(query_name, item_name, **kwargs) if not isinstance(negative_samplers, (list, tuple)): negative_samplers = [negative_samplers] - self.negative_samplers = [ItemSampler.parse(s) for s in list(negative_samplers)] + self.negative_samplers = [ItemSamplerV2.parse(s) for s in list(negative_samplers)] assert ( len(self.negative_samplers) > 0 ), "At least one sampler is required by ContrastiveDotProduct for negative sampling" diff --git a/merlin/models/tf/predictions/sampling/base.py b/merlin/models/tf/predictions/sampling/base.py index 38bb7179a7..97fc9594b2 100644 --- a/merlin/models/tf/predictions/sampling/base.py +++ b/merlin/models/tf/predictions/sampling/base.py @@ -9,6 +9,17 @@ class Items(NamedTuple): + """Storea item ids and their metadata + + Parameters + ---------- + id : tf.Tensor + The tensor of item ids + metadata: + dictionary of tensors containing meta information + about items such as item embeddings and item category + """ + id: tf.Tensor metadata: Dict[str, tf.Tensor] @@ -64,7 +75,20 @@ def from_config(cls, config): negative_sampling_registry: Registry = Registry.class_registry("tf.negative_sampling") -class ItemSampler(tf.keras.layers.Layer, RegistryMixin["ItemSampler"], abc.ABC): +class ItemSamplerV2(tf.keras.layers.Layer, RegistryMixin["ItemSampler"], abc.ABC): + """Base-class for negative sampling + + Parameters + ---------- + max_num_samples : int + The number of maximum samples to store + + Returns + ------- + Items + The sampled ids and their metadata + """ + registry = negative_sampling_registry def __init__( @@ -72,7 +96,7 @@ def __init__( max_num_samples: Optional[int] = None, **kwargs, ): - super(ItemSampler, self).__init__(**kwargs) + super(ItemSamplerV2, self).__init__(**kwargs) self.set_max_num_samples(max_num_samples) def call( @@ -111,4 +135,4 @@ def _list_to_tensor(input_list: List[tf.Tensor]) -> tf.Tensor: return output -ItemSamplersType = Union[ItemSampler, Sequence[Union[ItemSampler, str]], str] +ItemSamplersType = Union[ItemSamplerV2, Sequence[Union[ItemSamplerV2, str]], str] diff --git a/merlin/models/tf/predictions/sampling/in_batch.py b/merlin/models/tf/predictions/sampling/in_batch.py index d0da08b29a..90d48ca596 100644 --- a/merlin/models/tf/predictions/sampling/in_batch.py +++ b/merlin/models/tf/predictions/sampling/in_batch.py @@ -17,12 +17,12 @@ import tensorflow as tf -from merlin.models.tf.predictions.sampling.base import Items, ItemSampler +from merlin.models.tf.predictions.sampling.base import Items, ItemSamplerV2 -@ItemSampler.registry.register("in-batch") +@ItemSamplerV2.registry.register("in-batch") @tf.keras.utils.register_keras_serializable(package="merlin.models") -class InBatchSampler(ItemSampler): +class InBatchSamplerV2(ItemSamplerV2): """Provides in-batch sampling [1]_ for two-tower item retrieval models. The implementation is very simple, as it just returns the current item embeddings and metadata, but it is necessary to have @@ -71,6 +71,27 @@ def add(self, items: Items): def call( self, items: Items, features=None, targets=None, training=False, testing=False ) -> Items: + """Returns the item embeddings and item ids from + the current batch. + + Parameters + ---------- + items : Items + The items ids and their embeddings from the current batch + features : optional + The metadata with raw input features, by default None + targets : _type_, optional + The tensor of targets, by default None + training : bool, optional + Flag indicating if on training mode, by default False + testing : bool, optional + Flag indicating if on evaluation mode, by default False + + Returns + ------- + Items + NamedTuple with the sampled item ids and item metadata + """ self.add(items) items = self.sample() diff --git a/tests/unit/tf/predictions/test_dot_product.py b/tests/unit/tf/predictions/test_dot_product.py index 33220da592..4916ca1a60 100644 --- a/tests/unit/tf/predictions/test_dot_product.py +++ b/tests/unit/tf/predictions/test_dot_product.py @@ -7,14 +7,16 @@ ContrastiveDotProduct, DotProductCategoricalPrediction, ) -from merlin.models.tf.predictions.sampling.in_batch import InBatchSampler +from merlin.models.tf.predictions.sampling.in_batch import InBatchSamplerV2 from merlin.models.tf.utils import testing_utils def test_dot_product_prediction(ecommerce_data: Dataset): model = mm.RetrievalModel( mm.TwoTowerBlock(ecommerce_data.schema, query_tower=mm.MLPBlock([8])), - DotProductCategoricalPrediction(ecommerce_data.schema, negative_samplers=InBatchSampler()), + DotProductCategoricalPrediction( + ecommerce_data.schema, negative_samplers=InBatchSamplerV2() + ), ) _, history = testing_utils.model_test(model, ecommerce_data) @@ -33,16 +35,18 @@ def test_setting_negative_sampling_strategy(ecommerce_data: Dataset, run_eagerly def test_add_sampler(ecommerce_data: Dataset): model = mm.RetrievalModel( mm.TwoTowerBlock(ecommerce_data.schema, query_tower=mm.MLPBlock([8])), - DotProductCategoricalPrediction(ecommerce_data.schema, negative_samplers=InBatchSampler()), + DotProductCategoricalPrediction( + ecommerce_data.schema, negative_samplers=InBatchSamplerV2() + ), ) assert len(model.prediction_blocks[0].prediction_with_negatives.negative_samplers) == 1 - model.prediction_blocks[0].add_sampler(mm.InBatchSampler()) + model.prediction_blocks[0].add_sampler(mm.InBatchSamplerV2()) assert len(model.prediction_blocks[0].prediction_with_negatives.negative_samplers) == 2 def test_contrastive_dot_product(ecommerce_data: Dataset): batch_size = 10 - inbatch_sampler = InBatchSampler() + inbatch_sampler = InBatchSamplerV2() retrieval_scorer = ContrastiveDotProduct( schema=ecommerce_data.schema, @@ -73,7 +77,7 @@ def test_item_retrieval_scorer_no_sampler(ecommerce_data: Dataset): def test_item_retrieval_scorer_downscore_false_negatives(ecommerce_data: Dataset): batch_size = 10 - inbatch_sampler = InBatchSampler() + inbatch_sampler = InBatchSamplerV2() inputs, features = _retrieval_inputs_(batch_size=batch_size) FALSE_NEGATIVE_SCORE = -100_000_000.0 @@ -109,7 +113,7 @@ def test_item_retrieval_scorer_downscore_false_negatives(ecommerce_data: Dataset def test_retrieval_prediction_only_positive_when_not_training(ecommerce_data: Dataset): batch_size = 10 - inbatch_sampler = InBatchSampler() + inbatch_sampler = InBatchSamplerV2() item_retrieval_prediction = DotProductCategoricalPrediction( schema=ecommerce_data.schema, negative_samplers=[inbatch_sampler], diff --git a/tests/unit/tf/predictions/test_sampling.py b/tests/unit/tf/predictions/test_sampling.py new file mode 100644 index 0000000000..3d8d7ea3e7 --- /dev/null +++ b/tests/unit/tf/predictions/test_sampling.py @@ -0,0 +1,45 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import tensorflow as tf + +import merlin.models.tf as ml + + +def test_inbatch_sampler(): + item_embeddings = tf.random.uniform(shape=(10, 5), dtype=tf.float32) + item_ids = tf.random.uniform(shape=(10,), minval=1, maxval=10000, dtype=tf.int32) + + inbatch_sampler = ml.InBatchSamplerV2() + + input_data = ml.Items(item_ids, {"item_ids": item_ids}).with_embedding(item_embeddings) + output_data = inbatch_sampler(input_data) + + tf.assert_equal(input_data.embedding(), output_data.embedding()) + for feat_name in output_data.metadata: + tf.assert_equal(input_data.metadata[feat_name], output_data.metadata[feat_name]) + + +def test_inbatch_sampler_no_metadata_features(): + item_ids = tf.random.uniform(shape=(10,), minval=1, maxval=10000, dtype=tf.int32) + + inbatch_sampler = ml.InBatchSamplerV2() + + input_data = ml.Items(item_ids, {}) + output_data = inbatch_sampler(input_data) + + tf.assert_equal(input_data.id, output_data.id) + assert output_data.metadata == {} From 1befb657c5bbaff3fb44746916ada15fa6c20af4 Mon Sep 17 00:00:00 2001 From: sararb Date: Thu, 4 Aug 2022 12:23:25 -0400 Subject: [PATCH 26/27] add missing license from PR review --- merlin/models/tf/predictions/base.py | 15 +++++++++++++++ merlin/models/tf/predictions/dot_product.py | 16 ++++++++++++++++ merlin/models/tf/predictions/sampling/base.py | 15 +++++++++++++++ 3 files changed, 46 insertions(+) diff --git a/merlin/models/tf/predictions/base.py b/merlin/models/tf/predictions/base.py index 0eebb78764..0495b89c97 100644 --- a/merlin/models/tf/predictions/base.py +++ b/merlin/models/tf/predictions/base.py @@ -1,3 +1,18 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# from typing import List, Optional, Sequence, Union import tensorflow as tf diff --git a/merlin/models/tf/predictions/dot_product.py b/merlin/models/tf/predictions/dot_product.py index 35ced9e2cd..791f9d4232 100644 --- a/merlin/models/tf/predictions/dot_product.py +++ b/merlin/models/tf/predictions/dot_product.py @@ -1,3 +1,19 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import logging from typing import List, Optional, Sequence, Union diff --git a/merlin/models/tf/predictions/sampling/base.py b/merlin/models/tf/predictions/sampling/base.py index 97fc9594b2..11e0f60a2b 100644 --- a/merlin/models/tf/predictions/sampling/base.py +++ b/merlin/models/tf/predictions/sampling/base.py @@ -1,3 +1,18 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import abc from typing import Dict, List, NamedTuple, Optional, Sequence, Union From a17e430b2b42ca9594d7d653a0326f853ea043d5 Mon Sep 17 00:00:00 2001 From: sararb Date: Thu, 4 Aug 2022 12:42:40 -0400 Subject: [PATCH 27/27] remove unnecessary TODO --- merlin/models/tf/predictions/dot_product.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/merlin/models/tf/predictions/dot_product.py b/merlin/models/tf/predictions/dot_product.py index 791f9d4232..300323c3d6 100644 --- a/merlin/models/tf/predictions/dot_product.py +++ b/merlin/models/tf/predictions/dot_product.py @@ -35,7 +35,6 @@ @tf.keras.utils.register_keras_serializable(package="merlin_models") -# Or: RetrievalCategoricalPrediction class DotProductCategoricalPrediction(ContrastivePredictionBlock): """Contrastive prediction using negative-sampling, used in retrieval models. @@ -131,7 +130,6 @@ def compile(self, negative_sampling=None, downscore_false_negatives=False): self.prediction_with_negatives.negative_sampling = negative_sampling self.prediction_with_negatives.downscore_false_negatives = downscore_false_negatives - # TODO def add_sampler(self, sampler): self.prediction_with_negatives.negative_samplers.append(sampler)