Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing RetrievalModelV2 to unify two-tower & session-based enhancement #761

Merged
merged 13 commits into from
Oct 5, 2022
8 changes: 5 additions & 3 deletions merlin/models/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
ResidualBlock,
SequentialBlock,
)
from merlin.models.tf.core.encoder import EncoderBlock, TopKEncoder
from merlin.models.tf.core.encoder import EmbeddingEncoder, Encoder, TopKEncoder
from merlin.models.tf.inputs.base import InputBlock, InputBlockV2
from merlin.models.tf.inputs.continuous import Continuous, ContinuousFeatures, ContinuousProjection
from merlin.models.tf.inputs.embedding import (
Expand All @@ -100,7 +100,7 @@
TopKMetricsAggregator,
)
from merlin.models.tf.models import benchmark
from merlin.models.tf.models.base import BaseModel, Model, RetrievalModel
from merlin.models.tf.models.base import BaseModel, Model, RetrievalModel, RetrievalModelV2
from merlin.models.tf.models.ranking import DCNModel, DeepFMModel, DLRMModel, WideAndDeepModel
from merlin.models.tf.models.retrieval import (
MatrixFactorizationModel,
Expand Down Expand Up @@ -179,8 +179,9 @@
"SequentialBlock",
"ResidualBlock",
"DualEncoderBlock",
"EncoderBlock",
"TopKEncoder",
"Encoder",
"EmbeddingEncoder",
"CrossBlock",
"DLRMBlock",
"MLPBlock",
Expand Down Expand Up @@ -251,6 +252,7 @@
"TopKMetricsAggregator",
"Model",
"RetrievalModel",
"RetrievalModelV2",
"InputBlock",
"InputBlockV2",
"PredictionTasks",
Expand Down
175 changes: 157 additions & 18 deletions merlin/models/tf/core/encoder.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from typing import Optional, Union

import numpy as np
import tensorflow as tf
from packaging import version

import merlin.io
from merlin.models.tf.core import combinators
from merlin.models.tf.core.prediction import TopKPrediction
from merlin.models.tf.inputs.base import InputBlockV2
from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable
from merlin.models.tf.models.base import BaseModel
from merlin.models.tf.outputs.topk import TopKOutput
from merlin.models.tf.utils import tf_utils
from merlin.schema import ColumnSchema, Schema, Tags


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class EncoderBlock(tf.keras.Model):
class Encoder(tf.keras.Model):
"""Block that can be used for prediction & evaluation but not for training

Parameters
Expand Down Expand Up @@ -53,11 +55,99 @@ def __init__(
self.pre = pre
self.post = post

def call(self, inputs, **kwargs):
if "features" not in kwargs:
kwargs["features"] = inputs
def encode(
self,
dataset: merlin.io.Dataset,
index: Union[str, ColumnSchema, Schema, Tags],
batch_size: int,
**kwargs,
) -> merlin.io.Dataset:
if isinstance(index, Schema):
output_schema = index
elif isinstance(index, ColumnSchema):
output_schema = Schema([index])
elif isinstance(index, str):
output_schema = Schema([self.schema[index]])
elif isinstance(index, Tags):
output_schema = self.schema.select_by_tag(index)
else:
raise ValueError(f"Invalid index: {index}")

return self.batch_predict(
dataset,
batch_size=batch_size,
output_schema=output_schema,
index=index,
output_concat_func=np.concatenate,
**kwargs,
)

def batch_predict(
self,
dataset: merlin.io.Dataset,
batch_size: int,
output_schema: Optional[Schema] = None,
index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
**kwargs,
) -> merlin.io.Dataset:
"""Batched prediction using Dask.
Parameters
----------
dataset: merlin.io.Dataset
Dataset to predict on.
batch_size: int
Batch size to use for prediction.
Returns
-------
merlin.io.Dataset
"""

if index:
if isinstance(index, ColumnSchema):
index = Schema([index])
elif isinstance(index, str):
index = Schema([self.schema[index]])
elif isinstance(index, Tags):
index = self.schema.select_by_tag(index)
elif not isinstance(index, Schema):
raise ValueError(f"Invalid index: {index}")

if len(index) != 1:
raise ValueError("Only one column can be used as index")
index = index.first.name

if hasattr(dataset, "schema"):
if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)):
raise ValueError(
f"Model schema {self.schema.column_names} does not match dataset schema"
+ f" {dataset.schema.column_names}"
)

# Check if merlin-dataset is passed
if hasattr(dataset, "to_ddf"):
dataset = dataset.to_ddf()

from merlin.models.tf.utils.batch_utils import TFModelEncode

model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs)
encode_kwargs = {}
if output_schema:
encode_kwargs["filter_input_columns"] = output_schema.column_names
predictions = dataset.map_partitions(model_encode, **encode_kwargs)
if index:
predictions = predictions.set_index(index)

return combinators.call_sequentially(list(self.to_call), inputs=inputs, **kwargs)
return merlin.io.Dataset(predictions)

def call(self, inputs, training=False, testing=False, targets=None):
return combinators.call_sequentially(
list(self.to_call),
inputs=inputs,
features=inputs,
targets=targets,
training=training,
testing=testing,
)

def build(self, input_shape):
combinators.build_sequentially(self, list(self.to_call), input_shape=input_shape)
Expand Down Expand Up @@ -141,7 +231,10 @@ def from_config(cls, config, custom_objects=None):
if post is not None:
post = tf.keras.layers.deserialize(post, custom_objects=custom_objects)

return cls(*layers, pre=pre, post=post)
output = Encoder(*layers, pre=pre, post=post)
output.__class__ = cls

return output

def get_config(self):
config = tf_utils.maybe_serialize_keras_objects(self, {}, ["pre", "post"])
Expand All @@ -152,13 +245,13 @@ def get_config(self):


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class TopKEncoder(EncoderBlock, BaseModel):
class TopKEncoder(Encoder, BaseModel):
"""Block that can be used for top-k prediction & evaluation, initialized
from a trained retrieval model

Parameters
----------
query_encoder: Union[EncoderBlock, tf.keras.layers.Layer],
query_encoder: Union[Encoder, tf.keras.layers.Layer],
The layer to use for encoding the query features
topk_layer: Union[str, tf.keras.layers.Layer, TopKOutput]
The layer to use for computing the top-k predictions.
Expand All @@ -172,7 +265,7 @@ class TopKEncoder(EncoderBlock, BaseModel):
the candidates ids.
This is required when `topk_layer` is a string
By default None
candidate_encoder: Union[EncoderBlock, tf.keras.layers.Layer],
candidate_encoder: Union[Encoder, tf.keras.layers.Layer],
The layer to use for encoding the item features
k: int, Optional
Number of candidates to return, by default 10
Expand All @@ -186,10 +279,10 @@ class TopKEncoder(EncoderBlock, BaseModel):

def __init__(
self,
query_encoder: Union[EncoderBlock, tf.keras.layers.Layer],
query_encoder: Union[Encoder, tf.keras.layers.Layer],
topk_layer: Union[str, tf.keras.layers.Layer, TopKOutput] = "brute-force-topk",
candidates: Union[tf.Tensor, merlin.io.Dataset] = None,
candidate_encoder: Union[EncoderBlock, tf.keras.layers.Layer] = None,
candidate_encoder: Union[Encoder, tf.keras.layers.Layer] = None,
k: int = 10,
pre: Optional[tf.keras.layers.Layer] = None,
post: Optional[tf.keras.layers.Layer] = None,
Expand All @@ -201,15 +294,15 @@ def __init__(
topk_output = TopKOutput(to_call=topk_layer, candidates=candidates, k=k, **kwargs)
self.k = k

EncoderBlock.__init__(self, query_encoder, topk_output, pre=pre, post=post, **kwargs)
Encoder.__init__(self, query_encoder, topk_output, pre=pre, post=post, **kwargs)
# The base model is required for the evaluation step:
BaseModel.__init__(self, **kwargs)

@classmethod
def from_candidate_dataset(
cls,
query_encoder: Union[EncoderBlock, tf.keras.layers.Layer],
candidate_encoder: Union[EncoderBlock, tf.keras.layers.Layer],
query_encoder: Union[Encoder, tf.keras.layers.Layer],
candidate_encoder: Union[Encoder, tf.keras.layers.Layer],
dataset: merlin.io.Dataset,
top_k: int = 10,
index_column: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
Expand All @@ -220,9 +313,9 @@ def from_candidate_dataset(

Parameters
----------
query_encoder : Union[EncoderBlock, tf.keras.layers.Layer]
query_encoder : Union[Encoder, tf.keras.layers.Layer]
The encoder layer to use for computing the query embeddings.
candidate_encoder : Union[EncoderBlock, tf.keras.layers.Layer]
candidate_encoder : Union[Encoder, tf.keras.layers.Layer]
The encoder layer to use for computing the candidates embeddings.
dataset : merlin.io.Dataset
Raw candidate features dataset
Expand Down Expand Up @@ -259,7 +352,7 @@ def encode_candidates(
self,
dataset: merlin.io.Dataset,
index_column: Optional[Union[str, ColumnSchema, Schema, Tags]] = None,
candidate_encoder: Optional[Union[EncoderBlock, tf.keras.layers.Layer]] = None,
candidate_encoder: Optional[Union[Encoder, tf.keras.layers.Layer]] = None,
**kwargs,
) -> merlin.io.Dataset:
"""Method to generate candidates embeddings
Expand All @@ -273,7 +366,7 @@ def encode_candidates(
for returning the topk ids of candidates with the highest scores.
If not specified, the candidates indices will be used instead.
by default None
candidate_encoder : Union[EncoderBlock, tf.keras.layers.Layer], optional
candidate_encoder : Union[Encoder, tf.keras.layers.Layer], optional
The encoder layer to use for computing the candidates embeddings.
If not specified, the candidate_encoder set in the constructor
will be used instead.
Expand Down Expand Up @@ -339,3 +432,49 @@ def fit(self, *args, **kwargs):
"This block is not meant to be trained by itself. ",
"It can only be trained as part of a model.",
)


@tf.keras.utils.register_keras_serializable(package="merlin.models")
class EmbeddingEncoder(Encoder):
def __init__(
self,
schema: Union[ColumnSchema, Schema],
dim: int,
embeddings_initializer="uniform",
embeddings_regularizer=None,
activity_regularizer=None,
embeddings_constraint=None,
mask_zero=False,
input_length=None,
sequence_combiner: Optional[CombinerType] = None,
trainable=True,
name=None,
dtype=None,
dynamic=False,
):
if isinstance(schema, ColumnSchema):
col = schema
else:
col = schema.first
col_name = col.name

table = EmbeddingTable(
dim,
col,
embeddings_initializer=embeddings_initializer,
embeddings_regularizer=embeddings_regularizer,
activity_regularizer=activity_regularizer,
embeddings_constraint=embeddings_constraint,
mask_zero=mask_zero,
input_length=input_length,
sequence_combiner=sequence_combiner,
trainable=trainable,
name=name,
dtype=dtype,
dynamic=dynamic,
)

super().__init__(table, tf.keras.layers.Lambda(lambda x: x[col_name]))

def to_dataset(self, gpu=None) -> merlin.io.Dataset:
return self.blocks[0].to_dataset(gpu=gpu)
3 changes: 2 additions & 1 deletion merlin/models/tf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from merlin.core.dispatch import DataFrameType
from merlin.models.tf.core.base import Block, PredictionOutput
from merlin.models.tf.utils import tf_utils
from merlin.models.tf.utils.batch_utils import TFModelEncode
from merlin.models.utils.constants import MIN_FLOAT
from merlin.schema import Tags

Expand Down Expand Up @@ -104,6 +103,8 @@ def extract_ids_embeddings(cls, data: merlin.io.Dataset, check_unique_ids: bool
def get_candidates_dataset(
cls, block: Block, data: merlin.io.Dataset, id_column: Optional[str] = None
):
from merlin.models.tf.utils.batch_utils import TFModelEncode

if not id_column and getattr(block, "schema", None):
tagged = block.schema.select_by_tag(Tags.ITEM_ID)
if tagged.column_schemas:
Expand Down
39 changes: 36 additions & 3 deletions merlin/models/tf/inputs/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
# https://github.com/PyCQA/pylint/issues/3613
# pylint: disable=no-value-for-parameter, unexpected-keyword-arg
from merlin.models.tf.typing import TabularData
from merlin.models.tf.utils.tf_utils import call_layer, df_to_tensor, list_col_to_ragged
from merlin.models.tf.utils.tf_utils import (
call_layer,
df_to_tensor,
list_col_to_ragged,
tensor_to_df,
)
from merlin.models.utils import schema_utils
from merlin.models.utils.doc_utils import docstring_parameter
from merlin.models.utils.schema_utils import (
Expand Down Expand Up @@ -281,9 +286,8 @@ def from_pretrained(
name=None,
col_schema=None,
**kwargs,
):
) -> "EmbeddingTable":
"""Create From pre-trained embeddings from a Dataset or DataFrame.

Parameters
----------
data : Union[Dataset, DataFrameType]
Expand Down Expand Up @@ -313,6 +317,35 @@ def from_pretrained(
**kwargs,
)

@classmethod
def from_dataset(
cls,
data: Union[Dataset, DataFrameType],
trainable=True,
name=None,
col_schema=None,
**kwargs,
) -> "EmbeddingTable":
"""Create From pre-trained embeddings from a Dataset or DataFrame.
Parameters
----------
data : Union[Dataset, DataFrameType]
A dataset containing the pre-trained embedding weights
trainable : bool
Whether the layer should be trained or not.
name : str
The name of the layer.
"""
return cls.from_pretrained(
data, trainable=trainable, name=name, col_schema=col_schema, **kwargs
)

def to_dataset(self, gpu=None) -> merlin.io.Dataset:
return merlin.io.Dataset(self.to_df(gpu=gpu))

def to_df(self, gpu=None):
return tensor_to_df(self.table.embeddings, gpu=gpu)

def _maybe_build(self, inputs):
"""Creates state between layer instantiation and layer call.
Invoked automatically before the first execution of `call()`.
Expand Down
Loading