Skip to content

Commit

Permalink
Fix ModelHubMixin coders (#2291)
Browse files Browse the repository at this point in the history
* update HubMixinTest with union and optional custom type

* enable ModelHubMixin to handle union and optional custom type

* add docstring for _is_optional_type helper function

* Refactor helper functions and add independent unit tests. Refactor decode_arg

* Restrict UnionType check to python3.10 and above. Minor style updates.

* Only branch to pipe tests when version >= python3.10

* Change pipe operator tests to str + eval

---------

Co-authored-by: Lucain <lucain@huggingface.co>
  • Loading branch information
gorold and Wauplin committed Jun 14, 2024
1 parent 30e5192 commit 6f91049
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 8 deletions.
26 changes: 23 additions & 3 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@
import warnings
from dataclasses import asdict, dataclass, is_dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, get_args
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_args,
)

from .constants import CONFIG_NAME, PYTORCH_WEIGHTS_NAME, SAFETENSORS_SINGLE_FILE
from .file_download import hf_hub_download
Expand All @@ -16,8 +28,10 @@
SoftTemporaryDirectory,
is_jsonable,
is_safetensors_available,
is_simple_optional_type,
is_torch_available,
logging,
unwrap_simple_optional_type,
validate_hf_hub_args,
)

Expand Down Expand Up @@ -336,14 +350,20 @@ def _encode_arg(cls, arg: Any) -> Any:
"""Encode an argument into a JSON serializable format."""
for type_, (encoder, _) in cls._hub_mixin_coders.items():
if isinstance(arg, type_):
if arg is None:
return None
return encoder(arg)
return arg

@classmethod
def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> ARGS_T:
def _decode_arg(cls, expected_type: Type[ARGS_T], value: Any) -> Optional[ARGS_T]:
"""Decode a JSON serializable value into an argument."""
if is_simple_optional_type(expected_type):
if value is None:
return None
expected_type = unwrap_simple_optional_type(expected_type)
for type_, (_, decoder) in cls._hub_mixin_coders.items():
if issubclass(expected_type, type_):
if inspect.isclass(expected_type) and issubclass(expected_type, type_):
return decoder(value)
return value

Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
from ._subprocess import capture_output, run_interactive_subprocess, run_subprocess
from ._telemetry import send_telemetry
from ._token import get_token
from ._typing import is_jsonable
from ._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type
from ._validators import (
smoothly_deprecate_use_auth_token,
validate_hf_hub_args,
Expand Down
27 changes: 26 additions & 1 deletion src/huggingface_hub/utils/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@
# limitations under the License.
"""Handle typing imports based on system compatibility."""

from typing import Any, Callable, Literal, TypeVar
import sys
from typing import Any, Callable, List, Literal, Type, TypeVar, Union, get_args, get_origin


UNION_TYPES: List[Any] = [Union]
if sys.version_info >= (3, 10):
from types import UnionType

UNION_TYPES += [UnionType]


HTTP_METHOD_T = Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"]
Expand Down Expand Up @@ -48,3 +56,20 @@ def is_jsonable(obj: Any) -> bool:
return False
except RecursionError:
return False


def is_simple_optional_type(type_: Type) -> bool:
"""Check if a type is optional, i.e. Optional[Type] or Union[Type, None] or Type | None, where Type is a non-composite type."""
if get_origin(type_) in UNION_TYPES:
union_args = get_args(type_)
if len(union_args) == 2 and type(None) in union_args:
return True
return False


def unwrap_simple_optional_type(optional_type: Type) -> Type:
"""Unwraps a simple optional type, i.e. returns Type from Optional[Type]."""
for arg in get_args(optional_type):
if arg is not type(None):
return arg
raise ValueError(f"'{optional_type}' is not an optional type")
28 changes: 26 additions & 2 deletions tests/test_hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,22 @@ class DummyModelWithCustomTypes(
},
):
def __init__(
self, foo: int, bar: str, custom: CustomType, custom_default: CustomType = CustomType("default"), **kwargs
self,
foo: int,
bar: str,
baz: Union[int, str],
custom: CustomType,
optional_custom_1: Optional[CustomType],
optional_custom_2: Optional[CustomType],
custom_default: CustomType = CustomType("default"),
**kwargs,
):
self.foo = foo
self.bar = bar
self.baz = baz
self.custom = custom
self.optional_custom_1 = optional_custom_1
self.optional_custom_2 = optional_custom_2
self.custom_default = custom_default

@classmethod
Expand Down Expand Up @@ -406,21 +417,34 @@ def test_from_pretrained_when_cls_is_a_dataclass(self):
assert not hasattr(model, "other")

def test_from_cls_with_custom_type(self):
model = DummyModelWithCustomTypes(1, bar="bar", custom=CustomType("custom"))
model = DummyModelWithCustomTypes(
1,
bar="bar",
baz=1.0,
custom=CustomType("custom"),
optional_custom_1=CustomType("optional"),
optional_custom_2=None,
)
model.save_pretrained(self.cache_dir)

config = json.loads((self.cache_dir / "config.json").read_text())
assert config == {
"foo": 1,
"bar": "bar",
"baz": 1.0,
"custom": {"value": "custom"},
"optional_custom_1": {"value": "optional"},
"optional_custom_2": None,
"custom_default": {"value": "default"},
}

model_reloaded = DummyModelWithCustomTypes.from_pretrained(self.cache_dir)
assert model_reloaded.foo == 1
assert model_reloaded.bar == "bar"
assert model_reloaded.baz == 1.0
assert model_reloaded.custom.value == "custom"
assert model_reloaded.optional_custom_1 is not None and model_reloaded.optional_custom_1.value == "optional"
assert model_reloaded.optional_custom_2 is None
assert model_reloaded.custom_default.value == "default"

def test_inherited_class(self):
Expand Down
81 changes: 80 additions & 1 deletion tests/test_utils_typing.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import json
import sys
from typing import Optional, Type, Union

import pytest

from huggingface_hub.utils._typing import is_jsonable
from huggingface_hub.utils._typing import is_jsonable, is_simple_optional_type, unwrap_simple_optional_type


class NotSerializableClass:
pass


class CustomType:
pass


OBJ_WITH_CIRCULAR_REF = {"hello": "world"}
OBJ_WITH_CIRCULAR_REF["recursive"] = OBJ_WITH_CIRCULAR_REF

Expand Down Expand Up @@ -47,3 +53,76 @@ def test_is_jsonable_failure(data):
assert not is_jsonable(data)
with pytest.raises((TypeError, ValueError)):
json.dumps(data)


@pytest.mark.parametrize(
"type_, is_optional",
[
(Optional[int], True),
(Union[None, int], True),
(Union[int, None], True),
(Optional[CustomType], True),
(Union[None, CustomType], True),
(Union[CustomType, None], True),
(int, False),
(None, False),
(Union[int, float, None], False),
(Union[Union[int, float], None], False),
(Optional[Union[int, float]], False),
],
)
def test_is_simple_optional_type(type_: Type, is_optional: bool):
assert is_simple_optional_type(type_) is is_optional


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
@pytest.mark.parametrize(
"type_, is_optional",
[
("int | None", True),
("None | int", True),
("CustomType | None", True),
("None | CustomType", True),
("int | float", False),
("int | float | None", False),
("(int | float) | None", False),
("Union[int, float] | None", False),
],
)
def test_is_simple_optional_type_pipe(type_: str, is_optional: bool):
assert is_simple_optional_type(eval(type_)) is is_optional


@pytest.mark.parametrize(
"optional_type, inner_type",
[
(Optional[int], int),
(Union[int, None], int),
(Union[None, int], int),
(Optional[CustomType], CustomType),
(Union[CustomType, None], CustomType),
(Union[None, CustomType], CustomType),
],
)
def test_unwrap_simple_optional_type(optional_type: Type, inner_type: Type):
assert unwrap_simple_optional_type(optional_type) is inner_type


@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
@pytest.mark.parametrize(
"optional_type, inner_type",
[
("None | int", int),
("int | None", int),
("None | CustomType", CustomType),
("CustomType | None", CustomType),
],
)
def test_unwrap_simple_optional_type_pipe(optional_type: str, inner_type: Type):
assert unwrap_simple_optional_type(eval(optional_type)) is inner_type


@pytest.mark.parametrize("non_optional_type", [int, None, CustomType])
def test_unwrap_simple_optional_type_fail(non_optional_type: Type):
with pytest.raises(ValueError):
unwrap_simple_optional_type(non_optional_type)

0 comments on commit 6f91049

Please sign in to comment.