Skip to content

Commit

Permalink
Updates on MM to make it support new dataloader output
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielspmoreira committed Feb 24, 2023
1 parent ff8bd31 commit 29fc1f8
Show file tree
Hide file tree
Showing 10 changed files with 198 additions and 207 deletions.
75 changes: 25 additions & 50 deletions merlin/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import pathlib
from pathlib import Path
from random import randint
from typing import Dict, Optional, Sequence, Tuple, Union
from typing import Dict, Sequence, Tuple, Union

import numpy as np

import merlin.io
from merlin.models.utils import schema_utils
from merlin.schema import ColumnSchema, Schema, Tags
from merlin.schema import Schema, Tags
from merlin.schema.io.tensorflow_metadata import TensorflowMetadata

LOG = logging.getLogger("merlin-models")
Expand Down Expand Up @@ -92,9 +92,13 @@ def generate_data(
Example::
train, valid = generate_data(input, 10000, (0.8, 0.2))
min_session_length: int
The minimum number of events in a session.
The minimum number of events in a session. Overrides the
min sequence length information from the shape of list columns
schema (schema[col].shape.dims[1].min)
max_session_length: int
The maximum number of events in a session.
The minimum number of events in a session. Overrides the
max sequence length information from the shape of list columns
schema (schema[col].shape.dims[1].max)
device: str
The device to use for the data generation.
Supported values: {'cpu', 'gpu'}
Expand All @@ -119,23 +123,15 @@ def generate_data(
raise ValueError(f"Unknown input type: {type(input)}")

for col in schema.column_names:
if not schema[col].is_list:
continue
new_properties = schema[col].properties
new_properties["value_count"] = {"min": min_session_length}
if max_session_length:
new_properties["value_count"]["max"] = max_session_length
schema[col] = ColumnSchema(
name=schema[col].name,
tags=schema[col].tags,
properties=new_properties,
dtype=schema[col].dtype,
is_list=True,
)
if schema[col].shape.is_list:
min_session_length = min_session_length or schema[col].shape.dims[1].min
max_session_length = max_session_length or schema[col].shape.dims[1].max
# Overriding min and max session length from schema
schema[col] = schema[col].with_shape(
((0, None), (min_session_length, max_session_length))
)

df = generate_user_item_interactions(
schema, num_rows, min_session_length, max_session_length, device=device
)
df = generate_user_item_interactions(schema, num_rows, device=device)

if list(set_sizes) != [1.0]:
num_rows = df.shape[0]
Expand All @@ -156,8 +152,6 @@ def generate_data(
def generate_user_item_interactions(
schema: Schema,
num_interactions: int,
min_session_length: int = 5,
max_session_length: Optional[int] = None,
device: str = "cpu",
):
"""
Expand All @@ -177,10 +171,6 @@ def generate_user_item_interactions(
schema object describing the columns to generate.
num_interactions: int
number of interaction rows to generate.
max_session_length: Optional[int]
The maximum length of the multi-hot/sequence features
min_session_length: int
The minimum length of the multi-hot/sequence features
device: str
device to use for generating data.
Expand Down Expand Up @@ -215,8 +205,6 @@ def generate_user_item_interactions(
data,
features,
session_id_col,
min_session_length=min_session_length,
max_session_length=max_session_length,
device=device,
)
processed_cols += [f.name for f in features] + [session_id_col.name]
Expand All @@ -235,8 +223,6 @@ def generate_user_item_interactions(
data,
features,
user_id_col,
min_session_length=min_session_length,
max_session_length=max_session_length,
device=device,
)
processed_cols += [f.name for f in features] + [user_id_col.name]
Expand All @@ -247,11 +233,12 @@ def generate_user_item_interactions(
raise ValueError("Item ID column is required")
item_id_col = item_schema.first

is_list_feature = item_id_col.is_list
is_list_feature = item_id_col.shape.is_list
if not is_list_feature:
shape = num_interactions
else:
shape = (num_interactions, max_session_length or min_session_length) # type: ignore
seq_length = item_id_col.shape.dims[1].max or item_id_col.shape.dims[1].min
shape = (num_interactions, seq_length) # type: ignore
tmp = _array.clip(
_array.random.lognormal(3.0, 1.0, shape).astype(_array.int32),
1,
Expand All @@ -262,14 +249,7 @@ def generate_user_item_interactions(
else:
data[item_id_col.name] = list(tmp)
features = list(schema.select_by_tag(Tags.ITEM).remove_by_tag(Tags.ITEM_ID))
data = generate_conditional_features(
data,
features,
item_id_col,
min_session_length=min_session_length,
max_session_length=max_session_length,
device=device,
)
data = generate_conditional_features(data, features, item_id_col, device=device)
processed_cols += [f.name for f in features] + [item_id_col.name]

# Get remaining features
Expand All @@ -284,9 +264,7 @@ def generate_user_item_interactions(
is_int_feature = feature.dtype and np.issubdtype(feature.dtype.to_numpy, np.integer)
is_list_feature = feature.is_list
if is_list_feature:
data[feature.name] = generate_random_list_feature(
feature, num_interactions, min_session_length, max_session_length, device
)
data[feature.name] = generate_random_list_feature(feature, num_interactions, device)

elif is_int_feature:
domain = feature.int_domain
Expand All @@ -311,8 +289,6 @@ def generate_conditional_features(
data,
features,
parent_feature,
min_session_length: int = 5,
max_session_length: Optional[int] = None,
device="cpu",
):
"""
Expand All @@ -331,9 +307,7 @@ def generate_conditional_features(
is_list_feature = feature.is_list

if is_list_feature:
data[feature.name] = generate_random_list_feature(
feature, num_interactions, min_session_length, max_session_length, device
)
data[feature.name] = generate_random_list_feature(feature, num_interactions, device)

elif is_int_feature:
if not feature.int_domain:
Expand Down Expand Up @@ -364,15 +338,16 @@ def generate_conditional_features(
def generate_random_list_feature(
feature,
num_interactions,
min_session_length: int = 5,
max_session_length: Optional[int] = None,
device="cpu",
):
if device == "cpu":
import numpy as _array
else:
import cupy as _array

seq_length_dim = feature.shape.dims[1]
min_session_length, max_session_length = seq_length_dim.min, seq_length_dim.max

is_int_feature = np.issubdtype(feature.dtype.to_numpy, np.integer)
if is_int_feature:
if max_session_length:
Expand Down
9 changes: 4 additions & 5 deletions merlin/models/tf/inputs/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,13 @@ def call(
return out

def _call_table(self, inputs, **kwargs):
# if isinstance(inputs, tuple) and len(inputs) == 2:
# inputs = list_col_to_ragged(inputs)

# Eliminating the last dim==1 of dense tensors before embedding lookup
if isinstance(inputs, tf.Tensor) or (
isinstance(inputs, tf.RaggedTensor) and inputs.shape[-1] == 1
):
inputs = tf.squeeze(inputs, axis=-1)
# Eliminating the last dim==1 of dense tensors before embedding lookup
inputs = tf.cond(
tf.shape(inputs)[-1] == 1, lambda: tf.squeeze(inputs, axis=-1), lambda: inputs
)

if isinstance(inputs, (tf.RaggedTensor, tf.SparseTensor)):
if self.sequence_combiner and isinstance(self.sequence_combiner, str):
Expand Down
7 changes: 6 additions & 1 deletion merlin/models/tf/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,12 @@ def sample_batch(
inputs, targets = batch[0], batch[1]

if prepare_features:
inputs = PrepareFeatures(loader.schema, list_to_dense)(inputs)
pf = PrepareFeatures(loader.schema, list_to_dense)
if targets:
inputs, targets = pf(inputs, targets)
else:
inputs = pf(inputs)

if not include_targets:
return inputs
return inputs, targets
Expand Down
11 changes: 8 additions & 3 deletions merlin/models/tf/models/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from merlin.models.tf.blocks.interaction import FMBlock
from merlin.models.tf.blocks.mlp import MLPBlock, RegularizerType
from merlin.models.tf.core.aggregation import ConcatFeatures
from merlin.models.tf.core.base import Block
from merlin.models.tf.core.base import Block, BlockType
from merlin.models.tf.core.combinators import ParallelBlock, TabularBlock
from merlin.models.tf.inputs.base import InputBlockV2
from merlin.models.tf.inputs.embedding import EmbeddingOptions, Embeddings
from merlin.models.tf.models.base import Model
from merlin.models.tf.models.utils import parse_prediction_blocks
from merlin.models.tf.outputs.base import ModelOutputType
from merlin.models.tf.prediction_tasks.base import ParallelPredictionBlock, PredictionTask
from merlin.models.tf.transforms.features import CategoryEncoding
from merlin.models.tf.transforms.features import CategoryEncoding, PrepareFeatures
from merlin.schema import Schema, Tags


Expand Down Expand Up @@ -285,6 +285,7 @@ def WideAndDeepModel(
prediction_tasks: Optional[
Union[PredictionTask, List[PredictionTask], ParallelPredictionBlock, ModelOutputType]
] = None,
pre: Optional[BlockType] = None,
**wide_body_kwargs,
) -> Model:
"""
Expand Down Expand Up @@ -551,7 +552,11 @@ def WideAndDeepModel(
" or wide part (wide_schema/wide_input_block) must be provided."
)

wide_and_deep_body = ParallelBlock(branches, aggregation="element-wise-sum")
_pre = PrepareFeatures(schema)
if pre:
_pre = _pre.connect(pre)

wide_and_deep_body = ParallelBlock(branches, pre=_pre, aggregation="element-wise-sum")
model = Model(wide_and_deep_body, prediction_blocks)

return model
Loading

0 comments on commit 29fc1f8

Please sign in to comment.