Skip to content

Commit 84a6b66

Browse files
authored
Only load a full task config when load_task_extras is passed (keras-team#1812)
* Only load a full task config when `load_task` is passed This switches the way we load task configuration and "head weights" to better accommodate upcoming vision models. For many vision models, like resnet trained on imagenet, or deeplabv3, we have head weights that some users may want but others will not. We need to add an option for loading head weights. With this change, we will be able to do the following... ```python classifier = ImageClassifier.from_preset("resnet50", num_classes=2) classifier = ImageClassifier.from_preset("resnet50", load_task=True) ``` We could do this other ways as well, or flip the default, but I think we need to add an option to control wether to load just the backbone with random weights, or loading the full task. * Try changing the name of the argument * Address review comments
1 parent a806571 commit 84a6b66

File tree

5 files changed

+108
-65
lines changed

5 files changed

+108
-65
lines changed

keras_nlp/src/models/preprocessor.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def presets(cls):
8989
def from_preset(
9090
cls,
9191
preset,
92+
load_task_extras=False,
9293
**kwargs,
9394
):
9495
"""Instantiate a `keras_nlp.models.Preprocessor` from a model preset.
@@ -112,6 +113,9 @@ def from_preset(
112113
Args:
113114
preset: string. A built-in preset identifier, a Kaggle Models
114115
handle, a Hugging Face handle, or a path to a local directory.
116+
load_task_extras: bool. If `True`, load the saved task preprocessing
117+
configuration from a `preprocessing.json`. You might use this to
118+
restore the sequence length a model was fine-tuned with.
115119
116120
Examples:
117121
```python
@@ -138,7 +142,7 @@ def from_preset(
138142
# Detect the correct subclass if we need to.
139143
if cls.backbone_cls != backbone_cls:
140144
cls = find_subclass(preset, cls, backbone_cls)
141-
return loader.load_preprocessor(cls, **kwargs)
145+
return loader.load_preprocessor(cls, load_task_extras, **kwargs)
142146

143147
def save_to_preset(self, preset_dir):
144148
"""Save preprocessor to a preset directory.

keras_nlp/src/models/preprocessor_test.py

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
import pathlib
1516

1617
import pytest
1718
from absl.testing import parameterized
@@ -31,10 +32,11 @@
3132
RobertaTextClassifierPreprocessor,
3233
)
3334
from keras_nlp.src.tests.test_case import TestCase
34-
from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
35+
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
36+
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
37+
SentencePieceTokenizer,
38+
)
3539
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
36-
from keras_nlp.src.utils.preset_utils import check_config_class
37-
from keras_nlp.src.utils.preset_utils import load_json
3840

3941

4042
class TestPreprocessor(TestCase):
@@ -80,45 +82,40 @@ def test_from_preset_errors(self):
8082
# TODO: Add more tests when we added a model that has `preprocessor.json`.
8183

8284
@parameterized.parameters(
83-
(
84-
AlbertTextClassifierPreprocessor,
85-
"albert_base_en_uncased",
86-
"sentencepiece",
87-
),
88-
(RobertaTextClassifierPreprocessor, "roberta_base_en", "bytepair"),
89-
(BertTextClassifierPreprocessor, "bert_tiny_en_uncased", "wordpiece"),
85+
(AlbertTextClassifierPreprocessor, "albert_base_en_uncased"),
86+
(RobertaTextClassifierPreprocessor, "roberta_base_en"),
87+
(BertTextClassifierPreprocessor, "bert_tiny_en_uncased"),
9088
)
9189
@pytest.mark.large
92-
def test_save_to_preset(self, cls, preset_name, tokenizer_type):
90+
def test_save_to_preset(self, cls, preset_name):
9391
save_dir = self.get_temp_dir()
94-
preprocessor = cls.from_preset(preset_name)
92+
preprocessor = cls.from_preset(preset_name, sequence_length=100)
93+
tokenizer = preprocessor.tokenizer
9594
preprocessor.save_to_preset(save_dir)
95+
# Save a backbone so the preset is valid.
96+
backbone = cls.backbone_cls.from_preset(preset_name, load_weights=False)
97+
backbone.save_to_preset(save_dir)
9698

97-
if tokenizer_type == "bytepair":
99+
if isinstance(tokenizer, BytePairTokenizer):
98100
vocab_filename = "vocabulary.json"
99-
expected_assets = [
100-
"vocabulary.json",
101-
"merges.txt",
102-
]
103-
elif tokenizer_type == "sentencepiece":
101+
expected_assets = ["vocabulary.json", "merges.txt"]
102+
elif isinstance(tokenizer, SentencePieceTokenizer):
104103
vocab_filename = "vocabulary.spm"
105104
expected_assets = ["vocabulary.spm"]
106105
else:
107106
vocab_filename = "vocabulary.txt"
108107
expected_assets = ["vocabulary.txt"]
109108

110109
# Check existence of vocab file.
111-
vocab_path = os.path.join(
112-
save_dir, os.path.join(TOKENIZER_ASSET_DIR, vocab_filename)
113-
)
110+
path = pathlib.Path(save_dir)
111+
vocab_path = path / TOKENIZER_ASSET_DIR / vocab_filename
114112
self.assertTrue(os.path.exists(vocab_path))
115113

116114
# Check assets.
117-
self.assertEqual(
118-
set(preprocessor.tokenizer.file_assets),
119-
set(expected_assets),
120-
)
115+
self.assertEqual(set(tokenizer.file_assets), set(expected_assets))
121116

122-
# Check config class.
123-
preprocessor_config = load_json(save_dir, PREPROCESSOR_CONFIG_FILE)
124-
self.assertEqual(cls, check_config_class(preprocessor_config))
117+
# Check restore.
118+
restored = cls.from_preset(save_dir, load_task_extras=True)
119+
self.assertEqual(preprocessor.get_config(), restored.get_config())
120+
restored = cls.from_preset(save_dir, load_task_extras=False)
121+
self.assertNotEqual(preprocessor.get_config(), restored.get_config())

keras_nlp/src/models/task.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def from_preset(
146146
cls,
147147
preset,
148148
load_weights=True,
149+
load_task_extras=False,
149150
**kwargs,
150151
):
151152
"""Instantiate a `keras_nlp.models.Task` from a model preset.
@@ -171,9 +172,13 @@ def from_preset(
171172
Args:
172173
preset: string. A built-in preset identifier, a Kaggle Models
173174
handle, a Hugging Face handle, or a path to a local directory.
174-
load_weights: bool. If `True`, the weights will be loaded into the
175-
model architecture. If `False`, the weights will be randomly
176-
initialized.
175+
load_weights: bool. If `True`, saved weights will be loaded into
176+
the model architecture. If `False`, all weights will be
177+
randomly initialized.
178+
load_task_extras: bool. If `True`, load the saved task configuration
179+
from a `task.json` and any task specific weights from
180+
`task.weights`. You might use this to load a classification
181+
head for a model that has been saved with it.
177182
178183
Examples:
179184
```python
@@ -201,13 +206,14 @@ def from_preset(
201206
# Detect the correct subclass if we need to.
202207
if cls.backbone_cls != backbone_cls:
203208
cls = find_subclass(preset, cls, backbone_cls)
204-
return loader.load_task(cls, load_weights, **kwargs)
209+
return loader.load_task(cls, load_weights, load_task_extras, **kwargs)
205210

206211
def load_task_weights(self, filepath):
207212
"""Load only the tasks specific weights not in the backbone."""
208213
if not str(filepath).endswith(".weights.h5"):
209214
raise ValueError(
210-
"The filename must end in `.weights.h5`. Received: filepath={filepath}"
215+
"The filename must end in `.weights.h5`. "
216+
f"Received: filepath={filepath}"
211217
)
212218
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
213219
keras.saving.load_weights(

keras_nlp/src/models/task_test.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
import pathlib
1617

1718
import keras
1819
import pytest
@@ -109,23 +110,16 @@ def test_summary_without_preprocessor(self):
109110
@pytest.mark.large
110111
def test_save_to_preset(self):
111112
save_dir = self.get_temp_dir()
112-
model = TextClassifier.from_preset(
113-
"bert_tiny_en_uncased", num_classes=2
114-
)
115-
model.save_to_preset(save_dir)
113+
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
114+
task.save_to_preset(save_dir)
116115

117116
# Check existence of files.
118-
self.assertTrue(os.path.exists(os.path.join(save_dir, CONFIG_FILE)))
119-
self.assertTrue(
120-
os.path.exists(os.path.join(save_dir, MODEL_WEIGHTS_FILE))
121-
)
122-
self.assertTrue(os.path.exists(os.path.join(save_dir, METADATA_FILE)))
123-
self.assertTrue(
124-
os.path.exists(os.path.join(save_dir, TASK_CONFIG_FILE))
125-
)
126-
self.assertTrue(
127-
os.path.exists(os.path.join(save_dir, TASK_WEIGHTS_FILE))
128-
)
117+
path = pathlib.Path(save_dir)
118+
self.assertTrue(os.path.exists(path / CONFIG_FILE))
119+
self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
120+
self.assertTrue(os.path.exists(path / METADATA_FILE))
121+
self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
122+
self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))
129123

130124
# Check the task config (`task.json`).
131125
task_config = load_json(save_dir, TASK_CONFIG_FILE)
@@ -138,13 +132,30 @@ def test_save_to_preset(self):
138132
self.assertEqual(BertTextClassifier, check_config_class(task_config))
139133

140134
# Try loading the model from preset directory.
141-
restored_model = TextClassifier.from_preset(save_dir)
135+
restored_task = TextClassifier.from_preset(
136+
save_dir, load_task_extras=True
137+
)
142138

143139
# Check the model output.
144140
data = ["the quick brown fox.", "the slow brown fox."]
145-
ref_out = model.predict(data)
146-
new_out = restored_model.predict(data)
147-
self.assertAllEqual(ref_out, new_out)
141+
ref_out = task.predict(data)
142+
new_out = restored_task.predict(data)
143+
self.assertAllClose(ref_out, new_out)
144+
145+
# Load without head weights.
146+
restored_task = TextClassifier.from_preset(
147+
save_dir, load_task_extras=False, num_classes=2
148+
)
149+
data = ["the quick brown fox.", "the slow brown fox."]
150+
# Full output unequal.
151+
ref_out = task.predict(data)
152+
new_out = restored_task.predict(data)
153+
self.assertNotAllClose(ref_out, new_out)
154+
# Backbone output equal.
155+
data = task.preprocessor(data)
156+
ref_out = task.backbone.predict(data)
157+
new_out = restored_task.backbone.predict(data)
158+
self.assertAllClose(ref_out, new_out)
148159

149160
@pytest.mark.large
150161
def test_none_preprocessor(self):

keras_nlp/src/utils/preset_utils.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def load_tokenizer(self, cls, **kwargs):
656656
"""Load a tokenizer layer from the preset."""
657657
raise NotImplementedError
658658

659-
def load_task(self, cls, load_weights, **kwargs):
659+
def load_task(self, cls, load_weights, load_task_extras, **kwargs):
660660
"""Load a task model from the preset.
661661
662662
By default, we create a task from a backbone and preprocessor with
@@ -671,11 +671,12 @@ def load_task(self, cls, load_weights, **kwargs):
671671
)
672672
if "preprocessor" not in kwargs:
673673
kwargs["preprocessor"] = self.load_preprocessor(
674-
cls.preprocessor_cls
674+
cls.preprocessor_cls,
675+
load_task_extras=load_task_extras,
675676
)
676677
return cls(**kwargs)
677678

678-
def load_preprocessor(self, cls, **kwargs):
679+
def load_preprocessor(self, cls, load_task_extras, **kwargs):
679680
"""Load a prepocessor layer from the preset.
680681
681682
By default, we create a preprocessor from a tokenizer with default
@@ -704,35 +705,59 @@ def load_tokenizer(self, cls, **kwargs):
704705
tokenizer.load_preset_assets(self.preset)
705706
return tokenizer
706707

707-
def load_task(self, cls, load_weights, **kwargs):
708+
def load_task(self, cls, load_weights, load_task_extras, **kwargs):
708709
# If there is no `task.json` or it's for the wrong class delegate to the
709710
# super class loader.
711+
if not load_task_extras:
712+
return super().load_task(
713+
cls, load_weights, load_task_extras, **kwargs
714+
)
710715
if not check_file_exists(self.preset, TASK_CONFIG_FILE):
711-
return super().load_task(cls, load_weights, **kwargs)
716+
raise ValueError(
717+
"Saved preset has no `task.json`, cannot load the task config "
718+
"from a file. Call `from_preset()` with "
719+
"`load_task_extras=False` to load the task from a backbone "
720+
"with library defaults."
721+
)
712722
task_config = load_json(self.preset, TASK_CONFIG_FILE)
713723
if not issubclass(check_config_class(task_config), cls):
714-
return super().load_task(cls, load_weights, **kwargs)
724+
raise ValueError(
725+
f"Saved `task.json`does not match calling cls {cls}. Call "
726+
"`from_preset()` with `load_task_extras=False` to load the "
727+
"task from a backbone with library defaults."
728+
)
715729
# We found a `task.json` with a complete config for our class.
716730
task = load_serialized_object(task_config, **kwargs)
717731
if task.preprocessor is not None:
718732
task.preprocessor.tokenizer.load_preset_assets(self.preset)
719733
if load_weights:
720-
jax_memory_cleanup(task)
721734
if check_file_exists(self.preset, TASK_WEIGHTS_FILE):
735+
jax_memory_cleanup(task)
722736
task_weights = get_file(self.preset, TASK_WEIGHTS_FILE)
723737
task.load_task_weights(task_weights)
738+
else:
739+
jax_memory_cleanup(task.backbone)
724740
backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
725741
task.backbone.load_weights(backbone_weights)
726742
return task
727743

728-
def load_preprocessor(self, cls, **kwargs):
729-
# If there is no `preprocessing.json` or it's for the wrong class,
730-
# delegate to the super class loader.
744+
def load_preprocessor(self, cls, load_task_extras, **kwargs):
745+
if not load_task_extras:
746+
return super().load_preprocessor(cls, load_task_extras, **kwargs)
731747
if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE):
732-
return super().load_preprocessor(cls, **kwargs)
748+
raise ValueError(
749+
"Saved preset has no `preprocessor.json`, cannot load the task "
750+
"preprocessing config from a file. Call `from_preset()` with "
751+
"`load_task_extras=False` to load the preprocessor with "
752+
"library defaults."
753+
)
733754
preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE)
734755
if not issubclass(check_config_class(preprocessor_json), cls):
735-
return super().load_preprocessor(cls, **kwargs)
756+
raise ValueError(
757+
f"Saved `preprocessor.json`does not match calling cls {cls}. "
758+
"Call `from_preset()` with `load_task_extras=False` to "
759+
"load the the preprocessor with library defaults."
760+
)
736761
# We found a `preprocessing.json` with a complete config for our class.
737762
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
738763
preprocessor.tokenizer.load_preset_assets(self.preset)

0 commit comments

Comments
 (0)