-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add test for retrieval model with transformer block #833
Changes from 7 commits
6eea6ab
4f25bce
dcf594f
3c87d81
d22d3ba
fada3a8
fd09644
3e79045
7fd5ae4
3a41f79
a091fd5
83b87c8
9704c3e
22eb8f6
e233529
f9a4857
53320b0
8bb8683
0feea3c
e13b517
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,59 @@ def test_import(): | |
assert transformers is not None | ||
|
||
|
||
@pytest.mark.parametrize("run_eagerly", [True]) | ||
def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): | ||
|
||
seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag( | ||
Tags.CATEGORICAL | ||
) | ||
|
||
target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0] | ||
predict_next = mm.SequencePredictNext(schema=seq_schema, target=target) | ||
loader = Loader(sequence_testing_data, batch_size=8, shuffle=False, transform=predict_next) | ||
|
||
query_schema = seq_schema | ||
output_schema = seq_schema.select_by_name(target) | ||
|
||
query_encoder = mm.Encoder( | ||
mm.InputBlockV2( | ||
query_schema, | ||
embeddings=mm.Embeddings( | ||
query_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None | ||
), | ||
), | ||
GPT2Block(d_model=48, n_head=4, n_layer=2, pre=mm.ReplaceMaskedEmbeddings()), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @gabrielspmoreira that makes sense. I was copying this from another test that used GPT2Block and didn't make the connection with the I guess including it in a case like this where there are no masks doesn't affet the model? (in other words does it act as a no-op). Or would it be worth adding something to |
||
tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1)), | ||
) | ||
|
||
model = mm.RetrievalModelV2( | ||
query=query_encoder, | ||
output=mm.ContrastiveOutput(output_schema, negative_samplers="in-batch"), | ||
) | ||
|
||
testing_utils.model_test( | ||
model, | ||
loader, | ||
run_eagerly=run_eagerly, | ||
reload_model=False, | ||
metrics={}, | ||
) | ||
|
||
predictions = model.predict(loader) | ||
assert list(predictions.shape) == [100, 51997] | ||
|
||
query_embeddings = query_encoder.predict(loader) | ||
assert list(query_embeddings.shape) == [100, 48] | ||
|
||
# query_embeddings = model.query_embeddings(sequence_testing_data, batch_size=10).compute() | ||
item_embeddings = model.candidate_embeddings().compute().to_numpy() | ||
|
||
assert list(item_embeddings.shape) == [51997, 48] | ||
predicitons_2 = np.dot(query_embeddings, item_embeddings.T) | ||
|
||
np.testing.assert_allclose(predictions, predicitons_2, atol=1e-7) | ||
|
||
|
||
def test_transformer_encoder(): | ||
NUM_ROWS = 100 | ||
SEQ_LENGTH = 10 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tranform
SequencePredictNext
aims to perform a sliding window prediction of size 1 (we referred to it as Causal Language modeling in T4Rec to align with the NLP domain). More specifically, the session-based model will use the hidden representation at position N to predict the target at position N+1. So each row in the batch will be linked to multiple targets (from position 2 up to last non-padded position)If you want to predict only the last item of the sequence you should use
SequencePredictLast
indead.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. I think I got confused here between predict next and predict last.