Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tf/contrastive prediction #594

Merged
merged 29 commits into from
Aug 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
2cf4aff
First commit
marcromeyn Jul 18, 2022
e0aed3f
Introducing PredictionBlock
marcromeyn Jul 18, 2022
75135de
Running black
marcromeyn Jul 19, 2022
1508f38
Some fixes
marcromeyn Jul 19, 2022
bff9e31
Making metric-names better
marcromeyn Jul 19, 2022
c72cdef
Move arguments in BinaryPrediction + RegressionPrediction
marcromeyn Jul 19, 2022
e8f1664
Updating type-hints
marcromeyn Jul 19, 2022
7edd18a
Trying to fix failing tests
marcromeyn Jul 19, 2022
6cf887d
Trying to fix failing tests
marcromeyn Jul 19, 2022
d0e4728
Trying to fix failing tests
marcromeyn Jul 19, 2022
0c25197
Adding BinaryPrediction test
marcromeyn Jul 21, 2022
d3f2664
Adding RegressionPrediction test
marcromeyn Jul 21, 2022
2f4d5ed
First commit
marcromeyn Jul 19, 2022
8142ab0
Make in-batch the default
marcromeyn Jul 19, 2022
7c1a234
First pass over DotProduct
marcromeyn Jul 19, 2022
d180893
Making test_dot_product_prediction pass
marcromeyn Jul 20, 2022
8758b33
Adding TODO
marcromeyn Jul 22, 2022
b31ab0c
unify top-k metrics in one topkaggregator instanc
sararb Jul 29, 2022
d86081f
fix split_metrics
sararb Aug 2, 2022
162fa48
add negative_sampling to compile method
sararb Aug 2, 2022
b064fd3
update contrastive prediction block
sararb Aug 2, 2022
7b8fb94
fix merge conflict
sararb Aug 3, 2022
e350b18
fix missing imports
sararb Aug 3, 2022
e5c58bc
fix failing test
sararb Aug 3, 2022
2ce437d
add docstrings
sararb Aug 4, 2022
015ac6e
update names of new sampling classes
sararb Aug 4, 2022
1befb65
add missing license from PR review
sararb Aug 4, 2022
a17e430
remove unnecessary TODO
sararb Aug 4, 2022
e74b63a
Merge branch 'main' into tf/contrastive-prediction
rnyak Aug 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions merlin/models/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
from merlin.models.tf.predictions.base import PredictionBlock
from merlin.models.tf.predictions.classification import BinaryPrediction
from merlin.models.tf.predictions.regression import RegressionPrediction
from merlin.models.tf.predictions.sampling.base import Items, ItemSamplerV2
from merlin.models.tf.predictions.sampling.in_batch import InBatchSamplerV2
from merlin.models.tf.utils import repr_utils
from merlin.models.tf.utils.tf_utils import TensorInitializer

Expand Down
49 changes: 48 additions & 1 deletion merlin/models/tf/metrics/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#

# Adapted from source code: https://github.com/karlhigley/ranking-metrics-torch
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Tuple, Union

import tensorflow as tf
from keras.utils import losses_utils, metrics_utils
Expand Down Expand Up @@ -301,36 +301,42 @@ def from_config(cls, config):
return super(TopkMetric, cls).from_config(config)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
@metrics_registry.register_with_multiple_names("recall_at", "recall")
class RecallAt(TopkMetric):
def __init__(self, k=10, pre_sorted=False, name="recall_at"):
super().__init__(recall_at, k=k, pre_sorted=pre_sorted, name=name)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
@metrics_registry.register_with_multiple_names("precision_at", "precision")
class PrecisionAt(TopkMetric):
def __init__(self, k=10, pre_sorted=False, name="precision_at"):
super().__init__(precision_at, k=k, pre_sorted=pre_sorted, name=name)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
@metrics_registry.register_with_multiple_names("map_at", "map")
class AvgPrecisionAt(TopkMetric):
def __init__(self, k=10, pre_sorted=False, name="map_at"):
super().__init__(average_precision_at, k=k, pre_sorted=pre_sorted, name=name)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
@metrics_registry.register_with_multiple_names("mrr_at", "mrr")
class MRRAt(TopkMetric):
def __init__(self, k=10, pre_sorted=False, name="mrr_at"):
super().__init__(mrr_at, k=k, pre_sorted=pre_sorted, name=name)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
@metrics_registry.register_with_multiple_names("ndcg_at", "ndcg")
class NDCGAt(TopkMetric):
def __init__(self, k=10, pre_sorted=False, name="ndcg_at"):
super().__init__(ndcg_at, k=k, pre_sorted=pre_sorted, name=name)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class TopKMetricsAggregator(Metric, TopkMetricWithLabelRelevantCountsMixin):
"""Aggregator for top-k metrics (TopkMetric) that is optimized
to sort top-k predictions only once for all metrics.
Expand Down Expand Up @@ -409,6 +415,20 @@ def default_metrics(cls, top_ks: Sequence[int], **kwargs) -> Sequence[TopkMetric
aggregator = cls(*metrics)
return [aggregator]

def get_config(self):
config = {}
for i, metric in enumerate(self.topk_metrics):
config[i] = tf.keras.utils.serialize_keras_object(metric)
return config

@classmethod
def from_config(cls, config, custom_objects=None):
metrics = [
tf.keras.layers.deserialize(conf, custom_objects=custom_objects)
for conf in config.values()
]
return TopKMetricsAggregator(*metrics)


def filter_topk_metrics(
metrics: Sequence[Metric],
Expand All @@ -433,3 +453,30 @@ def filter_topk_metrics(
]
)
return topk_metrics


def split_metrics(
metrics: Sequence[Metric],
return_other_metrics: bool = False,
) -> Tuple[TopkMetric, TopKMetricsAggregator, Metric]:
"""Split the list of metrics into top-k metrics, top-k aggregators and others

Parameters
----------
metrics : List[Metric]
List of metrics

Returns
-------
List[TopkMetric, TopKMetricsAggregator, Metric]
List with the top-k metrics in the list of input metrics
"""
topk_metrics, topk_aggregators, other_metrics = [], [], []
for metric in metrics:
if isinstance(metric, TopkMetric):
topk_metrics.append(metric)
elif isinstance(metric, TopKMetricsAggregator):
topk_aggregators.append(metric)
else:
other_metrics.append(metric)
return topk_metrics, topk_aggregators, other_metrics
24 changes: 19 additions & 5 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from merlin.models.tf.dataset import BatchedDataset
from merlin.models.tf.inputs.base import InputBlock
from merlin.models.tf.losses.base import loss_registry
from merlin.models.tf.metrics.topk import filter_topk_metrics
from merlin.models.tf.metrics.topk import TopKMetricsAggregator, filter_topk_metrics, split_metrics
from merlin.models.tf.models.utils import parse_prediction_tasks
from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask
from merlin.models.tf.predictions.base import PredictionBlock
from merlin.models.tf.predictions.base import ContrastivePredictionBlock, PredictionBlock
from merlin.models.tf.typing import TabularData
from merlin.models.tf.utils.search_utils import find_all_instances_in_layers
from merlin.models.tf.utils.tf_utils import call_layer, maybe_serialize_keras_objects
Expand Down Expand Up @@ -303,6 +303,14 @@ def compile(
self.output_names = [task.task_name for task in self.prediction_tasks]
else:
self.output_names = [block.full_name for block in self.prediction_blocks]
negative_sampling = kwargs.pop("negative_sampling", None)
if negative_sampling:
if not isinstance(self.prediction_blocks[0], ContrastivePredictionBlock):
raise ValueError(
"Negative sampling strategy can be used only with a"
" `ContrastivePredictionBlock` prediction block"
)
self.prediction_blocks[0].compile(negative_sampling=negative_sampling)

# This flag will make Keras change the metric-names which is not needed in v2
from_serialized = kwargs.pop("from_serialized", num_v2_blocks > 0)
Expand All @@ -324,11 +332,16 @@ def _create_metrics(self, metrics=None):
out = {}

num_v1_blocks = len(self.prediction_tasks)

if isinstance(metrics, dict):
out = metrics

elif isinstance(metrics, (list, tuple)):
# Retrieve top-k metrics & wrap them in TopKMetricsAggregator
topk_metrics, topk_aggregators, other_metrics = split_metrics(metrics)
if len(topk_metrics) > 0:
topk_aggregators.append(TopKMetricsAggregator(*topk_metrics))
metrics = other_metrics + topk_aggregators

if num_v1_blocks > 0:
if num_v1_blocks == 1:
out[self.prediction_tasks[0].task_name] = metrics
Expand Down Expand Up @@ -625,8 +638,9 @@ def compute_metrics(
# Providing label_relevant_counts for TopkMetrics, as metric.update_state()
# should have standard signature for better compatibility with Keras methods
# like self.compiled_metrics.update_state()
for topk_metric in filter_topk_metrics(self.compiled_metrics.metrics):
topk_metric.label_relevant_counts = prediction_outputs.label_relevant_counts
if hasattr(prediction_outputs, "label_relevant_counts"):
for topk_metric in filter_topk_metrics(self.compiled_metrics.metrics):
topk_metric.label_relevant_counts = prediction_outputs.label_relevant_counts

self.compiled_metrics.update_state(
prediction_outputs.targets,
Expand Down
104 changes: 104 additions & 0 deletions merlin/models/tf/predictions/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List, Optional, Sequence, Union

import tensorflow as tf
Expand All @@ -8,6 +23,7 @@
from merlin.models.tf.core.prediction import Prediction
from merlin.models.tf.core.transformations import LogitsTemperatureScaler
from merlin.models.tf.utils import tf_utils
from merlin.models.tf.utils.tf_utils import call_layer


@tf.keras.utils.register_keras_serializable(package="merlin.models")
Expand Down Expand Up @@ -179,3 +195,91 @@ def from_config(cls, config):
)

return super().from_config(config)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class ContrastivePredictionBlock(PredictionBlock):
"""Base-class for prediction blocks that uses contrastive loss.

Parameters
----------
prediction : Layer
The prediction layer
prediction_with_negatives : Layer
The prediction layer that includes negative sampling
default_loss: Union[str, tf.keras.losses.Loss]
Default loss to set if the user does not specify one
default_metrics: Sequence[tf.keras.metrics.Metric]
Default metrics to set if the user does not specify any
name: Optional[Text], optional
Task name, by default None
target: Optional[str], optional
Label name, by default None
pre: Optional[Block], optional
Optional block to transform predictions before applying the prediction layer,
by default None
post: Optional[Block], optional
Optional block to transform predictions after applying the prediction layer,
by default None
logits_temperature: float, optional
Parameter used to reduce model overconfidence, so that logits / T.
by default 1.
"""

def __init__(
self,
prediction: Layer,
prediction_with_negatives: Layer,
default_loss: Union[str, tf.keras.losses.Loss],
default_metrics: Sequence[tf.keras.metrics.Metric],
name: Optional[str] = None,
target: Optional[str] = None,
pre: Optional[Layer] = None,
post: Optional[Layer] = None,
logits_temperature: float = 1.0,
**kwargs,
):

super(ContrastivePredictionBlock, self).__init__(
prediction,
default_loss=default_loss,
default_metrics=default_metrics,
target=target,
pre=pre,
post=post,
logits_temperature=logits_temperature,
name=name,
**kwargs,
)
self.prediction_with_negatives = prediction_with_negatives

def call(self, inputs, training=False, testing=False, **kwargs):
to_call = self.prediction

if self.prediction_with_negatives.has_negative_samplers and (training or testing):
to_call = self.prediction_with_negatives

return call_layer(to_call, inputs, training=training, testing=testing, **kwargs)

def get_config(self):
config = super(ContrastivePredictionBlock, self).get_config()
config.update(
{
"prediction_with_negatives": tf.keras.utils.serialize_keras_object(
self.prediction_with_negatives
),
}
)

return config

@classmethod
def from_config(cls, config):
config = tf_utils.maybe_deserialize_keras_objects(
config,
{
"prediction_with_negatives": tf.keras.layers.deserialize,
},
)

return super().from_config(config)
Loading