Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Encoder & Predictor #1112

Merged
merged 27 commits into from
Jun 22, 2023
Merged

Add Encoder & Predictor #1112

merged 27 commits into from
Jun 22, 2023

Conversation

marcromeyn
Copy link
Contributor

@marcromeyn marcromeyn commented May 25, 2023

Goals ⚽

This PR introduces the Encoder and Predictor classes, to add batch-prediction capabilities in the PyTorch backend.

Implementation Details 🚧

Encoder

The Encoder is meant to be used for things like embedding extraction.

>>> dataset = Dataset(...)
>>> model = mm.TwoTowerModel(dataset.schema)
# `selection=Tags.USER` ensures that only the sub-module(s) of the model
# that processes features tagged as user is used during encoding.
# Additionally, it filters out all other features that aren't tagged as user.
>>> user_encoder = Encoder(model[0], selection=Tags.USER)
# The index is used in the resulting DataFrame after encoding
# Setting unique=True (default value) ensures that any duplicate rows
# in the DataFrame, based on the index, are dropped, leaving only the
# first occurrence.
>>> user_embs = user_encoder(dataset, batch_size=128, index=Tags.USER_ID)
>>> print(user_embs.compute())
user_id    0         1         2    ...   37        38        39        40
0       ...  0.1231     0.4132    0.5123  ...  0.9132    0.8123    0.1123
1       ...  0.1521     0.5123    0.6312  ...  0.7321    0.6123    0.2213
...     ...  ...        ...       ...     ...  ...       ...       ...

Predictor

On the other hand, the Predictor class, will return both the original input data and the corresponding predictions in the output-DF.

>>> dataset = Dataset(...)
>>> model = mm.TwoTowerModel(dataset.schema)
>>> predictor = Predictor(model)
>>> predictions = predictor(dataset, batch_size=128)
>>> print(predictions.compute())
user_id  user_age  item_id  item_category  click  click_prediction
0        24        101      1             1      0.6312
1        35        102      2             0      0.7321
...      ...       ...      ...           ...    ...

@marcromeyn marcromeyn added area/pytorch enhancement New feature or request labels May 25, 2023
@github-actions
Copy link

Documentation preview

https://nvidia-merlin.github.io/models/review/pr-1112

@marcromeyn marcromeyn force-pushed the torch/batch-predict branch from ce776e2 to 96b2545 Compare May 29, 2023 07:48
@marcromeyn marcromeyn self-assigned this May 29, 2023
@marcromeyn marcromeyn marked this pull request as ready for review May 29, 2023 11:59
merlin/models/torch/predict.py Show resolved Hide resolved
@edknv
Copy link
Contributor

edknv commented Jun 13, 2023

Some tests in CPU tests seem to be failing (sample run) related to dask and/or dlpack, e.g.,

E           ValueError: Metadata inference failed in `encode_df`.
E           
E           You have supplied a custom function and Dask is unable to 
E           determine the type of output that that function returns. 
E           
E           To resolve this please provide a meta= keyword.
E           The docstring of the Dask function you ran should have more information.
E           
E           Original error is below:
E           ------------------------
E           BufferError('DLPack only supports signed/unsigned integers, float and complex dtypes.')

@marcromeyn marcromeyn merged commit aa501f0 into main Jun 22, 2023
@marcromeyn marcromeyn deleted the torch/batch-predict branch June 22, 2023 11:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants