Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FunctionWrapper #130

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ pip install triad

## Release History

### 0.9.7

* Make FunctionWrapper compare annotation origins by default

### 0.9.6

* Add `is_like` to Schema to compare similar schemas
Expand Down
27 changes: 18 additions & 9 deletions tests/collections/test_function_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Any, Callable, Dict, Iterable, List, Optional
from __future__ import annotations

from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Optional

import pandas as pd
from pytest import raises

from triad.collections.function_wrapper import (
AnnotatedParam,
FunctionWrapper,
NoneParam,
OtherParam,
function_wrapper,
)
from triad.exceptions import InvalidOperationError
from triad import to_uuid
from triad.collections.function_wrapper import (AnnotatedParam,
FunctionWrapper, NoneParam,
OtherParam, function_wrapper)
from triad.exceptions import InvalidOperationError


class _Dummy:
Expand All @@ -38,6 +37,11 @@ class SeriesParam(AnnotatedParam):
pass


@MockFunctionWrapper.annotated_param(List[List[int]], "l")
class ListParam(AnnotatedParam):
pass


def test_registration():
with raises(InvalidOperationError):

Expand Down Expand Up @@ -113,6 +117,7 @@ def _parse_function(f, params_re, return_re):
_parse_function(f4, "^0x$", "d")
_parse_function(f6, "^d$", "n")
_parse_function(f7, "^yz$", "n")
_parse_function(f8, "^l$", "n")


def f1(a: pd.DataFrame, b: pd.Series) -> None:
Expand Down Expand Up @@ -141,3 +146,7 @@ def f6(a: _Dummy) -> None:

def f7(*args: Any, **kwargs: int):
pass


def f8(a: list[list[int]]) -> None:
pass
76 changes: 59 additions & 17 deletions tests/utils/test_convert.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
from __future__ import annotations

import builtins
import urllib # must keep for testing purpose
import urllib.request # must keep for testing purpose
from datetime import date, datetime, timedelta
from inspect import signature
from typing import Any, Callable, Dict, List, Union, get_type_hints

import numpy as np
import pandas as pd
import tests.utils.convert_examples as ex
from pytest import raises

import tests.utils.convert_examples as ex
from tests.utils.convert_examples import BaseClass, Class2
from tests.utils.convert_examples import SubClass
from tests.utils.convert_examples import SubClass as SubClassSame
from triad.utils.convert import (
_parse_value_and_unit,
as_type,
get_caller_global_local_vars,
get_full_type_path,
str_to_instance,
str_to_object,
str_to_type,
to_bool,
to_datetime,
to_function,
to_instance,
to_size,
to_timedelta,
to_type,
)
from triad.utils.convert import (_parse_value_and_unit, as_type,
compare_annotations,
get_caller_global_local_vars,
get_full_type_path, str_to_instance,
str_to_object, str_to_type, to_bool,
to_datetime, to_function, to_instance,
to_size, to_timedelta, to_type)

_GLOBAL_DUMMY = 1

Expand Down Expand Up @@ -348,6 +344,52 @@ def f4():
f1()


def test_compare_annotations():
def _assert(f, arg_a, arg_b, expected=True, **kwargs):
# get the argument type annoptation of name arg_a in function f
sig = get_type_hints(f)
a = sig.get(arg_a, Any)
b = sig.get(arg_b, Any)
assert compare_annotations(a, b, **kwargs) == expected

def f1(a: int, b: str, c, d: None, e: Any):
pass

_assert(f1, "a", "a")
_assert(f1, "a", "b", False)
_assert(f1, "a", "c", False)
_assert(f1, "c", "c")
_assert(f1, "c", "d", False)
_assert(f1, "c", "e")
_assert(f1, "e", "e")

def f2(a: List, b: Dict, c: Union[int, str], d: Callable):
pass

for o in [True, False]:
kwargs = dict(compare_origin=o)
_assert(f2, "a", "a", **kwargs)
_assert(f2, "a", "b", False, **kwargs)
_assert(f2, "c", "c", **kwargs)
_assert(f2, "c", "d", False, **kwargs)

def f3(a: List[Dict[str, Any]], b: list[dict[str, Any]], c: List):
pass

_assert(f3, "a", "a")
_assert(f3, "a", "b", True)
_assert(f3, "a", "b", False, compare_origin=False)
_assert(f3, "a", "c", False)

def f4(a: Callable[..., Dict[str, Any]], b: Callable[..., dict[str, Any]], c: callable):
pass

_assert(f4, "a", "a")
_assert(f4, "a", "b", True)
_assert(f4, "a", "b", False, compare_origin=False)
_assert(f3, "a", "c", False)


# This is for test_obj_to_function
def dummy_for_test():
pass
Expand Down
4 changes: 2 additions & 2 deletions triad/collections/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ..exceptions import InvalidOperationError
from ..utils.assertion import assert_or_throw
from ..utils.convert import get_full_type_path
from ..utils.convert import compare_annotations, get_full_type_path
from ..utils.entry_points import load_entry_point
from ..utils.hash import to_uuid
from .dict import IndexedOrderedDict
Expand Down Expand Up @@ -165,7 +165,7 @@ def _func(tp: Type["AnnotatedParam"]) -> Type["AnnotatedParam"]:
anno = annotation

def _m(a: Any) -> bool:
return a == anno
return compare_annotations(a, anno, compare_origin=True)

_matcher = _m

Expand Down
37 changes: 26 additions & 11 deletions triad/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import importlib
import inspect
from types import ModuleType
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, get_args, get_origin

import numpy as np
import pandas as pd
Expand All @@ -16,7 +16,6 @@
_HAS_CISO8601 = False
from triad.utils.assertion import assert_or_throw


EMPTY_ARGS: List[Any] = []
EMPTY_KWARGS: Dict[str, Any] = {}

Expand Down Expand Up @@ -403,6 +402,7 @@ def to_timedelta(obj: Any) -> datetime.timedelta:

:param obj: object
:raises TypeError: if failed to convert

:return: timedelta value
"""
if obj is None:
Expand Down Expand Up @@ -449,15 +449,9 @@ def to_size(exp: Any) -> int:
default unit is byte if not provided. Unit can be `b`, `byte`,
`k`, `kb`, `m`, `mb`, `g`, `gb`, `t`, `tb`.

Args:
exp (Any): expression string or numerical value

Raises:
ValueError: for invalid expression
ValueError: for negative values

Returns:
int: size in byte
:param exp: expression string or numerical value
:raises ValueError: for invalid expression and negative values
:return: size in byte
"""
n, u = _parse_value_and_unit(exp)
assert n >= 0.0, "Size can't be negative"
Expand All @@ -474,6 +468,27 @@ def to_size(exp: Any) -> int:
raise ValueError(f"Invalid size expression {exp}")


def compare_annotations(a: Any, b: Any, compare_origin: bool = True) -> bool:
"""Compare two type annotations

:param a: first type annotation
:param b: second type annotation
:param compare_origin: whether to compare the origin of the type annotation
:return: whether the two type annotations are equal
"""
if compare_origin:
ta = get_origin(a) or a
tb = get_origin(b) or b
if ta != tb:
return False
aa = get_args(a)
ba = get_args(b)
if len(aa) != len(ba):
return False
return all(compare_annotations(x, y, compare_origin) for x, y in zip(aa, ba))
return a == b


def _parse_value_and_unit(exp: Any) -> Tuple[float, str]:
try:
assert exp is not None
Expand Down
2 changes: 1 addition & 1 deletion triad_version/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# flake8: noqa
__version__ = "0.9.6"
__version__ = "0.9.7"
Loading