Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test for retrieval model with transformer block #833

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6eea6ab
Add test for transformer with RetrievalModelV2
oliverholworthy Oct 26, 2022
4f25bce
Update test for transformer retrieval model
oliverholworthy Oct 27, 2022
dcf594f
Remove test_retrieval from test_block
oliverholworthy Oct 27, 2022
3c87d81
Allow index param to be optional to `Encoder.encode`
oliverholworthy Oct 28, 2022
d22d3ba
Correct target extraction in `SequencePredictNext`
oliverholworthy Oct 28, 2022
fada3a8
Replace ragged coercion with axis aware tf.squeeze
oliverholworthy Nov 1, 2022
fd09644
Merge branch 'main' into transformer-retrieval-model
oliverholworthy Nov 3, 2022
3e79045
Revert change to predict next
oliverholworthy Nov 4, 2022
7fd5ae4
Remove unused ReplaceMaskedEmbeddings (only required for MLM model)
oliverholworthy Nov 4, 2022
3a41f79
Support tuple return type from model.fit `pre` argument
oliverholworthy Nov 7, 2022
a091fd5
Use predict last and use as pre instead of transform
oliverholworthy Nov 7, 2022
83b87c8
Revert changes to contrastive output
oliverholworthy Nov 7, 2022
9704c3e
Set process_lists default value to False
oliverholworthy Nov 8, 2022
22eb8f6
Add d_model and MLPBlock
oliverholworthy Nov 11, 2022
e233529
Merge branch 'main' into transformer-retrieval-model
oliverholworthy Nov 11, 2022
f9a4857
Revert change to `Encoder.encode`
oliverholworthy Nov 14, 2022
53320b0
Revert change to default value of `process_lists` in `sample_batch`
oliverholworthy Nov 14, 2022
8bb8683
Remove commented query_embeddings line
oliverholworthy Nov 14, 2022
0feea3c
update comment about prediction tuple
oliverholworthy Nov 14, 2022
e13b517
Merge branch 'main' into transformer-retrieval-model
marcromeyn Nov 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion merlin/models/tf/core/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def encode(
elif isinstance(index, Tags):
output_schema = self.schema.select_by_tag(index)
else:
raise ValueError(f"Invalid index: {index}")
output_schema = None

return self.batch_predict(
dataset,
Expand Down
4 changes: 3 additions & 1 deletion merlin/models/tf/outputs/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def sample_negatives(
return negatives

def embedding_lookup(self, ids: tf.Tensor):
return self.to_call.embedding_lookup(tf.squeeze(ids))
if ids.shape.rank == 2:
ids = tf.squeeze(ids, axis=1)
return self.to_call.embedding_lookup(ids)

def to_dataset(self, gpu=None) -> merlin.io.Dataset:
return merlin.io.Dataset(tf_utils.tensor_to_df(self.to_call.embeddings, gpu=gpu))
Expand Down
2 changes: 1 addition & 1 deletion merlin/models/tf/transforms/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def call(
self._check_seq_inputs_targets(inputs)

# Shifts the target column to be the next item of corresponding input column
new_target = inputs[self.target_name][:, 1:]
new_target = inputs[self.target_name][:, -1:]
Copy link
Contributor

@sararb sararb Nov 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tranform SequencePredictNext aims to perform a sliding window prediction of size 1 (we referred to it as Causal Language modeling in T4Rec to align with the NLP domain). More specifically, the session-based model will use the hidden representation at position N to predict the target at position N+1. So each row in the batch will be linked to multiple targets (from position 2 up to last non-padded position)

If you want to predict only the last item of the sequence you should use SequencePredictLast indead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. I think I got confused here between predict next and predict last.

if targets is None:
targets = dict({self.target_name: new_target})
elif isinstance(targets, dict):
Expand Down
8 changes: 6 additions & 2 deletions merlin/models/tf/utils/tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,12 @@ def rescore_false_negatives(
Zeroes the logits of accidental negatives.
"""
# Removing dimensions of size 1 from the shape of the item ids, if applicable
positive_item_ids = tf.cast(tf.squeeze(positive_item_ids), neg_samples_item_ids.dtype)
neg_samples_item_ids = tf.squeeze(neg_samples_item_ids)
if positive_item_ids.shape.rank == 2:
positive_item_ids = tf.squeeze(positive_item_ids, axis=1)
positive_item_ids = tf.cast(positive_item_ids, neg_samples_item_ids.dtype)

if neg_samples_item_ids.shape.rank == 2:
neg_samples_item_ids = tf.squeeze(neg_samples_item_ids, axis=1)

# Reshapes positive and negative ids so that false_negatives_mask matches the scores shape
false_negatives_mask = tf.equal(
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/tf/transformers/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,59 @@ def test_import():
assert transformers is not None


@pytest.mark.parametrize("run_eagerly", [True])
def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly):

seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag(
Tags.CATEGORICAL
)

target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]
predict_next = mm.SequencePredictNext(schema=seq_schema, target=target)
loader = Loader(sequence_testing_data, batch_size=8, shuffle=False, transform=predict_next)

query_schema = seq_schema
output_schema = seq_schema.select_by_name(target)

query_encoder = mm.Encoder(
mm.InputBlockV2(
query_schema,
embeddings=mm.Embeddings(
query_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None
),
),
GPT2Block(d_model=48, n_head=4, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ReplaceMaskedEmbeddings() is only needed when SequenceMaskRandom() is used (Masked Language Modeling. As we are using SequencePredictNext here for Causal Language Modeling, there will be no masked items to have embeddings replaced

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @gabrielspmoreira that makes sense. I was copying this from another test that used GPT2Block and didn't make the connection with the SequenceMaskRandom.

I guess including it in a case like this where there are no masks doesn't affet the model? (in other words does it act as a no-op). Or would it be worth adding something to ReplaceMaskedEmbeddings to raise an exception if no masks are found?

tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1)),
)

model = mm.RetrievalModelV2(
query=query_encoder,
output=mm.ContrastiveOutput(output_schema, negative_samplers="in-batch"),
)

testing_utils.model_test(
model,
loader,
run_eagerly=run_eagerly,
reload_model=False,
metrics={},
)

predictions = model.predict(loader)
assert list(predictions.shape) == [100, 51997]

query_embeddings = query_encoder.predict(loader)
assert list(query_embeddings.shape) == [100, 48]

# query_embeddings = model.query_embeddings(sequence_testing_data, batch_size=10).compute()
item_embeddings = model.candidate_embeddings().compute().to_numpy()

assert list(item_embeddings.shape) == [51997, 48]
predicitons_2 = np.dot(query_embeddings, item_embeddings.T)

np.testing.assert_allclose(predictions, predicitons_2, atol=1e-7)


def test_transformer_encoder():
NUM_ROWS = 100
SEQ_LENGTH = 10
Expand Down