Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] cannot properly save and load the TF Retrieval model #498

Closed
rnyak opened this issue Jun 7, 2022 · 4 comments · Fixed by #887
Closed

[FEA] cannot properly save and load the TF Retrieval model #498

rnyak opened this issue Jun 7, 2022 · 4 comments · Fixed by #887
Assignees
Labels
enhancement New feature or request P1
Milestone

Comments

@rnyak
Copy link
Contributor

rnyak commented Jun 7, 2022

Bug description

We want to be able to save the entire Two-Tower model and load back to be able to do model.evaluate() and model.predict(). However, we get the following error when we load back the model. To reproduce the following errors, please run the 05-Retrieval-Model.ipynb example. Then save and reload the model with the following scripts.

First save the model after model.fit() step:
model.save('two_tower')

Then when we load back the saved model we get the following error:

reloaded = tf.keras.models.load_model('two_tower')
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [17], in <cell line: 1>()
----> 1 reloaded = tf.keras.models.load_model('two_tower')

File /usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py:67, in filter_traceback.<locals>.error_handler(*args, **kwargs)
     65 except Exception as e:  # pylint: disable=broad-except
     66   filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67   raise e.with_traceback(filtered_tb) from None
     68 finally:
     69   del filtered_tb

File /usr/local/lib/python3.8/dist-packages/keras/saving/saved_model/load.py:1000, in revive_custom_object(identifier, metadata)
    998   return revived_cls._init_from_metadata(metadata)  # pylint: disable=protected-access
    999 else:
-> 1000   raise ValueError(
   1001       f'Unable to restore custom object of type {identifier}. '
   1002       f'Please make sure that any custom layers are included in the '
   1003       f'`custom_objects` arg when calling `load_model()` and make sure that '
   1004       f'all layers implement `get_config` and `from_config`.')

ValueError: Unable to restore custom object of type _tf_keras_metric. Please make sure that any custom layers are included in the `custom_objects` arg when calling `load_model()` and make sure that all layers implement `get_config` and `from_config`.

Expected behavior

Environment details

  • Merlin version:
  • Platform:
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?): TF 2.8.0

Using merlin-tensorflow-training:22.05 docker image with the latest main branches pulled.

@rnyak rnyak added bug Something isn't working status/needs-triage labels Jun 7, 2022
@viswa-nvidia viswa-nvidia added this to the Merlin 22.06 milestone Jun 9, 2022
@rnyak
Copy link
Contributor Author

rnyak commented Jun 13, 2022

This is also required in this RMP NVIDIA-Merlin/Merlin#271

@oliverholworthy
Copy link
Member

oliverholworthy commented Jun 15, 2022

It looks like we've got some custom objects (custom metrics in this case) that need to be specified when calling load_model.

custom_objects = {
  "RecallAt": mm.RecallAt,
  "NDCGAt": mm.NDCGAt,
}
reloaded = tf.keras.models.load_model('two_tower', custom_objects=custom_objects)

We could potentially create a load_model helper that adds all the known custom objects automatically.

@oliverholworthy tested that but it did not work for me.

@viswa-nvidia viswa-nvidia added P1 and removed P0 labels Jun 27, 2022
@rnyak rnyak changed the title [BUG] cannot properly save and load the TF Retrieval model [FEA] cannot properly save and load the TF Retrieval model Jul 25, 2022
@rnyak rnyak added enhancement New feature or request and removed bug Something isn't working labels Jul 25, 2022
@oliverholworthy
Copy link
Member

Update. The saving of TwoTowerModel is now working. (22.08)

However, loading is still not working. Due to an unbound query variable in ItemRetrievalScorer when model is reloaded. And trying with the example in the ticket description should now give a new error with a message similar to this:

AssertionError: Found 1 Python objects that were not bound to checkpointed values, likely due to changes in the Python program. Showing 1 of 1 unmatched objects: [<tf.Variable 'query:0' shape=(None, 4) dtype=float32, numpy=array([[0., 0., 0., 0.]], dtype=float32)>]

There is work on-going in #633 that is taking us toward a place where we can replace this retrieval scorer with a new implementation that won't have this issue. Aiming for the next release 22.09

@sararb
Copy link
Contributor

sararb commented Oct 7, 2022

PR #790 introduces the definition of TwoTowerModelV2 that can be saved and loaded correctly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P1
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants