Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModelHubMixin overwrite config if preexistant #2142

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,12 @@ def save_pretrained(
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)

# Remove config.json if already exists. After `_save_pretrained` we don't want to overwrite config.json
# as it might have been saved by the custom `_save_pretrained` already. However we do want to overwrite
# an existing config.json if it was not saved by `_save_pretrained`.
config_path = save_directory / CONFIG_NAME
config_path.unlink(missing_ok=True)

# save model weights/files (framework-specific)
self._save_pretrained(save_directory)

Expand All @@ -271,7 +277,6 @@ def save_pretrained(
if config is not None:
if is_dataclass(config):
config = asdict(config) # type: ignore[arg-type]
config_path = save_directory / CONFIG_NAME
if not config_path.exists():
config_str = json.dumps(config, sort_keys=True, indent=2)
config_path.write_text(config_str)
Expand Down
23 changes: 21 additions & 2 deletions tests/test_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,29 @@ def test_push_to_hub(self):
# Delete repo
self._api.delete_repo(repo_id=repo_id)

def test_save_pretrained_do_not_overwrite_config(self):
"""Regression test for https://github.com/huggingface/huggingface_hub/issues/2102."""
def test_save_pretrained_do_not_overwrite_new_config(self):
"""Regression test for https://github.com/huggingface/huggingface_hub/issues/2102.

If `_from_pretrained` does save a config file, we should not overwrite it.
"""
model = DummyModelSavingConfig()
model.save_pretrained(self.cache_dir)
# config.json is not overwritten
with open(self.cache_dir / "config.json") as f:
assert json.load(f) == {"custom_config": "custom_config"}

def test_save_pretrained_does_overwrite_legacy_config(self):
"""Regression test for https://github.com/huggingface/huggingface_hub/issues/2142.

If a previously existing config file exists, it should be overwritten.
"""
# Something existing in the cache dir
(self.cache_dir / "config.json").write_text(json.dumps({"something_legacy": 123}))

# Save model
model = DummyModelWithKwargs(a=1, b=2)
model.save_pretrained(self.cache_dir)

# config.json IS overwritten
with open(self.cache_dir / "config.json") as f:
assert json.load(f) == {"a": 1, "b": 2}
8 changes: 4 additions & 4 deletions tests/test_hub_mixin_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,10 +346,10 @@ def forward(self, x):

# Linear layers should share weights and biases in memory
state_dict = reloaded.state_dict()
a_weight_ptr = state_dict["a.weight"].storage().data_ptr()
b_weight_ptr = state_dict["b.weight"].storage().data_ptr()
a_bias_ptr = state_dict["a.bias"].storage().data_ptr()
b_bias_ptr = state_dict["b.bias"].storage().data_ptr()
a_weight_ptr = state_dict["a.weight"].untyped_storage().data_ptr()
b_weight_ptr = state_dict["b.weight"].untyped_storage().data_ptr()
a_bias_ptr = state_dict["a.bias"].untyped_storage().data_ptr()
b_bias_ptr = state_dict["b.bias"].untyped_storage().data_ptr()
assert a_weight_ptr == b_weight_ptr
assert a_bias_ptr == b_bias_ptr

Expand Down
Loading