Skip to content

Commit

Permalink
Backwards compatible fix for functional model saving (#1378)
Browse files Browse the repository at this point in the history
The functional model attribute path should always take precedence, so
we can have stable checkpoints for functional subclassed models.

See keras-team/keras#18982

Note that this will cause a bunch of failures until we re-upload weights.

We should apply this workaround to Keras 3 and Keras 2 until we release
Keras 3. Then restrict to only Keras 2. Then finally delete entirely
when we drop Keras 2 support.
  • Loading branch information
mattdangerw committed Jan 4, 2024
1 parent 29a0ae5 commit 401e569
Show file tree
Hide file tree
Showing 15 changed files with 83 additions and 98 deletions.
8 changes: 4 additions & 4 deletions keras_nlp/models/albert/albert_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"path": "albert",
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/albert/albert_base_en_uncased/1",
"kaggle_handle": "kaggle://keras/albert/albert_base_en_uncased/2",
},
"albert_large_en_uncased": {
"metadata": {
Expand All @@ -39,7 +39,7 @@
"path": "albert",
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/albert/albert_large_en_uncased/1",
"kaggle_handle": "kaggle://keras/albert/albert_large_en_uncased/2",
},
"albert_extra_large_en_uncased": {
"metadata": {
Expand All @@ -52,7 +52,7 @@
"path": "albert",
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/albert/albert_extra_large_en_uncased/1",
"kaggle_handle": "kaggle://keras/albert/albert_extra_large_en_uncased/2",
},
"albert_extra_extra_large_en_uncased": {
"metadata": {
Expand All @@ -65,6 +65,6 @@
"path": "albert",
"model_card": "https://github.com/google-research/albert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/albert/albert_extra_extra_large_en_uncased/1",
"kaggle_handle": "kaggle://keras/albert/albert_extra_extra_large_en_uncased/2",
},
}
13 changes: 13 additions & 0 deletions keras_nlp/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ class Backbone(keras.Model):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._token_embedding = None
self._functional_layer_ids = set(
id(layer) for layer in self._flatten_layers()
)

def __dir__(self):
# Temporary fixes for weight saving. This mimics the following PR for
# older version of Keras: https://github.com/keras-team/keras/pull/18982
def filter_fn(attr):
if attr == "_layer_checkpoint_dependencies":
return False
return id(getattr(self, attr)) not in self._functional_layer_ids

return filter(filter_fn, super().__dir__())

def __setattr__(self, name, value):
# Work around torch setattr for properties.
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/models/bart/bart_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"path": "bart",
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.md",
},
"kaggle_handle": "kaggle://keras/bart/bart_base_en/1",
"kaggle_handle": "kaggle://keras/bart/bart_base_en/2",
},
"bart_large_en": {
"metadata": {
Expand All @@ -47,7 +47,7 @@
"dropout": 0.1,
"max_sequence_length": 1024,
},
"kaggle_handle": "kaggle://keras/bart/bart_large_en/1",
"kaggle_handle": "kaggle://keras/bart/bart_large_en/2",
},
"bart_large_en_cnn": {
"metadata": {
Expand All @@ -69,6 +69,6 @@
"dropout": 0.1,
"max_sequence_length": 1024,
},
"kaggle_handle": "kaggle://keras/bart/bart_large_en_cnn/1",
"kaggle_handle": "kaggle://keras/bart/bart_large_en_cnn/2",
},
}
20 changes: 10 additions & 10 deletions keras_nlp/models/bert/bert_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_tiny_en_uncased/1",
"kaggle_handle": "kaggle://keras/bert/bert_tiny_en_uncased/2",
},
"bert_small_en_uncased": {
"metadata": {
Expand All @@ -38,7 +38,7 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_small_en_uncased/1",
"kaggle_handle": "kaggle://keras/bert/bert_small_en_uncased/2",
},
"bert_medium_en_uncased": {
"metadata": {
Expand All @@ -51,7 +51,7 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_medium_en_uncased/1",
"kaggle_handle": "kaggle://keras/bert/bert_medium_en_uncased/2",
},
"bert_base_en_uncased": {
"metadata": {
Expand All @@ -64,7 +64,7 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_base_en_uncased/1",
"kaggle_handle": "kaggle://keras/bert/bert_base_en_uncased/2",
},
"bert_base_en": {
"metadata": {
Expand All @@ -77,7 +77,7 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_base_en/1",
"kaggle_handle": "kaggle://keras/bert/bert_base_en/2",
},
"bert_base_zh": {
"metadata": {
Expand All @@ -89,7 +89,7 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_base_zh/1",
"kaggle_handle": "kaggle://keras/bert/bert_base_zh/2",
},
"bert_base_multi": {
"metadata": {
Expand All @@ -101,7 +101,7 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_base_multi/1",
"kaggle_handle": "kaggle://keras/bert/bert_base_multi/2",
},
"bert_large_en_uncased": {
"metadata": {
Expand All @@ -114,7 +114,7 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_large_en_uncased/1",
"kaggle_handle": "kaggle://keras/bert/bert_large_en_uncased/2",
},
"bert_large_en": {
"metadata": {
Expand All @@ -127,7 +127,7 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_large_en/1",
"kaggle_handle": "kaggle://keras/bert/bert_large_en/2",
},
}

Expand All @@ -142,6 +142,6 @@
"path": "bert",
"model_card": "https://github.com/google-research/bert/blob/master/README.md",
},
"kaggle_handle": "kaggle://keras/bert/bert_tiny_en_uncased_sst2/1",
"kaggle_handle": "kaggle://keras/bert/bert_tiny_en_uncased_sst2/3",
}
}
10 changes: 5 additions & 5 deletions keras_nlp/models/deberta_v3/deberta_v3_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"path": "deberta_v3",
"model_card": "https://huggingface.co/microsoft/deberta-v3-xsmall",
},
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_extra_small_en/1",
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_extra_small_en/2",
},
"deberta_v3_small_en": {
"metadata": {
Expand All @@ -38,7 +38,7 @@
"path": "deberta_v3",
"model_card": "https://huggingface.co/microsoft/deberta-v3-small",
},
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_small_en/1",
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_small_en/2",
},
"deberta_v3_base_en": {
"metadata": {
Expand All @@ -51,7 +51,7 @@
"path": "deberta_v3",
"model_card": "https://huggingface.co/microsoft/deberta-v3-base",
},
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_base_en/1",
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_base_en/2",
},
"deberta_v3_large_en": {
"metadata": {
Expand All @@ -64,7 +64,7 @@
"path": "deberta_v3",
"model_card": "https://huggingface.co/microsoft/deberta-v3-large",
},
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_large_en/1",
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_large_en/2",
},
"deberta_v3_base_multi": {
"metadata": {
Expand All @@ -77,6 +77,6 @@
"path": "deberta_v3",
"model_card": "https://huggingface.co/microsoft/mdeberta-v3-base",
},
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_base_multi/1",
"kaggle_handle": "kaggle://keras/deberta_v3/deberta_v3_base_multi/2",
},
}
6 changes: 3 additions & 3 deletions keras_nlp/models/distil_bert/distil_bert_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"path": "distil_bert",
"model_card": "https://huggingface.co/distilbert-base-uncased",
},
"kaggle_handle": "kaggle://keras/distil_bert/distil_bert_base_en_uncased/1",
"kaggle_handle": "kaggle://keras/distil_bert/distil_bert_base_en_uncased/2",
},
"distil_bert_base_en": {
"metadata": {
Expand All @@ -40,7 +40,7 @@
"path": "distil_bert",
"model_card": "https://huggingface.co/distilbert-base-cased",
},
"kaggle_handle": "kaggle://keras/distil_bert/distil_bert_base_en/1",
"kaggle_handle": "kaggle://keras/distil_bert/distil_bert_base_en/2",
},
"distil_bert_base_multi": {
"metadata": {
Expand All @@ -52,6 +52,6 @@
"path": "distil_bert",
"model_card": "https://huggingface.co/distilbert-base-multilingual-cased",
},
"kaggle_handle": "kaggle://keras/distil_bert/distil_bert_base_multi/1",
"kaggle_handle": "kaggle://keras/distil_bert/distil_bert_base_multi/2",
},
}
4 changes: 2 additions & 2 deletions keras_nlp/models/f_net/f_net_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"path": "f_net",
"model_card": "https://github.com/google-research/google-research/blob/master/f_net/README.md",
},
"kaggle_handle": "kaggle://keras/f_net/f_net_base_en/1",
"kaggle_handle": "kaggle://keras/f_net/f_net_base_en/2",
},
"f_net_large_en": {
"metadata": {
Expand All @@ -38,6 +38,6 @@
"path": "f_net",
"model_card": "https://github.com/google-research/google-research/blob/master/f_net/README.md",
},
"kaggle_handle": "kaggle://keras/f_net/f_net_large_en/1",
"kaggle_handle": "kaggle://keras/f_net/f_net_large_en/2",
},
}
10 changes: 5 additions & 5 deletions keras_nlp/models/gpt2/gpt2_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"path": "gpt2",
"model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
},
"kaggle_handle": "kaggle://keras/gpt2/gpt2_base_en/1",
"kaggle_handle": "kaggle://keras/gpt2/gpt2_base_en/2",
},
"gpt2_medium_en": {
"metadata": {
Expand All @@ -39,7 +39,7 @@
"path": "gpt2",
"model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
},
"kaggle_handle": "kaggle://keras/gpt2/gpt2_medium_en/1",
"kaggle_handle": "kaggle://keras/gpt2/gpt2_medium_en/2",
},
"gpt2_large_en": {
"metadata": {
Expand All @@ -52,7 +52,7 @@
"path": "gpt2",
"model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
},
"kaggle_handle": "kaggle://keras/gpt2/gpt2_large_en/1",
"kaggle_handle": "kaggle://keras/gpt2/gpt2_large_en/2",
},
"gpt2_extra_large_en": {
"metadata": {
Expand All @@ -65,7 +65,7 @@
"path": "gpt2",
"model_card": "https://github.com/openai/gpt-2/blob/master/model_card.md",
},
"kaggle_handle": "kaggle://keras/gpt2/gpt2_extra_large_en/1",
"kaggle_handle": "kaggle://keras/gpt2/gpt2_extra_large_en/2",
},
"gpt2_base_en_cnn_dailymail": {
"metadata": {
Expand All @@ -77,6 +77,6 @@
"official_name": "GPT-2",
"path": "gpt2",
},
"kaggle_handle": "kaggle://keras/gpt2/gpt2_base_en_cnn_dailymail/1",
"kaggle_handle": "kaggle://keras/gpt2/gpt2_base_en_cnn_dailymail/2",
},
}
8 changes: 4 additions & 4 deletions keras_nlp/models/opt/opt_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"path": "opt",
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
},
"kaggle_handle": "kaggle://keras/opt/opt_125m_en/1",
"kaggle_handle": "kaggle://keras/opt/opt_125m_en/2",
},
# We skip the 350m checkpoint because it does not match the structure of
# other checkpoints.
Expand All @@ -41,7 +41,7 @@
"path": "opt",
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
},
"kaggle_handle": "kaggle://keras/opt/opt_1.3b_en/1",
"kaggle_handle": "kaggle://keras/opt/opt_1.3b_en/2",
},
"opt_2.7b_en": {
"metadata": {
Expand All @@ -54,7 +54,7 @@
"path": "opt",
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
},
"kaggle_handle": "kaggle://keras/opt/opt_2.7b_en/1",
"kaggle_handle": "kaggle://keras/opt/opt_2.7b_en/2",
},
"opt_6.7b_en": {
"metadata": {
Expand All @@ -67,6 +67,6 @@
"path": "opt",
"model_card": "https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/model_card.md",
},
"kaggle_handle": "kaggle://keras/opt/opt_6.7b_en/1",
"kaggle_handle": "kaggle://keras/opt/opt_6.7b_en/2",
},
}
4 changes: 2 additions & 2 deletions keras_nlp/models/roberta/roberta_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"path": "roberta",
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
},
"kaggle_handle": "kaggle://keras/roberta/roberta_base_en/1",
"kaggle_handle": "kaggle://keras/roberta/roberta_base_en/2",
},
"roberta_large_en": {
"metadata": {
Expand All @@ -38,6 +38,6 @@
"path": "roberta",
"model_card": "https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.md",
},
"kaggle_handle": "kaggle://keras/roberta/roberta_large_en/1",
"kaggle_handle": "kaggle://keras/roberta/roberta_large_en/2",
},
}
Loading

0 comments on commit 401e569

Please sign in to comment.