Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Encoder & Predictor #1112

Merged
merged 27 commits into from
Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
375ae40
Some commit
marcromeyn May 19, 2023
84fab78
Some commit
marcromeyn May 23, 2023
5a1924a
Some commit
marcromeyn May 23, 2023
d24c019
Adding better doc-strings + increase test-coverage
marcromeyn May 25, 2023
494db22
Merge torch/utils/schema_utils.py into utils/schema_utils.py
marcromeyn May 25, 2023
96b2545
Removing merlin/models/predict.py since it's un-used
marcromeyn May 25, 2023
6c5fad9
Merge branch 'main' into torch/batch-predict
marcromeyn May 29, 2023
385f566
Merge branch 'main' into torch/batch-predict
marcromeyn May 29, 2023
fb50333
Merge branch 'main' into torch/batch-predict
marcromeyn May 30, 2023
042479c
Merge branch 'main' into torch/batch-predict
marcromeyn May 30, 2023
101d45d
Adding output-schema propagation
marcromeyn Jun 1, 2023
f2f8ca2
Making test-classes for functions camel-case
marcromeyn Jun 1, 2023
c384f86
Merge branch 'main' into torch/batch-predict
edknv Jun 2, 2023
4ef30d2
Merge branch 'main' into torch/batch-predict
edknv Jun 2, 2023
3244cc3
Merge branch 'main' into torch/batch-predict
marcromeyn Jun 7, 2023
e96ae35
Merge branch 'main' into torch/batch-predict
edknv Jun 7, 2023
fc77896
Merge branch 'main' into torch/batch-predict
edknv Jun 12, 2023
5c4808a
Merge branch 'main' into torch/batch-predict
edknv Jun 12, 2023
22f28ac
Merge branch 'main' into torch/batch-predict
edknv Jun 13, 2023
764a3bf
Merge branch 'main' into torch/batch-predict
marcromeyn Jun 19, 2023
ee174b4
Pass `meta` argument to map_partitions to avoid dtype issues in dask
oliverholworthy Jun 19, 2023
0dcb7fc
Merge branch 'main' into torch/batch-predict
marcromeyn Jun 20, 2023
f8fa1ca
Merge branch 'main' into torch/batch-predict
marcromeyn Jun 21, 2023
26cd4d0
Merge branch 'main' into torch/batch-predict
marcromeyn Jun 21, 2023
66e36a5
Merge branch 'main' into torch/batch-predict
marcromeyn Jun 21, 2023
74185b9
Remove SchemaTrackingMixin from test_predict
marcromeyn Jun 22, 2023
ed878ed
Remove select_schema from schema_utils
marcromeyn Jun 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 346 additions & 0 deletions merlin/models/torch/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
from functools import partial, reduce
from typing import Dict, Optional, TypeVar, Union, overload

import torch
from torch import nn

from merlin.core.dispatch import DataFrameLike, concat, concat_columns
from merlin.dataloader.torch import Loader
from merlin.io import Dataset
from merlin.models.utils.schema_utils import Selection, select_schema
from merlin.schema import Schema
from merlin.table import TensorTable

OUT_KEY = "output"
DFType = TypeVar("DFType", bound=DataFrameLike)


class Encoder:
"""Encode various forms of data using a specified PyTorch module.

Supporting multiple data formats like Datasets, Loaders, DataFrames,
and PyTorch tensors.

Example usage for encoding with an index & selection::
>>> dataset = Dataset(...)
>>> model = mm.TwoTowerModel(dataset.schema)

# `selection=Tags.USER` ensures that only the sub-module(s) of the model
# that processes features tagged as user is used during encoding.
# Additionally, it filters out all other features that aren't tagged as user.
>>> user_encoder = Encoder(model[0], selection=Tags.USER)

# The index is used in the resulting DataFrame after encoding
# Setting unique=True (default value) ensures that any duplicate rows
# in the DataFrame, based on the index, are dropped, leaving only the
# first occurrence.
>>> user_embs = user_encoder(dataset, batch_size=128, index=Tags.USER_ID)
>>> print(user_embs.compute())
user_id 0 1 2 ... 37 38 39 40
0 ... 0.1231 0.4132 0.5123 ... 0.9132 0.8123 0.1123
1 ... 0.1521 0.5123 0.6312 ... 0.7321 0.6123 0.2213
... ... ... ... ... ... ... ... ...

Parameters
----------
module : nn.Module
The PyTorch module used for encoding.
selection : Optional[Selection], optional
The data selection used for encoding, by default None.
"""

def __init__(self, module: nn.Module, selection: Optional[Selection] = None):
self.module = module
self.selection = selection

@overload # pragma: no cover
def __call__(
self, data: Dataset, batch_size=None, index: Optional[Selection] = None, unique: bool = True
):
...

@overload # pragma: no cover
def __call__(self, data: Loader, index: Optional[Selection] = None, unique: bool = True):
...

@overload # pragma: no cover
def __call__(self, data: DataFrameLike, batch_size=None):
...

@overload # pragma: no cover
def __call__(self, data: torch.Tensor):
...

@overload # pragma: no cover
def __call__(self, data: Dict[str, torch.Tensor]):
...

def __call__(self, data, batch_size=None, index=None, unique=True):
"""Encode a Dataset, Loader, DataFrame, or Tensor(s).

Parameters
----------
data : Dataset, Loader, DataFrameLike, torch.Tensor or Dict[str, torch.Tensor]
The data to be encoded.
batch_size : int, optional
The batch size for the encoding, by default None.
index : Optional[Selection], optional
The data selection used for the encoding, by default None.
unique : bool, optional
If True, duplicate rows in the DataFrame are removed, by default True.
"""
if isinstance(data, (Dataset, Loader)):
return self.encode_dataset(data, batch_size, index=index, unique=unique)
if isinstance(data, DataFrameLike):
return self.encode_df(data, batch_size)
if isinstance(data, (dict, torch.Tensor)):
return self.encode_tensors(data)

raise ValueError("data must be a DataFrameLike, a Dataset, or a Loader")

def encode_dataset(
self,
data: Union[Dataset, Loader],
batch_size: Optional[int] = None,
index: Optional[Selection] = None,
unique: bool = True,
) -> Dataset:
"""Encode a Dataset or Loader through Dask.

Encoding happens in 3 steps:
1. Partition Mapping
This step uses Dask to break down the DataFrame into several partitions,
making large datasets computationally manageable.
The `call_df` function is applied to each partition independently,
facilitating efficient distributed computation.
2. DataFrame Processing
In this step, each partition, which is a DataFrame, is transformed directly
into a Loader with a determined batch size. This Loader then efficiently
converts the data into batches of PyTorch tensors, which are subsequently
processed by the PyTorch module using the `call_tensors` function.
3. Tensor Processing
Here, each batch derived from the Loader is processed by a PyTorch module for encoding.
If the inputs are dictionary-like and a passthrough_schema is provided, supplementary
columns might be included in the output DataFrame.

Parameters
----------
data : Union[Dataset, Loader]
The data to be encoded.
batch_size : Optional[int], optional
The batch size for the encoding, by default None.
index : Optional[Selection], optional
The data selection used for the encoding, by default None.
unique : bool, optional
If True, duplicate rows in the DataFrame are removed, by default True.
"""
if isinstance(data, Loader):
batch_size = data.batch_size
schema = data.input_schema
marcromeyn marked this conversation as resolved.
Show resolved Hide resolved
dataset: Dataset = data.dataset
elif isinstance(data, Dataset):
if not batch_size:
raise ValueError("batch_size must be provided if a Dataset is passed")
schema = data.schema
dataset: Dataset = data
else:
raise ValueError("data must be a DataFrameLike, a Dataset, or a Loader")

if self.selection:
schema = select_schema(schema, self.selection)
dataset = Dataset(dataset.to_ddf(), schema=schema)
ddf = dataset.to_ddf()[schema.column_names]
else:
ddf = dataset.to_ddf()

index_schema = None
if index:
index_schema = select_schema(schema, index)

if unique:
ddf = ddf.drop_duplicates(index_schema.column_names, keep="first")

output = ddf.map_partitions(
self.encode_df,
batch_size=batch_size,
input_schema=schema,
passthrough_schema=index_schema,
)
if index:
output = output.set_index(index_schema.column_names)

return Dataset(output)

def encode_df(
self,
df: DFType,
batch_size: Optional[int] = None,
input_schema: Optional[Schema] = None,
passthrough_schema: Optional[Schema] = None,
) -> DFType:
"""Encode a DataFrame, either from Pandas or CuDF.

Parameters
----------
df : DFType
The DataFrame to be encoded.
batch_size : Optional[int], optional
The batch size for the encoding, by default None.
input_schema : Optional[Schema], optional
The schema of the input DataFrame, by default None.
passthrough_schema : Optional[Schema], optional
The schema that should pass through the encoding, by default None.
"""
dataset = Dataset(df, schema=input_schema)
loader = Loader(dataset, batch_size=batch_size or len(df))
apply = partial(self.encode_tensors, passthrough_schema=passthrough_schema)
output_df = reduce(self.reduce, loader.map(apply))

return output_df

def encode_tensors(
self, inputs, targets=None, passthrough_schema: Optional[Schema] = None
) -> DFType:
"""Encode a batch of Pytorch tensor(s).

Parameters
----------
inputs
The inputs to be encoded.
targets, optional
The targets to be encoded, by default None.
passthrough_schema : Optional[Schema], optional
The schema that should pass through the encoding, by default None.
"""
del targets
output_df = to_tensor_table(self.module(inputs)).to_df()

if passthrough_schema and isinstance(inputs, dict):
col_names = passthrough_schema.column_names
index_dict = {n: inputs[n] for n in col_names}
index_df = to_tensor_table(index_dict).to_df()

output_df = concat_columns([index_df, output_df])

return output_df

def reduce(self, left: DFType, right: DFType):
"""
Concatenate two DataFrames along the index axis.

Parameters
----------
left : DFType
The first DataFrame.
right : DFType
The second DataFrame.
"""
return concat([left, right])


class Predictor(Encoder):
"""Prediction on various forms of data using a specified PyTorch module.

This is especially useful when you want to keep track of both the original data and
the predictions in one place, or when you need to perform further computations using
both inputs and predictions.

Example usage::
>>> dataset = Dataset(...)
>>> model = mm.TwoTowerModel(dataset.schema)
>>> predictor = Predictor(model)
>>> predictions = predictor(dataset, batch_size=128)
>>> print(predictions.compute())
user_id user_age item_id item_category click click_prediction
0 24 101 1 1 0.6312
1 35 102 2 0 0.7321
... ... ... ... ... ...


Parameters
----------
module : nn.Module
The PyTorch module used to transform the input tensors.
selection : Selection, optional
Selection of features to encode, if not provided, all features will be encoded.
prediction_suffix : str, optional
The suffix to add to the prediction columns in the output DataFrame.
"""

def __init__(
self,
module: nn.Module,
selection: Optional[Selection] = None,
prediction_suffix: str = "_prediction",
):
super().__init__(module, selection)
self.prediction_suffix = prediction_suffix

def encode_tensors(self, inputs, targets=None, **kwargs) -> DFType:
"""Encode a batch of Pytorch tensor(s), outputs include both inputs and predictions.

Parameters
----------
inputs :
Input tensors to be transformed.
targets : optional
Target tensors.

Returns
-------
output_df : DFType
The output DataFrame.
"""

del kwargs # Unused since we pass-through everything

input_df = to_tensor_table(inputs, "input").to_df()
if targets is not None:
target_df = to_tensor_table(targets, "target").to_df()
output_df = safe_concat_columns(input_df, target_df, rename_suffix="_target")
else:
output_df = input_df

module_df = to_tensor_table(self.module(inputs)).to_df()
output_df = safe_concat_columns(output_df, module_df, self.prediction_suffix)

return output_df


def to_tensor_table(
data: Union[torch.Tensor, Dict[str, torch.Tensor]], default_key: str = OUT_KEY
) -> TensorTable:
if isinstance(data, dict):
return TensorTable(data)
else:
return TensorTable({default_key: data})


def safe_concat_columns(left: DFType, right: DFType, rename_suffix: str = "_") -> DFType:
"""Safely concatenate columns from two dataframes.

If the column names overlap, the ones in the output_df are renamed
by appending a suffix.

Parameters
----------
input_df : DataFrameType
Input dataframe.
output_df : DataFrameType
Output dataframe.
rename_suffix : str, optional
Suffix to append to the column names in the output_df, by default "_"

Returns
-------
DataFrameType
Concatenated dataframe.
"""

left_col_set = set(left.columns)

_to_rename = [col for col in right.columns if col in left_col_set]
if _to_rename:
right = right.rename(columns={col: f"{col}{rename_suffix}" for col in _to_rename})

return concat_columns([left, right])
Loading