From 6eea6abd082312d33f3bce176f57b13f69fdde40 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Wed, 26 Oct 2022 19:45:46 +0100 Subject: [PATCH 01/17] Add test for transformer with RetrievalModelV2 --- merlin/models/tf/outputs/contrastive.py | 5 ++ tests/unit/tf/transformers/test_block.py | 69 ++++++++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/merlin/models/tf/outputs/contrastive.py b/merlin/models/tf/outputs/contrastive.py index c5efe25139..0ba419f087 100644 --- a/merlin/models/tf/outputs/contrastive.py +++ b/merlin/models/tf/outputs/contrastive.py @@ -197,6 +197,9 @@ def call_contrastive(self, inputs, features, targets, training=False, testing=Fa positive_id = features[self.col_schema.name] positive_embedding = inputs[self.candidate_name] + if isinstance(positive_id, tf.RaggedTensor): + positive_id = positive_id.to_tensor() + positive = Candidate(positive_id, {**features}).with_embedding(positive_embedding) negative = self.sample_negatives(positive, features, training=training, testing=testing) if self.has_candidate_weights and ( @@ -314,6 +317,8 @@ def sample_negatives( return negatives def embedding_lookup(self, ids: tf.Tensor): + if isinstance(ids, tf.RaggedTensor): + ids = ids.to_tensor() return self.to_call.embedding_lookup(tf.squeeze(ids)) def to_dataset(self, gpu=None) -> merlin.io.Dataset: diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index dd6e6603a7..b159044721 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -25,6 +25,75 @@ def test_import(): assert transformers is not None +@pytest.mark.parametrize("run_eagerly", [True]) +def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): + + seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag( + Tags.CATEGORICAL + ) + target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0] + + loader = Loader(sequence_testing_data, batch_size=8, shuffle=False) + + model = mm.RetrievalModelV2( + query=mm.SequentialBlock( + [ + mm.InputBlockV2( + seq_schema, + embeddings=mm.Embeddings( + seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None + ), + ), + GPT2Block(d_model=48, n_head=4, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()), + ] + ), + output=mm.ContrastiveOutput( + seq_schema.select_by_name(target), negative_samplers="in-batch" + ), + ) + seq_mask_random = mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3) + + inputs, targets = next(iter(loader)) + outputs = model(inputs, targets=targets, training=True) + assert list(outputs.shape) == [8, 4, 51997] + testing_utils.model_test( + model, + loader, + run_eagerly=run_eagerly, + reload_model=True, + fit_kwargs={"pre": seq_mask_random}, + metrics=[mm.RecallAt(5000), mm.NDCGAt(5000)], + ) + + +def test_retrieval(sequence_testing_data: Dataset): + dataset = sequence_testing_data + + seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag( + Tags.CATEGORICAL + ) + + transformer_block = GPT2Block(d_model=48, n_head=4, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()) + input_block = mm.InputBlockV2( + seq_schema, categorical=mm.Embeddings(seq_schema, sequence_combiner=None) + ) + candidate = dataset.schema.select_by_tag(Tags.ITEM_ID) + query = mm.SequentialBlock( + [ + input_block, + transformer_block, + ] + ) + output = mm.ContrastiveOutput(candidate, "in-batch") + model = mm.RetrievalModelV2(query=query, output=output) + + model.compile() + + target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0] + pre = mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3) + model.fit(dataset, batch_size=10, pre=pre) + + def test_transformer_encoder(): NUM_ROWS = 100 SEQ_LENGTH = 10 From 4f25bceaf2e8fc343d6ade35852b3082386891d1 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Thu, 27 Oct 2022 13:25:22 +0100 Subject: [PATCH 02/17] Update test for transformer retrieval model --- tests/unit/tf/transformers/test_block.py | 56 ++++++++++++++---------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index b159044721..79947fe086 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -31,40 +31,52 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag( Tags.CATEGORICAL ) + target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0] + predict_next = mm.SequencePredictNext(schema=seq_schema, target=target) + loader = Loader(sequence_testing_data, batch_size=8, shuffle=False, transform=predict_next) - loader = Loader(sequence_testing_data, batch_size=8, shuffle=False) + query_schema = seq_schema + output_schema = seq_schema.select_by_name(target) - model = mm.RetrievalModelV2( - query=mm.SequentialBlock( - [ - mm.InputBlockV2( - seq_schema, - embeddings=mm.Embeddings( - seq_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None - ), - ), - GPT2Block(d_model=48, n_head=4, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()), - ] - ), - output=mm.ContrastiveOutput( - seq_schema.select_by_name(target), negative_samplers="in-batch" + query_encoder = mm.Encoder( + mm.InputBlockV2( + query_schema, + embeddings=mm.Embeddings( + query_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None + ), ), + GPT2Block(d_model=48, n_head=4, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()), + tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1)), + ) + + model = mm.RetrievalModelV2( + query=query_encoder, + output=mm.ContrastiveOutput(output_schema, negative_samplers="in-batch"), ) - seq_mask_random = mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3) - inputs, targets = next(iter(loader)) - outputs = model(inputs, targets=targets, training=True) - assert list(outputs.shape) == [8, 4, 51997] testing_utils.model_test( model, loader, run_eagerly=run_eagerly, - reload_model=True, - fit_kwargs={"pre": seq_mask_random}, - metrics=[mm.RecallAt(5000), mm.NDCGAt(5000)], + reload_model=False, + metrics={}, ) + predictions = model.predict(loader) + assert list(predictions.shape) == [100, 51997] + + query_embeddings = query_encoder.predict(loader) + assert list(query_embeddings.shape) == [100, 48] + + # query_embeddings = model.query_embeddings(sequence_testing_data, batch_size=10).compute() + item_embeddings = model.candidate_embeddings().compute().to_numpy() + + assert list(item_embeddings.shape) == [51997, 48] + predicitons_2 = np.dot(query_embeddings, item_embeddings.T) + + np.testing.assert_allclose(predictions, predicitons_2, atol=1e-7) + def test_retrieval(sequence_testing_data: Dataset): dataset = sequence_testing_data From dcf594f5ea922f3e6fe819041e0c77fb5d2badb7 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Thu, 27 Oct 2022 13:25:49 +0100 Subject: [PATCH 03/17] Remove test_retrieval from test_block --- tests/unit/tf/transformers/test_block.py | 28 ------------------------ 1 file changed, 28 deletions(-) diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 79947fe086..2ae493df23 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -78,34 +78,6 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): np.testing.assert_allclose(predictions, predicitons_2, atol=1e-7) -def test_retrieval(sequence_testing_data: Dataset): - dataset = sequence_testing_data - - seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag( - Tags.CATEGORICAL - ) - - transformer_block = GPT2Block(d_model=48, n_head=4, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()) - input_block = mm.InputBlockV2( - seq_schema, categorical=mm.Embeddings(seq_schema, sequence_combiner=None) - ) - candidate = dataset.schema.select_by_tag(Tags.ITEM_ID) - query = mm.SequentialBlock( - [ - input_block, - transformer_block, - ] - ) - output = mm.ContrastiveOutput(candidate, "in-batch") - model = mm.RetrievalModelV2(query=query, output=output) - - model.compile() - - target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0] - pre = mm.SequenceMaskRandom(schema=seq_schema, target=target, masking_prob=0.3) - model.fit(dataset, batch_size=10, pre=pre) - - def test_transformer_encoder(): NUM_ROWS = 100 SEQ_LENGTH = 10 From 3c87d81766af542ee6cb4c55966a6025af279972 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 28 Oct 2022 12:17:38 +0100 Subject: [PATCH 04/17] Allow index param to be optional to `Encoder.encode` --- merlin/models/tf/core/encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/models/tf/core/encoder.py b/merlin/models/tf/core/encoder.py index 1478f2268c..511081d5cd 100644 --- a/merlin/models/tf/core/encoder.py +++ b/merlin/models/tf/core/encoder.py @@ -71,7 +71,7 @@ def encode( elif isinstance(index, Tags): output_schema = self.schema.select_by_tag(index) else: - raise ValueError(f"Invalid index: {index}") + output_schema = None return self.batch_predict( dataset, From d22d3baabd0d97f66f3f96da03056665d675154d Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 28 Oct 2022 12:18:10 +0100 Subject: [PATCH 05/17] Correct target extraction in `SequencePredictNext` --- merlin/models/tf/transforms/sequence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/models/tf/transforms/sequence.py b/merlin/models/tf/transforms/sequence.py index 8350a948d7..3c03bc102d 100644 --- a/merlin/models/tf/transforms/sequence.py +++ b/merlin/models/tf/transforms/sequence.py @@ -219,7 +219,7 @@ def call( self._check_seq_inputs_targets(inputs) # Shifts the target column to be the next item of corresponding input column - new_target = inputs[self.target_name][:, 1:] + new_target = inputs[self.target_name][:, -1:] if targets is None: targets = dict({self.target_name: new_target}) elif isinstance(targets, dict): From fada3a8b5ede5264b2f634154ee684733595a264 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 1 Nov 2022 10:51:06 +0000 Subject: [PATCH 06/17] Replace ragged coercion with axis aware tf.squeeze --- merlin/models/tf/outputs/contrastive.py | 9 +++------ merlin/models/tf/utils/tf_utils.py | 8 ++++++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/merlin/models/tf/outputs/contrastive.py b/merlin/models/tf/outputs/contrastive.py index 0ba419f087..0a73da3542 100644 --- a/merlin/models/tf/outputs/contrastive.py +++ b/merlin/models/tf/outputs/contrastive.py @@ -197,9 +197,6 @@ def call_contrastive(self, inputs, features, targets, training=False, testing=Fa positive_id = features[self.col_schema.name] positive_embedding = inputs[self.candidate_name] - if isinstance(positive_id, tf.RaggedTensor): - positive_id = positive_id.to_tensor() - positive = Candidate(positive_id, {**features}).with_embedding(positive_embedding) negative = self.sample_negatives(positive, features, training=training, testing=testing) if self.has_candidate_weights and ( @@ -317,9 +314,9 @@ def sample_negatives( return negatives def embedding_lookup(self, ids: tf.Tensor): - if isinstance(ids, tf.RaggedTensor): - ids = ids.to_tensor() - return self.to_call.embedding_lookup(tf.squeeze(ids)) + if ids.shape.rank == 2: + ids = tf.squeeze(ids, axis=1) + return self.to_call.embedding_lookup(ids) def to_dataset(self, gpu=None) -> merlin.io.Dataset: return merlin.io.Dataset(tf_utils.tensor_to_df(self.to_call.embeddings, gpu=gpu)) diff --git a/merlin/models/tf/utils/tf_utils.py b/merlin/models/tf/utils/tf_utils.py index d09afbbed1..6faff59f01 100644 --- a/merlin/models/tf/utils/tf_utils.py +++ b/merlin/models/tf/utils/tf_utils.py @@ -114,8 +114,12 @@ def rescore_false_negatives( Zeroes the logits of accidental negatives. """ # Removing dimensions of size 1 from the shape of the item ids, if applicable - positive_item_ids = tf.cast(tf.squeeze(positive_item_ids), neg_samples_item_ids.dtype) - neg_samples_item_ids = tf.squeeze(neg_samples_item_ids) + if positive_item_ids.shape.rank == 2: + positive_item_ids = tf.squeeze(positive_item_ids, axis=1) + positive_item_ids = tf.cast(positive_item_ids, neg_samples_item_ids.dtype) + + if neg_samples_item_ids.shape.rank == 2: + neg_samples_item_ids = tf.squeeze(neg_samples_item_ids, axis=1) # Reshapes positive and negative ids so that false_negatives_mask matches the scores shape false_negatives_mask = tf.equal( From 3e7904551e5f4b6b7c187f8231bdfd915cf81ce1 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 4 Nov 2022 13:28:17 +0000 Subject: [PATCH 07/17] Revert change to predict next --- merlin/models/tf/transforms/sequence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/models/tf/transforms/sequence.py b/merlin/models/tf/transforms/sequence.py index 3c03bc102d..8350a948d7 100644 --- a/merlin/models/tf/transforms/sequence.py +++ b/merlin/models/tf/transforms/sequence.py @@ -219,7 +219,7 @@ def call( self._check_seq_inputs_targets(inputs) # Shifts the target column to be the next item of corresponding input column - new_target = inputs[self.target_name][:, -1:] + new_target = inputs[self.target_name][:, 1:] if targets is None: targets = dict({self.target_name: new_target}) elif isinstance(targets, dict): From 7fd5ae4a2b9979eb3090f99541528ff702e1c092 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 4 Nov 2022 13:38:39 +0000 Subject: [PATCH 08/17] Remove unused ReplaceMaskedEmbeddings (only required for MLM model) --- tests/unit/tf/transformers/test_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 5e65598906..00600fbce9 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -46,7 +46,7 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): query_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None ), ), - GPT2Block(d_model=48, n_head=4, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()), + GPT2Block(d_model=48, n_head=4, n_layer=2), tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1)), ) From 3a41f793cbcebf26369681e04a4be5e7ebddb92d Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 7 Nov 2022 14:15:17 +0000 Subject: [PATCH 09/17] Support tuple return type from model.fit `pre` argument --- merlin/models/tf/models/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index db221f998d..9d3529c571 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -671,6 +671,11 @@ def train_step(self, data): out = call_layer(self.train_pre, x, targets=y, features=x, training=True) if isinstance(out, Prediction): x, y = out.outputs, out.targets + elif isinstance(out, tuple): + assert ( + len(out) == 2 + ), "output of `pre` must be a 2-tuple of x, y or `Prediction` name tuple" + x, y = out else: x = out From a091fd560b26f5a59240b14af831cb1fd3f9dcc9 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 7 Nov 2022 14:15:51 +0000 Subject: [PATCH 10/17] Use predict last and use as pre instead of transform --- tests/unit/tf/transformers/test_block.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 00600fbce9..d4fde7b815 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -33,8 +33,8 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): ) target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0] - predict_next = mm.SequencePredictNext(schema=seq_schema, target=target) - loader = Loader(sequence_testing_data, batch_size=8, shuffle=False, transform=predict_next) + predict_last = mm.SequencePredictLast(schema=seq_schema, target=target) + loader = Loader(sequence_testing_data, batch_size=8, shuffle=False) query_schema = seq_schema output_schema = seq_schema.select_by_name(target) @@ -61,6 +61,7 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): run_eagerly=run_eagerly, reload_model=False, metrics={}, + fit_kwargs={"pre": predict_last}, ) predictions = model.predict(loader) From 83b87c8419e1c2c33a568bf956249048d03a5a31 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 7 Nov 2022 19:49:46 +0000 Subject: [PATCH 11/17] Revert changes to contrastive output --- merlin/models/tf/outputs/contrastive.py | 4 +--- merlin/models/tf/utils/tf_utils.py | 8 ++------ 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/merlin/models/tf/outputs/contrastive.py b/merlin/models/tf/outputs/contrastive.py index 0a73da3542..c5efe25139 100644 --- a/merlin/models/tf/outputs/contrastive.py +++ b/merlin/models/tf/outputs/contrastive.py @@ -314,9 +314,7 @@ def sample_negatives( return negatives def embedding_lookup(self, ids: tf.Tensor): - if ids.shape.rank == 2: - ids = tf.squeeze(ids, axis=1) - return self.to_call.embedding_lookup(ids) + return self.to_call.embedding_lookup(tf.squeeze(ids)) def to_dataset(self, gpu=None) -> merlin.io.Dataset: return merlin.io.Dataset(tf_utils.tensor_to_df(self.to_call.embeddings, gpu=gpu)) diff --git a/merlin/models/tf/utils/tf_utils.py b/merlin/models/tf/utils/tf_utils.py index 6faff59f01..d09afbbed1 100644 --- a/merlin/models/tf/utils/tf_utils.py +++ b/merlin/models/tf/utils/tf_utils.py @@ -114,12 +114,8 @@ def rescore_false_negatives( Zeroes the logits of accidental negatives. """ # Removing dimensions of size 1 from the shape of the item ids, if applicable - if positive_item_ids.shape.rank == 2: - positive_item_ids = tf.squeeze(positive_item_ids, axis=1) - positive_item_ids = tf.cast(positive_item_ids, neg_samples_item_ids.dtype) - - if neg_samples_item_ids.shape.rank == 2: - neg_samples_item_ids = tf.squeeze(neg_samples_item_ids, axis=1) + positive_item_ids = tf.cast(tf.squeeze(positive_item_ids), neg_samples_item_ids.dtype) + neg_samples_item_ids = tf.squeeze(neg_samples_item_ids) # Reshapes positive and negative ids so that false_negatives_mask matches the scores shape false_negatives_mask = tf.equal( From 9704c3e0fe2418a8f2baaad759d9ea6d4d28eb80 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Tue, 8 Nov 2022 19:54:12 +0000 Subject: [PATCH 12/17] Set process_lists default value to False --- merlin/models/tf/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/models/tf/loader.py b/merlin/models/tf/loader.py index 60fc5fe262..bca51ca515 100644 --- a/merlin/models/tf/loader.py +++ b/merlin/models/tf/loader.py @@ -575,7 +575,7 @@ def sample_batch( include_targets: bool = True, to_ragged: bool = False, to_dense: bool = False, - process_lists=True, + process_lists=False, ): """Util function to generate a batch of input tensors from a merlin.io.Dataset instance From 22eb8f621de015b1fa7dc3872bef5043e40c98ae Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Fri, 11 Nov 2022 11:46:52 +0000 Subject: [PATCH 13/17] Add d_model and MLPBlock --- tests/unit/tf/transformers/test_block.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index d4fde7b815..c39b44d339 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -39,6 +39,7 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): query_schema = seq_schema output_schema = seq_schema.select_by_name(target) + d_model = 48 query_encoder = mm.Encoder( mm.InputBlockV2( query_schema, @@ -46,7 +47,8 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): query_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None ), ), - GPT2Block(d_model=48, n_head=4, n_layer=2), + mm.MLPBlock([d_model]), + GPT2Block(d_model=d_model, n_head=2, n_layer=2), tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1)), ) @@ -68,12 +70,12 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): assert list(predictions.shape) == [100, 51997] query_embeddings = query_encoder.predict(loader) - assert list(query_embeddings.shape) == [100, 48] + assert list(query_embeddings.shape) == [100, d_model] # query_embeddings = model.query_embeddings(sequence_testing_data, batch_size=10).compute() item_embeddings = model.candidate_embeddings().compute().to_numpy() - assert list(item_embeddings.shape) == [51997, 48] + assert list(item_embeddings.shape) == [51997, d_model] predicitons_2 = np.dot(query_embeddings, item_embeddings.T) np.testing.assert_allclose(predictions, predicitons_2, atol=1e-7) From f9a485744974fdac7e3d4c3f371c8d49c654c224 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 14 Nov 2022 10:12:34 +0000 Subject: [PATCH 14/17] Revert change to `Encoder.encode` --- merlin/models/tf/core/encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/models/tf/core/encoder.py b/merlin/models/tf/core/encoder.py index f9d278119f..75931a4823 100644 --- a/merlin/models/tf/core/encoder.py +++ b/merlin/models/tf/core/encoder.py @@ -87,7 +87,7 @@ def encode( elif isinstance(index, Tags): output_schema = self.schema.select_by_tag(index) else: - output_schema = None + raise ValueError(f"Invalid index: {index}") return self.batch_predict( dataset, From 53320b0ab49901e711ceee866830f8df446c50bf Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 14 Nov 2022 10:12:51 +0000 Subject: [PATCH 15/17] Revert change to default value of `process_lists` in `sample_batch` --- merlin/models/tf/loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/models/tf/loader.py b/merlin/models/tf/loader.py index b3cd721e63..771cd339ba 100644 --- a/merlin/models/tf/loader.py +++ b/merlin/models/tf/loader.py @@ -580,7 +580,7 @@ def sample_batch( include_targets: bool = True, to_ragged: bool = False, to_dense: bool = False, - process_lists=False, + process_lists=True, ): """Util function to generate a batch of input tensors from a merlin.io.Dataset instance From 8bb8683d610f8ef82600bf427507d3f0dbde3cb2 Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 14 Nov 2022 10:13:52 +0000 Subject: [PATCH 16/17] Remove commented query_embeddings line --- tests/unit/tf/transformers/test_block.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 4f1026cfb9..7f01a962b2 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -72,7 +72,6 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): query_embeddings = query_encoder.predict(loader) assert list(query_embeddings.shape) == [100, d_model] - # query_embeddings = model.query_embeddings(sequence_testing_data, batch_size=10).compute() item_embeddings = model.candidate_embeddings().compute().to_numpy() assert list(item_embeddings.shape) == [51997, d_model] From 0feea3c996c59583c91ad66824acd291f53e14eb Mon Sep 17 00:00:00 2001 From: Oliver Holworthy Date: Mon, 14 Nov 2022 10:16:23 +0000 Subject: [PATCH 17/17] update comment about prediction tuple --- merlin/models/tf/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/merlin/models/tf/models/base.py b/merlin/models/tf/models/base.py index fa46829215..6d0b2dd24a 100644 --- a/merlin/models/tf/models/base.py +++ b/merlin/models/tf/models/base.py @@ -710,7 +710,7 @@ def train_step(self, data): elif isinstance(out, tuple): assert ( len(out) == 2 - ), "output of `pre` must be a 2-tuple of x, y or `Prediction` name tuple" + ), "output of `pre` must be a 2-tuple of x, y or `Prediction` tuple" x, y = out else: x = out