Skip to content

Commit

Permalink
ModelHubMixin overwrite config if preexistant (#2142)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored Mar 25, 2024
1 parent 3252e27 commit a42c629
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 7 deletions.
7 changes: 6 additions & 1 deletion src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ def save_pretrained(
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)

# Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
# as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
# an existing config.json if it was not saved by `_save_pretrained`.
config_path = save_directory / CONFIG_NAME
config_path.unlink(missing_ok=True)

# save model weights/files (framework-specific)
self._save_pretrained(save_directory)

Expand All @@ -271,7 +277,6 @@ def save_pretrained(
if config is not None:
if is_dataclass(config):
config = asdict(config) # type: ignore[arg-type]
config_path = save_directory / CONFIG_NAME
if not config_path.exists():
config_str = json.dumps(config, sort_keys=True, indent=2)
config_path.write_text(config_str)
Expand Down
23 changes: 21 additions & 2 deletions tests/test_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,29 @@ def test_push_to_hub(self):
# Delete repo
self._api.delete_repo(repo_id=repo_id)

def test_save_pretrained_do_not_overwrite_config(self):
"""Regression test for https://github.com/huggingface/huggingface_hub/issues/2102."""
def test_save_pretrained_do_not_overwrite_new_config(self):
"""Regression test for https://github.com/huggingface/huggingface_hub/issues/2102.
If `_from_pretrained` does save a config file, we should not overwrite it.
"""
model = DummyModelSavingConfig()
model.save_pretrained(self.cache_dir)
# config.json is not overwritten
with open(self.cache_dir / "config.json") as f:
assert json.load(f) == {"custom_config": "custom_config"}

def test_save_pretrained_does_overwrite_legacy_config(self):
"""Regression test for https://github.com/huggingface/huggingface_hub/issues/2142.
If a previously existing config file exists, it should be overwritten.
"""
# Something existing in the cache dir
(self.cache_dir / "config.json").write_text(json.dumps({"something_legacy": 123}))

# Save model
model = DummyModelWithKwargs(a=1, b=2)
model.save_pretrained(self.cache_dir)

# config.json IS overwritten
with open(self.cache_dir / "config.json") as f:
assert json.load(f) == {"a": 1, "b": 2}
8 changes: 4 additions & 4 deletions tests/test_hub_mixin_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,10 @@ def forward(self, x):

# Linear layers should share weights and biases in memory
state_dict = reloaded.state_dict()
a_weight_ptr = state_dict["a.weight"].storage().data_ptr()
b_weight_ptr = state_dict["b.weight"].storage().data_ptr()
a_bias_ptr = state_dict["a.bias"].storage().data_ptr()
b_bias_ptr = state_dict["b.bias"].storage().data_ptr()
a_weight_ptr = state_dict["a.weight"].untyped_storage().data_ptr()
b_weight_ptr = state_dict["b.weight"].untyped_storage().data_ptr()
a_bias_ptr = state_dict["a.bias"].untyped_storage().data_ptr()
b_bias_ptr = state_dict["b.bias"].untyped_storage().data_ptr()
assert a_weight_ptr == b_weight_ptr
assert a_bias_ptr == b_bias_ptr

Expand Down

0 comments on commit a42c629

Please sign in to comment.