Skip to content

Commit

Permalink
Fix: correctly encode/decode config in ModelHubMixin if custom coders (
Browse files Browse the repository at this point in the history
…#2337)

* Fix: correctly encode/decode config in ModelHubMixin if custom coders

* make style

* make quality

* Update tests/test_hub_mixin_pytorch.py

Co-authored-by: Hafedh <70411813+not-lain@users.noreply.github.com>

---------

Co-authored-by: Hafedh <70411813+not-lain@users.noreply.github.com>
  • Loading branch information
Wauplin and not-lain authored Jun 14, 2024
1 parent a626cc3 commit c013756
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 19 deletions.
6 changes: 3 additions & 3 deletions docs/source/ko/guides/integrations.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,9 @@ from argparse import Namespace

class VoiceCraft(
nn.Module,
PytorchModelHubMixin, # 믹스인을 상속합니다.
coders: {
Namespace = (
PyTorchModelHubMixin, # 믹스인을 상속합니다.
coders={
Namespace: (
lambda x: vars(x), # Encoder: `Namespace`를 유효한 JSON 형태로 변환하는 방법은 무엇인가요?
lambda data: Namespace(**data), # Decoder: 딕셔너리에서 Namespace를 재구성하는 방법은 무엇인가요?
)
Expand Down
26 changes: 10 additions & 16 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Type,
TypeVar,
Union,
get_args,
)

from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
Expand Down Expand Up @@ -326,12 +325,11 @@ def __new__(cls, *args, **kwargs) -> "ModelHubMixin":
if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder
},
}
init_config.pop("config", {})
passed_config = init_config.pop("config", {})

# Populate `init_config` with provided config
provided_config = passed_values.get("config")
if isinstance(provided_config, dict):
init_config.update(provided_config)
if isinstance(passed_config, dict):
init_config.update(passed_config)

# Set `config` attribute and return
if init_config != {}:
Expand Down Expand Up @@ -362,9 +360,14 @@ def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T
if value is None:
return None
expected_type = unwrap_simple_optional_type(expected_type)
# Dataclass => handle it
if is_dataclass(expected_type):
return _load_dataclass(expected_type, value) # type: ignore[return-value]
# Otherwise => check custom decoders
for type_, (_, decoder) in cls._hub_mixin_coders.items():
if inspect.isclass(expected_type) and issubclass(expected_type, type_):
return decoder(value)
# Otherwise => don't decode
return value

def save_pretrained(
Expand Down Expand Up @@ -531,18 +534,9 @@ def from_pretrained(

# Check if `config` argument was passed at init
if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs:
# Check if `config` argument is a dataclass
# Decode `config` argument if it was passed
config_annotation = cls._hub_mixin_init_parameters["config"].annotation
if config_annotation is inspect.Parameter.empty:
pass # no annotation
elif is_dataclass(config_annotation):
config = _load_dataclass(config_annotation, config)
else:
# if Optional/Union annotation => check if a dataclass is in the Union
for _sub_annotation in get_args(config_annotation):
if is_dataclass(_sub_annotation):
config = _load_dataclass(_sub_annotation, config)
break
config = cls._decode_arg(config_annotation, config)

# Forward config to model initialization
model_kwargs["config"] = config
Expand Down
32 changes: 32 additions & 0 deletions tests/test_hub_mixin_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import struct
import unittest
from argparse import Namespace
from pathlib import Path
from typing import Any, Dict, Optional, TypeVar
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -95,6 +96,21 @@ class DummyModelWithModelCardAndCustomKwargs(
def __init__(self, linear_layer: int = 4):
super().__init__()

class DummyModelWithEncodedConfig(
nn.Module,
PyTorchModelHubMixin,
coders={
Namespace: (
lambda x: vars(x),
lambda data: Namespace(**data),
)
},
):
# Regression test for https://github.com/huggingface/huggingface_hub/issues/2334
def __init__(self, config: Namespace):
super().__init__()
self.config = config

else:
DummyModel = None
DummyModelWithModelCard = None
Expand Down Expand Up @@ -419,3 +435,19 @@ def test_model_card_with_custom_kwargs(self):
model.save_pretrained(self.cache_dir, model_card_kwargs=model_card_kwargs)
card_reloaded = ModelCard.load(self.cache_dir / "README.md")
assert str(card) == str(card_reloaded)

def test_config_with_custom_coders(self):
"""
Regression test for #2334. When `config` is encoded with custom coders, it should be decoded correctly.
See https://github.com/huggingface/huggingface_hub/issues/2334.
"""
model = DummyModelWithEncodedConfig(Namespace(a=1, b=2))
model.save_pretrained(self.cache_dir)
assert model._hub_mixin_config["a"] == 1
assert model._hub_mixin_config["b"] == 2

reloaded = DummyModelWithEncodedConfig.from_pretrained(self.cache_dir)
assert isinstance(reloaded.config, Namespace)
assert reloaded.config.a == 1
assert reloaded.config.b == 2

0 comments on commit c013756

Please sign in to comment.