diff --git a/merlin/models/tf/__init__.py b/merlin/models/tf/__init__.py index e03f965a9f..3c64960e13 100644 --- a/merlin/models/tf/__init__.py +++ b/merlin/models/tf/__init__.py @@ -75,7 +75,7 @@ ResidualBlock, SequentialBlock, ) -from merlin.models.tf.core.encoder import EncoderBlock, TopKEncoder +from merlin.models.tf.core.encoder import EmbeddingEncoder, Encoder, TopKEncoder from merlin.models.tf.inputs.base import InputBlock, InputBlockV2 from merlin.models.tf.inputs.continuous import Continuous, ContinuousFeatures, ContinuousProjection from merlin.models.tf.inputs.embedding import ( @@ -100,7 +100,7 @@ TopKMetricsAggregator, ) from merlin.models.tf.models import benchmark -from merlin.models.tf.models.base import BaseModel, Model, RetrievalModel +from merlin.models.tf.models.base import BaseModel, Model, RetrievalModel, RetrievalModelV2 from merlin.models.tf.models.ranking import DCNModel, DeepFMModel, DLRMModel, WideAndDeepModel from merlin.models.tf.models.retrieval import ( MatrixFactorizationModel, @@ -179,8 +179,9 @@ "SequentialBlock", "ResidualBlock", "DualEncoderBlock", - "EncoderBlock", "TopKEncoder", + "Encoder", + "EmbeddingEncoder", "CrossBlock", "DLRMBlock", "MLPBlock", @@ -251,6 +252,7 @@ "TopKMetricsAggregator", "Model", "RetrievalModel", + "RetrievalModelV2", "InputBlock", "InputBlockV2", "PredictionTasks", diff --git a/merlin/models/tf/core/encoder.py b/merlin/models/tf/core/encoder.py index 4f84502ceb..6177f91232 100644 --- a/merlin/models/tf/core/encoder.py +++ b/merlin/models/tf/core/encoder.py @@ -1,5 +1,6 @@ from typing import Optional, Union +import numpy as np import tensorflow as tf from packaging import version @@ -7,6 +8,7 @@ from merlin.models.tf.core import combinators from merlin.models.tf.core.prediction import TopKPrediction from merlin.models.tf.inputs.base import InputBlockV2 +from merlin.models.tf.inputs.embedding import CombinerType, EmbeddingTable from merlin.models.tf.models.base import BaseModel from merlin.models.tf.outputs.topk import TopKOutput from merlin.models.tf.utils import tf_utils @@ -14,7 +16,7 @@ @tf.keras.utils.register_keras_serializable(package="merlin.models") -class EncoderBlock(tf.keras.Model): +class Encoder(tf.keras.Model): """Block that can be used for prediction & evaluation but not for training Parameters @@ -53,11 +55,99 @@ def __init__( self.pre = pre self.post = post - def call(self, inputs, **kwargs): - if "features" not in kwargs: - kwargs["features"] = inputs + def encode( + self, + dataset: merlin.io.Dataset, + index: Union[str, ColumnSchema, Schema, Tags], + batch_size: int, + **kwargs, + ) -> merlin.io.Dataset: + if isinstance(index, Schema): + output_schema = index + elif isinstance(index, ColumnSchema): + output_schema = Schema([index]) + elif isinstance(index, str): + output_schema = Schema([self.schema[index]]) + elif isinstance(index, Tags): + output_schema = self.schema.select_by_tag(index) + else: + raise ValueError(f"Invalid index: {index}") + + return self.batch_predict( + dataset, + batch_size=batch_size, + output_schema=output_schema, + index=index, + output_concat_func=np.concatenate, + **kwargs, + ) + + def batch_predict( + self, + dataset: merlin.io.Dataset, + batch_size: int, + output_schema: Optional[Schema] = None, + index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, + **kwargs, + ) -> merlin.io.Dataset: + """Batched prediction using Dask. + Parameters + ---------- + dataset: merlin.io.Dataset + Dataset to predict on. + batch_size: int + Batch size to use for prediction. + Returns + ------- + merlin.io.Dataset + """ + + if index: + if isinstance(index, ColumnSchema): + index = Schema([index]) + elif isinstance(index, str): + index = Schema([self.schema[index]]) + elif isinstance(index, Tags): + index = self.schema.select_by_tag(index) + elif not isinstance(index, Schema): + raise ValueError(f"Invalid index: {index}") + + if len(index) != 1: + raise ValueError("Only one column can be used as index") + index = index.first.name + + if hasattr(dataset, "schema"): + if not set(self.schema.column_names).issubset(set(dataset.schema.column_names)): + raise ValueError( + f"Model schema {self.schema.column_names} does not match dataset schema" + + f" {dataset.schema.column_names}" + ) + + # Check if merlin-dataset is passed + if hasattr(dataset, "to_ddf"): + dataset = dataset.to_ddf() + + from merlin.models.tf.utils.batch_utils import TFModelEncode + + model_encode = TFModelEncode(self, batch_size=batch_size, **kwargs) + encode_kwargs = {} + if output_schema: + encode_kwargs["filter_input_columns"] = output_schema.column_names + predictions = dataset.map_partitions(model_encode, **encode_kwargs) + if index: + predictions = predictions.set_index(index) - return combinators.call_sequentially(list(self.to_call), inputs=inputs, **kwargs) + return merlin.io.Dataset(predictions) + + def call(self, inputs, training=False, testing=False, targets=None): + return combinators.call_sequentially( + list(self.to_call), + inputs=inputs, + features=inputs, + targets=targets, + training=training, + testing=testing, + ) def build(self, input_shape): combinators.build_sequentially(self, list(self.to_call), input_shape=input_shape) @@ -141,7 +231,10 @@ def from_config(cls, config, custom_objects=None): if post is not None: post = tf.keras.layers.deserialize(post, custom_objects=custom_objects) - return cls(*layers, pre=pre, post=post) + output = Encoder(*layers, pre=pre, post=post) + output.__class__ = cls + + return output def get_config(self): config = tf_utils.maybe_serialize_keras_objects(self, {}, ["pre", "post"]) @@ -152,13 +245,13 @@ def get_config(self): @tf.keras.utils.register_keras_serializable(package="merlin.models") -class TopKEncoder(EncoderBlock, BaseModel): +class TopKEncoder(Encoder, BaseModel): """Block that can be used for top-k prediction & evaluation, initialized from a trained retrieval model Parameters ---------- - query_encoder: Union[EncoderBlock, tf.keras.layers.Layer], + query_encoder: Union[Encoder, tf.keras.layers.Layer], The layer to use for encoding the query features topk_layer: Union[str, tf.keras.layers.Layer, TopKOutput] The layer to use for computing the top-k predictions. @@ -172,7 +265,7 @@ class TopKEncoder(EncoderBlock, BaseModel): the candidates ids. This is required when `topk_layer` is a string By default None - candidate_encoder: Union[EncoderBlock, tf.keras.layers.Layer], + candidate_encoder: Union[Encoder, tf.keras.layers.Layer], The layer to use for encoding the item features k: int, Optional Number of candidates to return, by default 10 @@ -186,10 +279,10 @@ class TopKEncoder(EncoderBlock, BaseModel): def __init__( self, - query_encoder: Union[EncoderBlock, tf.keras.layers.Layer], + query_encoder: Union[Encoder, tf.keras.layers.Layer], topk_layer: Union[str, tf.keras.layers.Layer, TopKOutput] = "brute-force-topk", candidates: Union[tf.Tensor, merlin.io.Dataset] = None, - candidate_encoder: Union[EncoderBlock, tf.keras.layers.Layer] = None, + candidate_encoder: Union[Encoder, tf.keras.layers.Layer] = None, k: int = 10, pre: Optional[tf.keras.layers.Layer] = None, post: Optional[tf.keras.layers.Layer] = None, @@ -201,15 +294,15 @@ def __init__( topk_output = TopKOutput(to_call=topk_layer, candidates=candidates, k=k, **kwargs) self.k = k - EncoderBlock.__init__(self, query_encoder, topk_output, pre=pre, post=post, **kwargs) + Encoder.__init__(self, query_encoder, topk_output, pre=pre, post=post, **kwargs) # The base model is required for the evaluation step: BaseModel.__init__(self, **kwargs) @classmethod def from_candidate_dataset( cls, - query_encoder: Union[EncoderBlock, tf.keras.layers.Layer], - candidate_encoder: Union[EncoderBlock, tf.keras.layers.Layer], + query_encoder: Union[Encoder, tf.keras.layers.Layer], + candidate_encoder: Union[Encoder, tf.keras.layers.Layer], dataset: merlin.io.Dataset, top_k: int = 10, index_column: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, @@ -220,9 +313,9 @@ def from_candidate_dataset( Parameters ---------- - query_encoder : Union[EncoderBlock, tf.keras.layers.Layer] + query_encoder : Union[Encoder, tf.keras.layers.Layer] The encoder layer to use for computing the query embeddings. - candidate_encoder : Union[EncoderBlock, tf.keras.layers.Layer] + candidate_encoder : Union[Encoder, tf.keras.layers.Layer] The encoder layer to use for computing the candidates embeddings. dataset : merlin.io.Dataset Raw candidate features dataset @@ -259,7 +352,7 @@ def encode_candidates( self, dataset: merlin.io.Dataset, index_column: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, - candidate_encoder: Optional[Union[EncoderBlock, tf.keras.layers.Layer]] = None, + candidate_encoder: Optional[Union[Encoder, tf.keras.layers.Layer]] = None, **kwargs, ) -> merlin.io.Dataset: """Method to generate candidates embeddings @@ -273,7 +366,7 @@ def encode_candidates( for returning the topk ids of candidates with the highest scores. If not specified, the candidates indices will be used instead. by default None - candidate_encoder : Union[EncoderBlock, tf.keras.layers.Layer], optional + candidate_encoder : Union[Encoder, tf.keras.layers.Layer], optional The encoder layer to use for computing the candidates embeddings. If not specified, the candidate_encoder set in the constructor will be used instead. @@ -339,3 +432,49 @@ def fit(self, *args, **kwargs): "This block is not meant to be trained by itself. ", "It can only be trained as part of a model.", ) + + +@tf.keras.utils.register_keras_serializable(package="merlin.models") +class EmbeddingEncoder(Encoder): + def __init__( + self, + schema: Union[ColumnSchema, Schema], + dim: int, + embeddings_initializer="uniform", + embeddings_regularizer=None, + activity_regularizer=None, + embeddings_constraint=None, + mask_zero=False, + input_length=None, + sequence_combiner: Optional[CombinerType] = None, + trainable=True, + name=None, + dtype=None, + dynamic=False, + ): + if isinstance(schema, ColumnSchema): + col = schema + else: + col = schema.first + col_name = col.name + + table = EmbeddingTable( + dim, + col, + embeddings_initializer=embeddings_initializer, + embeddings_regularizer=embeddings_regularizer, + activity_regularizer=activity_regularizer, + embeddings_constraint=embeddings_constraint, + mask_zero=mask_zero, + input_length=input_length, + sequence_combiner=sequence_combiner, + trainable=trainable, + name=name, + dtype=dtype, + dynamic=dynamic, + ) + + super().__init__(table, tf.keras.layers.Lambda(lambda x: x[col_name])) + + def to_dataset(self, gpu=None) -> merlin.io.Dataset: + return self.blocks[0].to_dataset(gpu=gpu) diff --git a/merlin/models/tf/core/index.py b/merlin/models/tf/core/index.py index 12000b835d..f73be16cda 100644 --- a/merlin/models/tf/core/index.py +++ b/merlin/models/tf/core/index.py @@ -23,7 +23,6 @@ from merlin.core.dispatch import DataFrameType from merlin.models.tf.core.base import Block, PredictionOutput from merlin.models.tf.utils import tf_utils -from merlin.models.tf.utils.batch_utils import TFModelEncode from merlin.models.utils.constants import MIN_FLOAT from merlin.schema import Tags @@ -104,6 +103,8 @@ def extract_ids_embeddings(cls, data: merlin.io.Dataset, check_unique_ids: bool def get_candidates_dataset( cls, block: Block, data: merlin.io.Dataset, id_column: Optional[str] = None ): + from merlin.models.tf.utils.batch_utils import TFModelEncode + if not id_column and getattr(block, "schema", None): tagged = block.schema.select_by_tag(Tags.ITEM_ID) if tagged.column_schemas: diff --git a/merlin/models/tf/inputs/embedding.py b/merlin/models/tf/inputs/embedding.py index de888b4ee6..165f1035ad 100644 --- a/merlin/models/tf/inputs/embedding.py +++ b/merlin/models/tf/inputs/embedding.py @@ -42,7 +42,12 @@ # https://github.com/PyCQA/pylint/issues/3613 # pylint: disable=no-value-for-parameter, unexpected-keyword-arg from merlin.models.tf.typing import TabularData -from merlin.models.tf.utils.tf_utils import call_layer, df_to_tensor, list_col_to_ragged +from merlin.models.tf.utils.tf_utils import ( + call_layer, + df_to_tensor, + list_col_to_ragged, + tensor_to_df, +) from merlin.models.utils import schema_utils from merlin.models.utils.doc_utils import docstring_parameter from merlin.models.utils.schema_utils import ( @@ -281,9 +286,8 @@ def from_pretrained( name=None, col_schema=None, **kwargs, - ): + ) -> "EmbeddingTable": """Create From pre-trained embeddings from a Dataset or DataFrame. - Parameters ---------- data : Union[Dataset, DataFrameType] @@ -313,6 +317,35 @@ def from_pretrained( **kwargs, ) + @classmethod + def from_dataset( + cls, + data: Union[Dataset, DataFrameType], + trainable=True, + name=None, + col_schema=None, + **kwargs, + ) -> "EmbeddingTable": + """Create From pre-trained embeddings from a Dataset or DataFrame. + Parameters + ---------- + data : Union[Dataset, DataFrameType] + A dataset containing the pre-trained embedding weights + trainable : bool + Whether the layer should be trained or not. + name : str + The name of the layer. + """ + return cls.from_pretrained( + data, trainable=trainable, name=name, col_schema=col_schema, **kwargs + ) + + def to_dataset(self, gpu=None) -> merlin.io.Dataset: + return merlin.io.Dataset(self.to_df(gpu=gpu)) + + def to_df(self, gpu=None): + return tensor_to_df(self.table.embeddings, gpu=gpu) + def _maybe_build(self, inputs): """Creates state between layer instantiation and layer call. Invoked automatically before the first execution of `call()`. diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index bc0f961cf0..394675a8e5 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -15,7 +15,7 @@ import merlin.io from merlin.models.tf.core.base import Block, ModelContext, PredictionOutput, is_input_block -from merlin.models.tf.core.combinators import SequentialBlock +from merlin.models.tf.core.combinators import ParallelBlock, SequentialBlock from merlin.models.tf.core.prediction import Prediction, PredictionContext from merlin.models.tf.core.tabular import TabularBlock from merlin.models.tf.inputs.base import InputBlock @@ -24,6 +24,7 @@ from merlin.models.tf.metrics.topk import TopKMetricsAggregator, filter_topk_metrics, split_metrics from merlin.models.tf.models.utils import parse_prediction_tasks from merlin.models.tf.outputs.base import ModelOutput +from merlin.models.tf.outputs.contrastive import ContrastiveOutput from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask from merlin.models.tf.transforms.tensor import ListToRagged from merlin.models.tf.typing import TabularData @@ -34,9 +35,10 @@ maybe_serialize_keras_objects, ) from merlin.models.utils.dataset import unique_rows_by_features -from merlin.schema import Schema, Tags +from merlin.schema import ColumnSchema, Schema, Tags if TYPE_CHECKING: + from merlin.models.tf.core.encoder import Encoder from merlin.models.tf.core.index import TopKIndexBlock @@ -1397,6 +1399,149 @@ def to_top_k_recommender( return recommender +@tf.keras.utils.register_keras_serializable(package="merlin_models") +class RetrievalModelV2(Model): + def __init__( + self, + *, + query: Union[Encoder, tf.keras.layers.Layer], + output: Union[ModelOutput, tf.keras.layers.Layer], + candidate: Optional[Union[Encoder, tf.keras.layers.Layer]] = None, + query_name="query", + candidate_name="candidate", + pre: Optional[tf.keras.layers.Layer] = None, + post: Optional[tf.keras.layers.Layer] = None, + **kwargs, + ): + if isinstance(output, ContrastiveOutput): + query_name = output.query_name + candidate_name = output.candidate_name + + if query and candidate: + encoder = ParallelBlock({query_name: query, candidate_name: candidate}) + else: + encoder = query + + super().__init__(encoder, output, pre=pre, post=post, **kwargs) + + self._query_name = query_name + self._candidate_name = candidate_name + self._encoder = encoder + self._output = output + + def query_embeddings( + self, + dataset: Optional[merlin.io.Dataset] = None, + index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, + **kwargs, + ) -> merlin.io.Dataset: + query = self.query_encoder if self.has_candidate_encoder else self.encoder + + if dataset is not None and hasattr(query, "encode"): + return query.encode(dataset, index=index, **kwargs) + + if hasattr(query, "to_dataset"): + return query.to_dataset(**kwargs) + + return query.encode(dataset, index=index, **kwargs) + + def candidate_embeddings( + self, + dataset: Optional[merlin.io.Dataset] = None, + index: Optional[Union[str, ColumnSchema, Schema, Tags]] = None, + **kwargs, + ) -> merlin.io.Dataset: + if self.has_candidate_encoder: + candidate = self.candidate_encoder + + if dataset is not None and hasattr(candidate, "encode"): + return candidate.encode(dataset, index=index, **kwargs) + + if hasattr(candidate, "to_dataset"): + return candidate.to_dataset(**kwargs) + + return candidate.encode(dataset, index=index, **kwargs) + + if isinstance(self.last, ContrastiveOutput): + return self.last.to_dataset() + + raise Exception(...) + + @property + def encoder(self): + return self._encoder + + @property + def has_candidate_encoder(self): + return ( + isinstance(self.encoder, ParallelBlock) + and self._candidate_name in self.encoder.parallel_dict + ) + + @property + def query_encoder(self) -> Encoder: + if self.has_candidate_encoder: + output = self.encoder[self._query_name] + else: + output = self.encoder + + output = self._check_encoder(output) + + return output + + @property + def candidate_encoder(self) -> Encoder: + output = None + if self.has_candidate_encoder: + output = self.encoder[self._candidate_name] + + if output: + return self._check_encoder(output) + + raise ValueError("No candidate encoder found.") + + def _check_encoder(self, maybe_encoder): + output = maybe_encoder + + from merlin.models.tf.core.encoder import Encoder + + if isinstance(output, SequentialBlock): + output = Encoder(*maybe_encoder.layers) + + if not isinstance(output, Encoder): + raise ValueError(f"Query encoder should be an Encoder, got {type(output)}") + + return output + + @classmethod + def from_config(cls, config, custom_objects=None): + pre = config.pop("pre", None) + if pre is not None: + pre = tf.keras.layers.deserialize(pre, custom_objects=custom_objects) + + post = config.pop("post", None) + if post is not None: + post = tf.keras.layers.deserialize(post, custom_objects=custom_objects) + + encoder = config.pop("_encoder", None) + if encoder is not None: + encoder = tf.keras.layers.deserialize(encoder, custom_objects=custom_objects) + + output = config.pop("_output", None) + if output is not None: + output = tf.keras.layers.deserialize(output, custom_objects=custom_objects) + + output = RetrievalModelV2(query=encoder, output=output, pre=pre, post=post) + output.__class__ = cls + + return output + + def get_config(self): + config = maybe_serialize_keras_objects(self, {}, ["pre", "post", "_encoder", "_output"]) + + return config + + def _maybe_convert_merlin_dataset(data, batch_size, shuffle=True, **kwargs): # Check if merlin-dataset is passed if hasattr(data, "to_ddf"): diff --git a/merlin/models/tf/outputs/classification.py b/merlin/models/tf/outputs/classification.py index d8759b1782..648d8c5337 100644 --- a/merlin/models/tf/outputs/classification.py +++ b/merlin/models/tf/outputs/classification.py @@ -20,12 +20,14 @@ from tensorflow.keras.layers import Layer from tensorflow.python.ops import embedding_ops +import merlin.io from merlin.models.tf.inputs.embedding import EmbeddingTable from merlin.models.tf.metrics.topk import AvgPrecisionAt, MRRAt, NDCGAt, PrecisionAt, RecallAt from merlin.models.tf.outputs.base import MetricsFn, ModelOutput from merlin.models.tf.utils.tf_utils import ( maybe_deserialize_keras_objects, maybe_serialize_keras_objects, + tensor_to_df, ) from merlin.schema import ColumnSchema, Schema @@ -189,6 +191,9 @@ def __init__( **kwargs, ) + def to_dataset(self, gpu=True) -> merlin.io.Dataset: + return merlin.io.Dataset(tensor_to_df(self.to_call.embeddings, gpu=gpu)) + def get_config(self): config = super().get_config() config["max_num_samples"] = self.max_num_samples @@ -281,7 +286,11 @@ def embedding_lookup(self, inputs: tf.Tensor, **kwargs): tf.Tensor Tensor of hidden representation vectors. """ - return embedding_ops.embedding_lookup(tf.transpose(self.kernel), inputs, **kwargs) + return embedding_ops.embedding_lookup(self.embeddings, inputs, **kwargs) + + @property + def embeddings(self): + return tf.transpose(self.kernel) @tf.keras.utils.register_keras_serializable(package="merlin.models") @@ -323,6 +332,10 @@ def call(self, inputs, training=False, **kwargs) -> tf.Tensor: return logits + @property + def embeddings(self): + return self.table.table.embeddings + def embedding_lookup(self, inputs, **kwargs): return self.table.table(inputs, **kwargs) diff --git a/merlin/models/tf/outputs/contrastive.py b/merlin/models/tf/outputs/contrastive.py index ee360f373f..9f712d11d0 100644 --- a/merlin/models/tf/outputs/contrastive.py +++ b/merlin/models/tf/outputs/contrastive.py @@ -19,6 +19,7 @@ import tensorflow as tf from tensorflow.keras.layers import Layer +import merlin.io from merlin.models.tf.core.prediction import Prediction from merlin.models.tf.inputs.embedding import EmbeddingTable from merlin.models.tf.outputs.base import DotProduct, MetricsFn, ModelOutput @@ -292,6 +293,9 @@ def sample_negatives( def embedding_lookup(self, ids: tf.Tensor): return self.to_call.embedding_lookup(tf.squeeze(ids)) + def to_dataset(self, gpu=None) -> merlin.io.Dataset: + return merlin.io.Dataset(tf_utils.tensor_to_df(self.to_call.embeddings, gpu=gpu)) + @property def has_candidate_weights(self) -> bool: if isinstance(self.to_call, DotProduct): @@ -338,6 +342,10 @@ def from_config(cls, config): class LookUpProtocol(Protocol): """Protocol for embedding lookup layers""" + @property + def embeddings(self): + pass + def embedding_lookup(self, inputs, **kwargs): pass diff --git a/merlin/models/tf/utils/tf_utils.py b/merlin/models/tf/utils/tf_utils.py index 3fdca0f029..fabde52af3 100644 --- a/merlin/models/tf/utils/tf_utils.py +++ b/merlin/models/tf/utils/tf_utils.py @@ -20,6 +20,7 @@ import tensorflow as tf from keras.utils.tf_inspect import getfullargspec from packaging import version +from tensorflow.python import to_dlpack from merlin.core.dispatch import DataFrameType from merlin.io import Dataset @@ -274,6 +275,40 @@ def df_to_tensor(gdf, dtype=None): return x +def tensor_to_df(tensor, index=None, gpu=None): + if gpu is None: + try: + import cudf # noqa: F401 + import cupy + + gpu = True + except ImportError: + gpu = False + + if gpu: + # Note: It is not possible to convert Tensorflow tensors to the cudf dataframe + # directly using dlPack (as the example commented below) because cudf.from_dlpack() + # expects the 2D tensor to be in Fortran order (column-major), which is not + # supported by TF (https://github.com/rapidsai/cudf/issues/10754). + # df = cudf.from_dlpack(to_dlpack(tf.convert_to_tensor(embeddings))) + tensor_cupy = cupy.fromDlpack(to_dlpack(tf.convert_to_tensor(tensor))) + df = cudf.DataFrame(tensor_cupy) + df.columns = [str(col) for col in list(df.columns)] + if not index: + index = cudf.RangeIndex(0, tensor.shape[0]) + df.set_index(index) + else: + import pandas as pd + + df = pd.DataFrame(tensor.numpy()) + df.columns = [str(col) for col in list(df.columns)] + if not index: + index = pd.RangeIndex(0, tensor.shape[0]) + df.set_index(index) + + return df + + def add_epsilon_to_zeros(tensor: tf.Tensor, epsilon: float = 1e-24) -> tf.Tensor: """Replaces zeros by adding a small epsilon value to them. This is useful to avoid inf and nan errors on math ops diff --git a/tests/unit/tf/core/test_encoder.py b/tests/unit/tf/core/test_encoder.py index 04ca14f450..998b648abd 100644 --- a/tests/unit/tf/core/test_encoder.py +++ b/tests/unit/tf/core/test_encoder.py @@ -15,9 +15,9 @@ def test_encoder_block(music_streaming_data: Dataset): schema = music_streaming_data.schema user_schema = schema.select_by_name(["user_id", "user_genres"]) - user_encoder = mm.EncoderBlock(user_schema, mm.MLPBlock([4]), name="query") + user_encoder = mm.Encoder(user_schema, mm.MLPBlock([4]), name="query") item_schema = schema.select_by_name(["item_id"]) - item_encoder = mm.EncoderBlock(item_schema, mm.MLPBlock([4]), name="candidate") + item_encoder = mm.Encoder(item_schema, mm.MLPBlock([4]), name="candidate") model = mm.Model( mm.ParallelBlock(user_encoder, item_encoder), @@ -27,7 +27,7 @@ def test_encoder_block(music_streaming_data: Dataset): assert model.blocks[0]["query"] == user_encoder assert model.blocks[0]["candidate"] == item_encoder - testing_utils.model_test(model, music_streaming_data) + testing_utils.model_test(model, music_streaming_data, reload_model=True) with pytest.raises(Exception) as excinfo: user_encoder.compile("adam") @@ -53,9 +53,9 @@ def test_topk_encoder(music_streaming_data: Dataset): # 1. Train a retrieval model schema = music_streaming_data.schema user_schema = schema.select_by_name(["user_id", "country", "user_age"]) - user_encoder = mm.EncoderBlock(user_schema, mm.MLPBlock([4]), name="query") + user_encoder = mm.Encoder(user_schema, mm.MLPBlock([4]), name="query") item_schema = schema.select_by_name(["item_id"]) - item_encoder = mm.EncoderBlock(item_schema, mm.MLPBlock([4]), name="candidate") + item_encoder = mm.Encoder(item_schema, mm.MLPBlock([4]), name="candidate") retrieval_model = mm.Model( mm.ParallelBlock(user_encoder, item_encoder), mm.ContrastiveOutput(item_schema, "in-batch"), @@ -111,6 +111,7 @@ def _item_id_as_target(inputs, targets): topk_encoder.save(tmpdir) loaded_topk_encoder = tf.keras.models.load_model(tmpdir) batch_output = loaded_topk_encoder(batch[0]) + assert list(batch_output.scores.shape) == [32, TOP_K] tf.debugging.assert_equal( topk_encoder.topk_layer._candidates, diff --git a/tests/unit/tf/models/test_base.py b/tests/unit/tf/models/test_base.py index cd7f80d185..8f2488e62d 100644 --- a/tests/unit/tf/models/test_base.py +++ b/tests/unit/tf/models/test_base.py @@ -669,3 +669,64 @@ def test_unfreeze_all_blocks(ecommerce_data): model.compile(run_eagerly=True, optimizer=tf.keras.optimizers.SGD(lr=0.1)) model.fit(ecommerce_data, batch_size=128, epochs=1) + + +def test_retrieval_model_query(ecommerce_data: Dataset, run_eagerly=True): + query = ecommerce_data.schema.select_by_tag(Tags.USER_ID) + candidate = ecommerce_data.schema.select_by_tag(Tags.ITEM_ID) + + def item_id_as_target(features, targets): + targets[candidate.first.name] = features.pop(candidate.first.name) + + return features, targets + + loader = mm.Loader(ecommerce_data, batch_size=50, transform=item_id_as_target) + + model = mm.RetrievalModelV2( + query=mm.EmbeddingEncoder(query, dim=8), + output=mm.ContrastiveOutput(candidate, "in-batch"), + ) + + model, _ = testing_utils.model_test(model, loader, reload_model=True, run_eagerly=run_eagerly) + + assert isinstance(model.query_encoder, mm.EmbeddingEncoder) + assert isinstance(model.last, mm.ContrastiveOutput) + + queries = model.query_embeddings().compute() + _check_embeddings(queries, 1001) + + candidates = model.candidate_embeddings().compute() + _check_embeddings(candidates, 1001) + + +def test_retrieval_model_query_candidate(ecommerce_data: Dataset, run_eagerly=True): + query = ecommerce_data.schema.select_by_tag(Tags.USER_ID) + candidate = ecommerce_data.schema.select_by_tag(Tags.ITEM_ID) + model = mm.RetrievalModelV2( + query=mm.EmbeddingEncoder(query, dim=8), + candidate=mm.EmbeddingEncoder(candidate, dim=8), + output=mm.ContrastiveOutput(candidate, "in-batch"), + ) + + reloaded_model, _ = testing_utils.model_test(model, ecommerce_data, reload_model=True) + + assert isinstance(reloaded_model.query_encoder, mm.EmbeddingEncoder) + assert isinstance(reloaded_model.candidate_encoder, mm.EmbeddingEncoder) + + queries = model.query_embeddings(ecommerce_data, batch_size=10, index=Tags.USER_ID).compute() + _check_embeddings(queries, 100, "user_id") + + candidates = model.candidate_embeddings( + ecommerce_data, batch_size=10, index=candidate + ).compute() + _check_embeddings(candidates, 100, "item_id") + + +def _check_embeddings(embeddings, extected_len, index_name=None): + if not isinstance(embeddings, pd.DataFrame): + embeddings = embeddings.to_pandas() + + assert isinstance(embeddings, pd.DataFrame) + assert list(embeddings.columns) == [str(i) for i in range(8)] + assert len(embeddings.index) == extected_len + assert embeddings.index.name == index_name