diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index 675112fd99..28e347d434 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -4,7 +4,19 @@ import warnings from dataclasses import asdict, dataclass, is_dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, get_args +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, +) from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE from .file_download import hf_hub_download @@ -16,8 +28,10 @@ SoftTemporaryDirectory, is_jsonable, is_safetensors_available, + is_simple_optional_type, is_torch_available, logging, + unwrap_simple_optional_type, validate_hf_hub_args, ) @@ -336,14 +350,20 @@ def _encode_arg(cls, arg: Any) -> Any: """Encode an argument into a JSON serializable format.""" for type_, (encoder, _) in cls._hub_mixin_coders.items(): if isinstance(arg, type_): + if arg is None: + return None return encoder(arg) return arg @classmethod - def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> ARGS_T: + def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]: """Decode a JSON serializable value into an argument.""" + if is_simple_optional_type(expected_type): + if value is None: + return None + expected_type = unwrap_simple_optional_type(expected_type) for type_, (_, decoder) in cls._hub_mixin_coders.items(): - if issubclass(expected_type, type_): + if inspect.isclass(expected_type) and issubclass(expected_type, type_): return decoder(value) return value diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index b7bd497342..cb90b5e6b1 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -113,7 +113,7 @@ from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess from ._telemetry import send_telemetry from ._token import get_token -from ._typing import is_jsonable +from ._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type from ._validators import ( smoothly_deprecate_use_auth_token, validate_hf_hub_args, diff --git a/src/huggingface_hub/utils/_typing.py b/src/huggingface_hub/utils/_typing.py index ae502b825b..b28a68ec12 100644 --- a/src/huggingface_hub/utils/_typing.py +++ b/src/huggingface_hub/utils/_typing.py @@ -14,7 +14,15 @@ # limitations under the License. """Handle typing imports based on system compatibility.""" -from typing import Any, Callable, Literal, TypeVar +import sys +from typing import Any, Callable, List, Literal, Type, TypeVar, Union, get_args, get_origin + + +UNION_TYPES: List[Any] = [Union] +if sys.version_info >= (3, 10): + from types import UnionType + + UNION_TYPES += [UnionType] HTTP_METHOD_T = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"] @@ -48,3 +56,20 @@ def is_jsonable(obj: Any) -> bool: return False except RecursionError: return False + + +def is_simple_optional_type(type_: Type) -> bool: + """Check if a type is optional, i.e. Optional[Type] or Union[Type, None] or Type | None, where Type is a non-composite type.""" + if get_origin(type_) in UNION_TYPES: + union_args = get_args(type_) + if len(union_args) == 2 and type(None) in union_args: + return True + return False + + +def unwrap_simple_optional_type(optional_type: Type) -> Type: + """Unwraps a simple optional type, i.e. returns Type from Optional[Type].""" + for arg in get_args(optional_type): + if arg is not type(None): + return arg + raise ValueError(f"'{optional_type}' is not an optional type") diff --git a/tests/test_hub_mixin.py b/tests/test_hub_mixin.py index 63da9adb17..9b711adb78 100644 --- a/tests/test_hub_mixin.py +++ b/tests/test_hub_mixin.py @@ -144,11 +144,22 @@ class DummyModelWithCustomTypes( }, ): def __init__( - self, foo: int, bar: str, custom: CustomType, custom_default: CustomType = CustomType("default"), **kwargs + self, + foo: int, + bar: str, + baz: Union[int, str], + custom: CustomType, + optional_custom_1: Optional[CustomType], + optional_custom_2: Optional[CustomType], + custom_default: CustomType = CustomType("default"), + **kwargs, ): self.foo = foo self.bar = bar + self.baz = baz self.custom = custom + self.optional_custom_1 = optional_custom_1 + self.optional_custom_2 = optional_custom_2 self.custom_default = custom_default @classmethod @@ -406,21 +417,34 @@ def test_from_pretrained_when_cls_is_a_dataclass(self): assert not hasattr(model, "other") def test_from_cls_with_custom_type(self): - model = DummyModelWithCustomTypes(1, bar="bar", custom=CustomType("custom")) + model = DummyModelWithCustomTypes( + 1, + bar="bar", + baz=1.0, + custom=CustomType("custom"), + optional_custom_1=CustomType("optional"), + optional_custom_2=None, + ) model.save_pretrained(self.cache_dir) config = json.loads((self.cache_dir / "config.json").read_text()) assert config == { "foo": 1, "bar": "bar", + "baz": 1.0, "custom": {"value": "custom"}, + "optional_custom_1": {"value": "optional"}, + "optional_custom_2": None, "custom_default": {"value": "default"}, } model_reloaded = DummyModelWithCustomTypes.from_pretrained(self.cache_dir) assert model_reloaded.foo == 1 assert model_reloaded.bar == "bar" + assert model_reloaded.baz == 1.0 assert model_reloaded.custom.value == "custom" + assert model_reloaded.optional_custom_1 is not None and model_reloaded.optional_custom_1.value == "optional" + assert model_reloaded.optional_custom_2 is None assert model_reloaded.custom_default.value == "default" def test_inherited_class(self): diff --git a/tests/test_utils_typing.py b/tests/test_utils_typing.py index afc148848c..9e4ebb43e2 100644 --- a/tests/test_utils_typing.py +++ b/tests/test_utils_typing.py @@ -1,14 +1,20 @@ import json +import sys +from typing import Optional, Type, Union import pytest -from huggingface_hub.utils._typing import is_jsonable +from huggingface_hub.utils._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type class NotSerializableClass: pass +class CustomType: + pass + + OBJ_WITH_CIRCULAR_REF = {"hello": "world"} OBJ_WITH_CIRCULAR_REF["recursive"] = OBJ_WITH_CIRCULAR_REF @@ -47,3 +53,76 @@ def test_is_jsonable_failure(data): assert not is_jsonable(data) with pytest.raises((TypeError, ValueError)): json.dumps(data) + + +@pytest.mark.parametrize( + "type_, is_optional", + [ + (Optional[int], True), + (Union[None, int], True), + (Union[int, None], True), + (Optional[CustomType], True), + (Union[None, CustomType], True), + (Union[CustomType, None], True), + (int, False), + (None, False), + (Union[int, float, None], False), + (Union[Union[int, float], None], False), + (Optional[Union[int, float]], False), + ], +) +def test_is_simple_optional_type(type_: Type, is_optional: bool): + assert is_simple_optional_type(type_) is is_optional + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +@pytest.mark.parametrize( + "type_, is_optional", + [ + ("int | None", True), + ("None | int", True), + ("CustomType | None", True), + ("None | CustomType", True), + ("int | float", False), + ("int | float | None", False), + ("(int | float) | None", False), + ("Union[int, float] | None", False), + ], +) +def test_is_simple_optional_type_pipe(type_: str, is_optional: bool): + assert is_simple_optional_type(eval(type_)) is is_optional + + +@pytest.mark.parametrize( + "optional_type, inner_type", + [ + (Optional[int], int), + (Union[int, None], int), + (Union[None, int], int), + (Optional[CustomType], CustomType), + (Union[CustomType, None], CustomType), + (Union[None, CustomType], CustomType), + ], +) +def test_unwrap_simple_optional_type(optional_type: Type, inner_type: Type): + assert unwrap_simple_optional_type(optional_type) is inner_type + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") +@pytest.mark.parametrize( + "optional_type, inner_type", + [ + ("None | int", int), + ("int | None", int), + ("None | CustomType", CustomType), + ("CustomType | None", CustomType), + ], +) +def test_unwrap_simple_optional_type_pipe(optional_type: str, inner_type: Type): + assert unwrap_simple_optional_type(eval(optional_type)) is inner_type + + +@pytest.mark.parametrize("non_optional_type", [int, None, CustomType]) +def test_unwrap_simple_optional_type_fail(non_optional_type: Type): + with pytest.raises(ValueError): + unwrap_simple_optional_type(non_optional_type)