Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update UniformNegativeSampling to handle targets and add optional control for testing #583

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 36 additions & 27 deletions merlin/models/tf/data_augmentation/negative_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,27 @@ class UniformNegativeSampling(tf.keras.layers.Layer):
Only works with positive-only binary-target batches.
"""

def __init__(self, schema: Schema, n_per_positive: int, seed: Optional[int] = None, **kwargs):
def __init__(
self,
schema: Schema,
n_per_positive: int,
seed: Optional[int] = None,
run_when_testing: bool = True,
**kwargs
):
"""Instantiate a sampling block."""
super(UniformNegativeSampling, self).__init__(**kwargs)
self.n_per_positive = n_per_positive
self.item_id_col = schema.select_by_tag(Tags.ITEM_ID).column_names[0]
self.schema = schema.select_by_tag(Tags.ITEM)
self.seed = seed
self.run_when_testing = run_when_testing

def call(self, inputs: TabularData, targets=None) -> Prediction:
def call(self, inputs: TabularData, targets=None, testing=False, **kwargs) -> Prediction:
"""Extend batch of inputs and targets with negatives."""
if targets is None or (testing and not self.run_when_testing):
return Prediction(inputs, targets)

# 1. Select item-features
fist_input = list(inputs.values())[0]
batch_size = (
Expand Down Expand Up @@ -91,31 +102,28 @@ def call(self, inputs: TabularData, targets=None) -> Prediction:
outputs[name] = tf.concat([val, negatives], axis=0)
outputs[name] = tf.ragged.boolean_mask(outputs[name], mask)

# update targets if present
if targets is not None:

def mask_targets(target_tensor):
out = tf.concat(
[
target_tensor,
tf.zeros((sampled_num_negatives, 1), dtype=target_tensor.dtype),
],
0,
)
out = tf.boolean_mask(out, mask)

return out

if isinstance(targets, dict):
targets = {k: mask_targets(v) for k, v in targets.items()}
elif isinstance(targets, list):
targets = [mask_targets(v) for v in targets]
elif isinstance(targets, tuple):
targets = tuple([mask_targets(v) for v in targets])
elif isinstance(targets, tf.Tensor):
targets = mask_targets(targets)
else:
raise ValueError("Unsupported target type: {}".format(type(targets)))
def mask_targets(target_tensor):
out = tf.concat(
[
target_tensor,
tf.zeros((sampled_num_negatives, 1), dtype=target_tensor.dtype),
],
0,
)
out = tf.boolean_mask(out, mask)

return out

if isinstance(targets, dict):
targets = {k: mask_targets(v) for k, v in targets.items()}
elif isinstance(targets, list):
targets = [mask_targets(v) for v in targets]
elif isinstance(targets, tuple):
targets = tuple([mask_targets(v) for v in targets])
elif isinstance(targets, tf.Tensor):
targets = mask_targets(targets)
else:
raise ValueError("Unsupported target type: {}".format(type(targets)))

return Prediction(outputs, targets)

Expand All @@ -125,6 +133,7 @@ def get_config(self):
config["schema"] = schema_utils.schema_to_tensorflow_metadata_json(self.schema)
config["n_per_positive"] = self.n_per_positive
config["seed"] = self.seed
config["run_when_testing"] = self.run_when_testing
return config

@classmethod
Expand Down
73 changes: 68 additions & 5 deletions tests/unit/tf/data_augmentation/test_negative_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,84 @@ def user_item_pairs(df):

assert input_user_item_pairs.intersection(negatives_user_item_pairs) == set()

def assert_outputs_batch_size(self, assert_fn, *outputs):
for values in zip(*outputs):
for value in values:
if isinstance(value, tuple):
assert_fn(value[1].shape[0])
else:
assert_fn(value.shape[0])

@pytest.mark.parametrize("to_dense", [True, False])
def test_calling(self, music_streaming_data: Dataset, to_dense: bool, tf_random_seed: int):
def test_calling_without_targets(
self, music_streaming_data: Dataset, to_dense: bool, tf_random_seed: int
):
schema = music_streaming_data.schema
batch_size, n_per_positive = 10, 5
features = mm.sample_batch(
music_streaming_data, batch_size=batch_size, include_targets=False, to_dense=to_dense
)

sampler = UniformNegativeSampling(schema, n_per_positive, seed=tf_random_seed)

with_negatives = sampler(features)
outputs = with_negatives.outputs

def assert_fn(output_batch_size):
assert output_batch_size == batch_size

self.assert_outputs_batch_size(assert_fn, outputs.values())

@pytest.mark.parametrize("to_dense", [True, False])
def test_calling(self, music_streaming_data: Dataset, to_dense: bool, tf_random_seed: int):
schema = music_streaming_data.schema
batch_size, n_per_positive = 10, 5
inputs, targets = mm.sample_batch(
music_streaming_data, batch_size=batch_size, include_targets=True, to_dense=to_dense
)

sampler = UniformNegativeSampling(schema, 5, seed=tf_random_seed)
with_negatives = sampler(features).outputs

with_negatives = sampler(inputs, targets=targets)
outputs = with_negatives.outputs
targets = with_negatives.targets

max_batch_size = batch_size + batch_size * n_per_positive
assert all(
f.shape[0] <= max_batch_size and f.shape[0] > batch_size
for f in with_negatives.values()

def assert_fn(output_batch_size):
assert batch_size < output_batch_size <= max_batch_size

self.assert_outputs_batch_size(
assert_fn,
outputs.values(),
targets.values(),
)

@pytest.mark.parametrize("to_dense", [True, False])
def test_run_when_testing(
self, music_streaming_data: Dataset, to_dense: bool, tf_random_seed: int
):
schema = music_streaming_data.schema
batch_size, n_per_positive = 10, 5
inputs, targets = mm.sample_batch(
music_streaming_data, batch_size=batch_size, include_targets=True, to_dense=to_dense
)

sampler = UniformNegativeSampling(
schema, n_per_positive, seed=tf_random_seed, run_when_testing=False
)

with_negatives = sampler(inputs, targets=targets, testing=True)
outputs = with_negatives.outputs
targets = with_negatives.targets

def assert_fn(output_batch_size):
assert output_batch_size == batch_size

self.assert_outputs_batch_size(
assert_fn,
outputs.values(),
targets.values(),
)

@pytest.mark.parametrize("run_eagerly", [True, False])
Expand Down