Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SentencePieceTokenizer inside a keras.models.Model fails to be reconstructed during keras.saving.load_model() #1522

Closed
briango28 opened this issue Mar 25, 2024 · 2 comments
Assignees
Labels
type:Bug Something isn't working

Comments

@briango28
Copy link
Contributor

Describe the bug

When a SentencePieceTokenizer is integrated into a model using the functional API, a.k.a. keras.models.Model, it cannot be properly reconstructed from a saved model.keras file.

While untested, I would expect any other custom keras object that relies on load_assets() to be able to compute an output spec for a given input tensor to exhibit the same behavior.

To Reproduce

https://colab.research.google.com/drive/1XMNYLQrJo25_BkIv8GT02bMJZMjw5RoC?usp=sharing
Refer to cell no. 6

Expected behavior

Proper reconstruction of SentencePieceTokenizer.

Additional context

When keras.saving.load_model() is called on a saved Functional model, the model is reconstructed by running a KerasTensor through the model. Because this happens before the vocabulary is loaded via SentencePieceTokenizer.load_assets(), an error is raised upon encountering the tokenizer in the model.

The above functionality can be found in keras.saving.saving_lib.
_load_state(), which is responsible for calling load_assets() is called on L178, later than deserialize_keras_object() on L155.

Would you like to help us fix it?
Defining SentencePieceTokenizer.compute_output_spec() seems to be sufficient to construct the model graph, allowing the loading function to continue to _load_state().

Cell no. 3 in the colab notebook is a working example.

@briango28
Copy link
Contributor Author

briango28 commented Mar 25, 2024

After a quick skim of the repository, BytePairTokenizer, WordPieceTokenizer, and SentencePieceTokenizer seem to have vocabularies saved & loaded with save_assets() & load_assets() and are affected by this issue.

The aforementioned compute_output_spec() method (copied below) should work for each of them.

def compute_output_spec(self, input_spec) -> keras.KerasTensor:
    return keras.KerasTensor(input_spec.shape + (self.sequence_length,),
                             dtype=self.compute_dtype,
                             sparse=not self.sequence_length)

-> #1523

@mattdangerw
Copy link
Member

@briango28 can this be marked as fixed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
type:Bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants