Skip to content

Commit 20c56dc

Browse files
committed
Try changing the name of the argument
1 parent cd6d840 commit 20c56dc

File tree

5 files changed

+27
-26
lines changed

5 files changed

+27
-26
lines changed

keras_nlp/src/models/preprocessor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def presets(cls):
8989
def from_preset(
9090
cls,
9191
preset,
92-
load_task=False,
92+
load_task_extras=False,
9393
**kwargs,
9494
):
9595
"""Instantiate a `keras_nlp.models.Preprocessor` from a model preset.
@@ -113,7 +113,7 @@ def from_preset(
113113
Args:
114114
preset: string. A built in preset identifier, a Kaggle Models
115115
handle, a Hugging Face handle, or a path to a local directory.
116-
load_task: bool. If `True`, load the saved task preprocessing
116+
load_task_extras: bool. If `True`, load the saved task preprocessing
117117
configuration from a `preprocessing.json`. You might use this to
118118
restore the sequence length a model was fine-tuned with.
119119
@@ -142,7 +142,7 @@ def from_preset(
142142
# Detect the correct subclass if we need to.
143143
if cls.backbone_cls != backbone_cls:
144144
cls = find_subclass(preset, cls, backbone_cls)
145-
return loader.load_preprocessor(cls, load_task, **kwargs)
145+
return loader.load_preprocessor(cls, load_task_extras, **kwargs)
146146

147147
def save_to_preset(self, preset_dir):
148148
"""Save preprocessor to a preset directory.

keras_nlp/src/models/preprocessor_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_save_to_preset(self, cls, preset_name):
115115
self.assertEqual(set(tokenizer.file_assets), set(expected_assets))
116116

117117
# Check restore.
118-
restored = cls.from_preset(save_dir, load_task=True)
118+
restored = cls.from_preset(save_dir, load_task_extras=True)
119119
self.assertEqual(preprocessor.get_config(), restored.get_config())
120-
restored = cls.from_preset(save_dir, load_task=False)
120+
restored = cls.from_preset(save_dir, load_task_extras=False)
121121
self.assertNotEqual(preprocessor.get_config(), restored.get_config())

keras_nlp/src/models/task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def from_preset(
146146
cls,
147147
preset,
148148
load_weights=True,
149-
load_task=False,
149+
load_task_extras=False,
150150
**kwargs,
151151
):
152152
"""Instantiate a `keras_nlp.models.Task` from a model preset.
@@ -175,7 +175,7 @@ def from_preset(
175175
load_weights: bool. If `True`, the backbone weights will be loaded
176176
into the model architecture. If `False`, the weights will be
177177
randomly initialized.
178-
load_task: bool. If `True`, load the saved task configuration
178+
load_task_extras: bool. If `True`, load the saved task configuration
179179
from a `task.json` and any task specific weights from
180180
`task.weights`. You might use this to load a classification
181181
head for a model that has been saved with it.
@@ -206,7 +206,7 @@ def from_preset(
206206
# Detect the correct subclass if we need to.
207207
if cls.backbone_cls != backbone_cls:
208208
cls = find_subclass(preset, cls, backbone_cls)
209-
return loader.load_task(cls, load_weights, load_task, **kwargs)
209+
return loader.load_task(cls, load_weights, load_task_extras, **kwargs)
210210

211211
def load_task_weights(self, filepath):
212212
"""Load only the tasks specific weights not in the backbone."""

keras_nlp/src/models/task_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def test_save_to_preset(self):
132132
self.assertEqual(BertTextClassifier, check_config_class(task_config))
133133

134134
# Try loading the model from preset directory.
135-
restored_task = TextClassifier.from_preset(save_dir, load_task=True)
135+
restored_task = TextClassifier.from_preset(save_dir, load_task_extras=True)
136136

137137
# Check the model output.
138138
data = ["the quick brown fox.", "the slow brown fox."]
@@ -142,7 +142,7 @@ def test_save_to_preset(self):
142142

143143
# Load without head weights.
144144
restored_task = TextClassifier.from_preset(
145-
save_dir, load_task=False, num_classes=2
145+
save_dir, load_task_extras=False, num_classes=2
146146
)
147147
data = ["the quick brown fox.", "the slow brown fox."]
148148
# Full output unequal.

keras_nlp/src/utils/preset_utils.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ def load_tokenizer(self, cls, **kwargs):
656656
"""Load a tokenizer layer from the preset."""
657657
raise NotImplementedError
658658

659-
def load_task(self, cls, load_weights, load_task, **kwargs):
659+
def load_task(self, cls, load_weights, load_task_extras, **kwargs):
660660
"""Load a task model from the preset.
661661
662662
By default, we create a task from a backbone and preprocessor with
@@ -672,11 +672,11 @@ def load_task(self, cls, load_weights, load_task, **kwargs):
672672
if "preprocessor" not in kwargs:
673673
kwargs["preprocessor"] = self.load_preprocessor(
674674
cls.preprocessor_cls,
675-
load_task=load_task,
675+
load_task_extras=load_task_extras,
676676
)
677677
return cls(**kwargs)
678678

679-
def load_preprocessor(self, cls, load_task, **kwargs):
679+
def load_preprocessor(self, cls, load_task_extras, **kwargs):
680680
"""Load a prepocessor layer from the preset.
681681
682682
By default, we create a preprocessor from a tokenizer with default
@@ -705,23 +705,24 @@ def load_tokenizer(self, cls, **kwargs):
705705
tokenizer.load_preset_assets(self.preset)
706706
return tokenizer
707707

708-
def load_task(self, cls, load_weights, load_task, **kwargs):
708+
def load_task(self, cls, load_weights, load_task_extras, **kwargs):
709709
# If there is no `task.json` or it's for the wrong class delegate to the
710710
# super class loader.
711-
if not load_task:
712-
return super().load_task(cls, load_weights, load_task, **kwargs)
711+
if not load_task_extras:
712+
return super().load_task(cls, load_weights, load_task_extras, **kwargs)
713713
if not check_file_exists(self.preset, TASK_CONFIG_FILE):
714714
raise ValueError(
715715
"Saved preset has no `task.json`, cannot load the task config "
716-
"from a file. Call `from_preset()` with `load_task=False` to "
717-
"load the task from a backbone with library defaults."
716+
"from a file. Call `from_preset()` with "
717+
"`load_task_extras=False` to load the task from a backbone "
718+
"with library defaults."
718719
)
719720
task_config = load_json(self.preset, TASK_CONFIG_FILE)
720721
if not issubclass(check_config_class(task_config), cls):
721722
raise ValueError(
722723
f"Saved `task.json`does not match calling cls {cls}. Call "
723-
"`from_preset()` with `load_task=False` to load the task from "
724-
"a backbone with library defaults."
724+
"`from_preset()` with `load_task_extras=False` to load the "
725+
"task from a backbone with library defaults."
725726
)
726727
# We found a `task.json` with a complete config for our class.
727728
task = load_serialized_object(task_config, **kwargs)
@@ -736,21 +737,21 @@ def load_task(self, cls, load_weights, load_task, **kwargs):
736737
task.backbone.load_weights(backbone_weights)
737738
return task
738739

739-
def load_preprocessor(self, cls, load_task, **kwargs):
740-
if not load_task:
741-
return super().load_preprocessor(cls, load_task, **kwargs)
740+
def load_preprocessor(self, cls, load_task_extras, **kwargs):
741+
if not load_task_extras:
742+
return super().load_preprocessor(cls, load_task_extras, **kwargs)
742743
if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE):
743744
raise ValueError(
744745
"Saved preset has no `preprocessor.json`, cannot load the task "
745746
"preprocessing config from a file. Call `from_preset()` with "
746-
"`load_task=False` to load the preprocessor with library "
747-
"defaults."
747+
"`load_task_extras=False` to load the preprocessor with "
748+
"library defaults."
748749
)
749750
preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE)
750751
if not issubclass(check_config_class(preprocessor_json), cls):
751752
raise ValueError(
752753
f"Saved `preprocessor.json`does not match calling cls {cls}. "
753-
"Call `from_preset()` with `load_task=False` to "
754+
"Call `from_preset()` with `load_task_extras=False` to "
754755
"load the the preprocessor with library defaults."
755756
)
756757
# We found a `preprocessing.json` with a complete config for our class.

0 commit comments

Comments
 (0)