From c013756e90d7f46b4c7dbdba6c01138d914e5070 Mon Sep 17 00:00:00 2001 From: Lucain Date: Fri, 14 Jun 2024 16:04:41 +0200 Subject: [PATCH] Fix: correctly encode/decode config in ModelHubMixin if custom coders (#2337) * Fix: correctly encode/decode config in ModelHubMixin if custom coders * make style * make quality * Update tests/test_hub_mixin_pytorch.py Co-authored-by: Hafedh <70411813+not-lain@users.noreply.github.com> --------- Co-authored-by: Hafedh <70411813+not-lain@users.noreply.github.com> --- docs/source/ko/guides/integrations.md | 6 ++--- src/huggingface_hub/hub_mixin.py | 26 +++++++++------------- tests/test_hub_mixin_pytorch.py | 32 +++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 19 deletions(-) diff --git a/docs/source/ko/guides/integrations.md b/docs/source/ko/guides/integrations.md index 729b3179b7..f251785a07 100644 --- a/docs/source/ko/guides/integrations.md +++ b/docs/source/ko/guides/integrations.md @@ -365,9 +365,9 @@ from argparse import Namespace class VoiceCraft( nn.Module, - PytorchModelHubMixin, # 믹스인을 상속합니다. - coders: { - Namespace = ( + PyTorchModelHubMixin, # 믹스인을 상속합니다. + coders={ + Namespace: ( lambda x: vars(x), # Encoder: `Namespace`를 유효한 JSON 형태로 변환하는 방법은 무엇인가요? lambda data: Namespace(**data), # Decoder: 딕셔너리에서 Namespace를 재구성하는 방법은 무엇인가요? ) diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index b9de32679e..2b9a895602 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -15,7 +15,6 @@ Type, TypeVar, Union, - get_args, ) from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE @@ -326,12 +325,11 @@ def __new__(cls, *args, **kwargs) -> "ModelHubMixin": if instance._is_jsonable(value) # Only if jsonable or we have a custom encoder }, } - init_config.pop("config", {}) + passed_config = init_config.pop("config", {}) # Populate `init_config` with provided config - provided_config = passed_values.get("config") - if isinstance(provided_config, dict): - init_config.update(provided_config) + if isinstance(passed_config, dict): + init_config.update(passed_config) # Set `config` attribute and return if init_config != {}: @@ -362,9 +360,14 @@ def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T if value is None: return None expected_type = unwrap_simple_optional_type(expected_type) + # Dataclass => handle it + if is_dataclass(expected_type): + return _load_dataclass(expected_type, value) # type: ignore[return-value] + # Otherwise => check custom decoders for type_, (_, decoder) in cls._hub_mixin_coders.items(): if inspect.isclass(expected_type) and issubclass(expected_type, type_): return decoder(value) + # Otherwise => don't decode return value def save_pretrained( @@ -531,18 +534,9 @@ def from_pretrained( # Check if `config` argument was passed at init if "config" in cls._hub_mixin_init_parameters and "config" not in model_kwargs: - # Check if `config` argument is a dataclass + # Decode `config` argument if it was passed config_annotation = cls._hub_mixin_init_parameters["config"].annotation - if config_annotation is inspect.Parameter.empty: - pass # no annotation - elif is_dataclass(config_annotation): - config = _load_dataclass(config_annotation, config) - else: - # if Optional/Union annotation => check if a dataclass is in the Union - for _sub_annotation in get_args(config_annotation): - if is_dataclass(_sub_annotation): - config = _load_dataclass(_sub_annotation, config) - break + config = cls._decode_arg(config_annotation, config) # Forward config to model initialization model_kwargs["config"] = config diff --git a/tests/test_hub_mixin_pytorch.py b/tests/test_hub_mixin_pytorch.py index cb4d5b4fd8..09f4a67b47 100644 --- a/tests/test_hub_mixin_pytorch.py +++ b/tests/test_hub_mixin_pytorch.py @@ -2,6 +2,7 @@ import os import struct import unittest +from argparse import Namespace from pathlib import Path from typing import Any, Dict, Optional, TypeVar from unittest.mock import Mock, patch @@ -95,6 +96,21 @@ class DummyModelWithModelCardAndCustomKwargs( def __init__(self, linear_layer: int = 4): super().__init__() + class DummyModelWithEncodedConfig( + nn.Module, + PyTorchModelHubMixin, + coders={ + Namespace: ( + lambda x: vars(x), + lambda data: Namespace(**data), + ) + }, + ): + # Regression test for https://github.com/huggingface/huggingface_hub/issues/2334 + def __init__(self, config: Namespace): + super().__init__() + self.config = config + else: DummyModel = None DummyModelWithModelCard = None @@ -419,3 +435,19 @@ def test_model_card_with_custom_kwargs(self): model.save_pretrained(self.cache_dir, model_card_kwargs=model_card_kwargs) card_reloaded = ModelCard.load(self.cache_dir / "README.md") assert str(card) == str(card_reloaded) + + def test_config_with_custom_coders(self): + """ + Regression test for #2334. When `config` is encoded with custom coders, it should be decoded correctly. + + See https://github.com/huggingface/huggingface_hub/issues/2334. + """ + model = DummyModelWithEncodedConfig(Namespace(a=1, b=2)) + model.save_pretrained(self.cache_dir) + assert model._hub_mixin_config["a"] == 1 + assert model._hub_mixin_config["b"] == 2 + + reloaded = DummyModelWithEncodedConfig.from_pretrained(self.cache_dir) + assert isinstance(reloaded.config, Namespace) + assert reloaded.config.a == 1 + assert reloaded.config.b == 2