Skip to content

Commit

Permalink
Add test for retrieval model with transformer block (#833)
Browse files Browse the repository at this point in the history
* Add test for transformer with RetrievalModelV2

* Update test for transformer retrieval model

* Remove test_retrieval from test_block

* Allow index param to be optional to `Encoder.encode`

* Correct target extraction in `SequencePredictNext`

* Replace ragged coercion with axis aware tf.squeeze

* Revert change to predict next

* Remove unused ReplaceMaskedEmbeddings (only required for MLM model)

* Support tuple return type from model.fit `pre` argument

* Use predict last and use as pre instead of transform

* Revert changes to contrastive output

* Set process_lists default value to False

* Add d_model and MLPBlock

* Revert change to `Encoder.encode`

* Revert change to default value of `process_lists` in `sample_batch`

* Remove commented query_embeddings line

* update comment about prediction tuple

Co-authored-by: Marc Romeyn <marcromeyn@gmail.com>
  • Loading branch information
2 people authored and radekosmulski committed Nov 16, 2022
1 parent ba38df0 commit 0899049
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
5 changes: 5 additions & 0 deletions merlin/models/tf/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,11 @@ def train_step(self, data):
out = call_layer(self.train_pre, x, targets=y, features=x, training=True)
if isinstance(out, Prediction):
x, y = out.outputs, out.targets
elif isinstance(out, tuple):
assert (
len(out) == 2
), "output of `pre` must be a 2-tuple of x, y or `Prediction` tuple"
x, y = out
else:
x = out

Expand Down
55 changes: 55 additions & 0 deletions tests/unit/tf/transformers/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,61 @@ def test_import():
assert transformers is not None


@pytest.mark.parametrize("run_eagerly", [True])
def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly):

seq_schema = sequence_testing_data.schema.select_by_tag(Tags.SEQUENCE).select_by_tag(
Tags.CATEGORICAL
)

target = sequence_testing_data.schema.select_by_tag(Tags.ITEM_ID).column_names[0]
predict_last = mm.SequencePredictLast(schema=seq_schema, target=target)
loader = Loader(sequence_testing_data, batch_size=8, shuffle=False)

query_schema = seq_schema
output_schema = seq_schema.select_by_name(target)

d_model = 48
query_encoder = mm.Encoder(
mm.InputBlockV2(
query_schema,
embeddings=mm.Embeddings(
query_schema.select_by_tag(Tags.CATEGORICAL), sequence_combiner=None
),
),
mm.MLPBlock([d_model]),
GPT2Block(d_model=d_model, n_head=2, n_layer=2),
tf.keras.layers.Lambda(lambda x: tf.reduce_mean(x, axis=1)),
)

model = mm.RetrievalModelV2(
query=query_encoder,
output=mm.ContrastiveOutput(output_schema, negative_samplers="in-batch"),
)

testing_utils.model_test(
model,
loader,
run_eagerly=run_eagerly,
reload_model=False,
metrics={},
fit_kwargs={"pre": predict_last},
)

predictions = model.predict(loader)
assert list(predictions.shape) == [100, 51997]

query_embeddings = query_encoder.predict(loader)
assert list(query_embeddings.shape) == [100, d_model]

item_embeddings = model.candidate_embeddings().compute().to_numpy()

assert list(item_embeddings.shape) == [51997, d_model]
predicitons_2 = np.dot(query_embeddings, item_embeddings.T)

np.testing.assert_allclose(predictions, predicitons_2, atol=1e-7)


def test_transformer_encoder():
NUM_ROWS = 100
SEQ_LENGTH = 10
Expand Down

0 comments on commit 0899049

Please sign in to comment.