Skip to content

Commit

Permalink
Fix loading of TwoTowerModel with context query variable (#887)
Browse files Browse the repository at this point in the history
* Use `add_weight` instead of `add_variable` on `ModelContext`

* Remove the now unused method `add_variable` on `ModelContext`

* Update two tower save test to reload model

* Pad queries with zeros when batch size is different from first batch

* Use tf.shape instead of .shape to get dims

Co-authored-by: rnyak <ronayak@hotmail.com>
  • Loading branch information
oliverholworthy and rnyak authored Nov 22, 2022
1 parent dc69d42 commit 7a1227d
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
27 changes: 17 additions & 10 deletions merlin/models/tf/blocks/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,12 @@ def __init__(
def build(self, input_shapes):
if isinstance(input_shapes, dict):
query_shape = input_shapes[self.query_name]
self.context.add_variable(
tf.Variable(
initial_value=tf.zeros([1, query_shape[-1]], dtype=tf.float32),
name="query",
trainable=False,
validate_shape=False,
dtype=tf.float32,
shape=tf.TensorShape([None, query_shape[-1]]),
)
self.context.add_weight(
name="query",
shape=query_shape,
dtype=tf.float32,
trainable=False,
initializer=tf.keras.initializers.Zeros(),
)

super().build(input_shapes)
Expand Down Expand Up @@ -241,7 +238,17 @@ def call(
"""
if self.cache_query:
# enabled only during top-k evaluation
self.context["query"].assign(tf.cast(inputs[self.query_name], tf.float32))

query = inputs[self.query_name]
context_query_size = tf.shape(self.context["query"])[0]
# pad with zeros to match shape of initial query variable
padding_size = context_query_size - tf.shape(query)[0]
if padding_size > 0:
query = tf.pad(query, [[0, padding_size], [0, 0]])

query = query[:context_query_size]

self.context["query"].assign(tf.cast(query, tf.float32))

if training or testing:
return inputs
Expand Down
3 changes: 0 additions & 3 deletions merlin/models/tf/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ class ModelContext(Layer):
(This is created automatically in the model and doesn't need to be created manually.)
"""

def add_variable(self, variable):
setattr(self, variable.name, variable)

def add_embedding_table(self, name, embedding_table):
embedding_tables = getattr(self, "embedding_tables", {})
embedding_tables[name] = embedding_table
Expand Down
5 changes: 5 additions & 0 deletions merlin/models/tf/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ def call_outputs(
2D Tensors with the one-hot representation of true targets and
the scores for the top-k implicit negatives.
"""

n = tf.shape(outputs.positive_item_ids)[0]

queries = self.context["query"]
queries = queries[:n]

pred_top_scores, top_ids = self(queries, k=self._k)

targets_sorted = tf.cast(
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tf/models/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def test_two_tower_model_save(tmpdir, ecommerce_data: Dataset):
embedding_options=mm.EmbeddingOptions(infer_embedding_sizes=True),
)

testing_utils.model_test(model, dataset, reload_model=False)
testing_utils.model_test(model, dataset, reload_model=True)

query_tower = model.retrieval_block.query_block()
query_tower_path = Path(tmpdir) / "query_tower"
Expand Down

0 comments on commit 7a1227d

Please sign in to comment.