From 05585741416faf1ca5364da63c450ae56adbfb9b Mon Sep 17 00:00:00 2001 From: Lucain Date: Tue, 19 Mar 2024 14:33:04 +0100 Subject: [PATCH] Fix ModelHubMixin: save config only if doesn't exist (#2105) * Fix ModelHubMixin: save config only if doesn't exist * sutff --- src/huggingface_hub/hub_mixin.py | 7 +++++-- tests/test_hub_mixin.py | 17 +++++++++++++++++ tests/test_hub_mixin_pytorch.py | 2 ++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 31c7ace27b..41f74b0a01 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -248,13 +248,16 @@ def save_pretrained( # save model weights/files (framework-specific) self._save_pretrained(save_directory) - # save config (if provided) + # save config (if provided and if not serialized yet in `_save_pretrained`) if config is None: config = self.config if config is not None: if is_dataclass(config): config = asdict(config) # type: ignore[arg-type] - (save_directory / CONFIG_NAME).write_text(json.dumps(config, sort_keys=True, indent=2)) + config_path = save_directory / CONFIG_NAME + if not config_path.exists(): + config_str = json.dumps(config, sort_keys=True, indent=2) + config_path.write_text(config_str) # save model card model_card_path = save_directory / "README.md" diff --git a/tests/test_hub_mixin.py b/tests/test_hub_mixin.py index 4ced211876..12eb9ac7d9 100644 --- a/tests/test_hub_mixin.py +++ b/tests/test_hub_mixin.py @@ -76,6 +76,15 @@ def __init__(self, **kwargs): pass +class DummyModelSavingConfig(ModelHubMixin): + def _save_pretrained(self, save_directory: Path) -> None: + """Implementation that uses `config.json` to serialize the config. + + This file must not be overwritten by the default config saved by `ModelHubMixin`. + """ + (save_directory / "config.json").write_text(json.dumps({"custom_config": "custom_config"})) + + @pytest.mark.usefixtures("fx_cache_dir") class HubMixinTest(unittest.TestCase): cache_dir: Path @@ -262,3 +271,11 @@ def test_push_to_hub(self): # Delete repo self._api.delete_repo(repo_id=repo_id) + + def test_save_pretrained_do_not_overwrite_config(self): + """Regression test for https://github.com/huggingface/huggingface_hub/issues/2102.""" + model = DummyModelSavingConfig() + model.save_pretrained(self.cache_dir) + # config.json is not overwritten + with open(self.cache_dir / "config.json") as f: + assert json.load(f) == {"custom_config": "custom_config"} diff --git a/tests/test_hub_mixin_pytorch.py b/tests/test_hub_mixin_pytorch.py index 0f0a640464..f70074140e 100644 --- a/tests/test_hub_mixin_pytorch.py +++ b/tests/test_hub_mixin_pytorch.py @@ -290,6 +290,7 @@ def test_load_no_config(self): assert reloaded_with_default.state == "other" assert reloaded_with_default.config == {"num_classes": 50, "state": "other"} + config_file.unlink() # Remove config file reloaded_with_default.save_pretrained(self.cache_dir) assert json.loads(config_file.read_text()) == {"num_classes": 50, "state": "other"} @@ -307,6 +308,7 @@ def test_save_with_non_jsonable_config(self): assert "not_jsonable" not in model.config # If jsonable value passed by user, it's saved in the config + (self.cache_dir / "config.json").unlink() new_model = DummyModelNoConfig(not_jsonable=123) new_model.save_pretrained(self.cache_dir) assert new_model.config["not_jsonable"] == 123