diff --git a/README.md b/README.md index c0f316e..e056de6 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,10 @@ pip install triad ## Release History +### 0.9.7 + +* Make FunctionWrapper compare annotation origins by default + ### 0.9.6 * Add `is_like` to Schema to compare similar schemas diff --git a/setup.cfg b/setup.cfg index dfb1af0..0a95f46 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ addopts = -vvv [flake8] -ignore = E24,E203,W503 +ignore = A005,E24,E203,W503 max-line-length = 88 format = pylint exclude = .svc,CVS,.bzr,.hg,.git,__pycache__,venv,tests/*,docs/* diff --git a/tests/collections/test_function_wrapper.py b/tests/collections/test_function_wrapper.py index 820d6c2..ae974b8 100644 --- a/tests/collections/test_function_wrapper.py +++ b/tests/collections/test_function_wrapper.py @@ -1,17 +1,17 @@ -from typing import Any, Callable, Dict, Iterable, List, Optional +from __future__ import annotations + from copy import deepcopy +from typing import Any, Callable, Dict, Iterable, List, Optional + import pandas as pd from pytest import raises +import sys -from triad.collections.function_wrapper import ( - AnnotatedParam, - FunctionWrapper, - NoneParam, - OtherParam, - function_wrapper, -) -from triad.exceptions import InvalidOperationError from triad import to_uuid +from triad.collections.function_wrapper import (AnnotatedParam, + FunctionWrapper, NoneParam, + OtherParam, function_wrapper) +from triad.exceptions import InvalidOperationError class _Dummy: @@ -38,6 +38,11 @@ class SeriesParam(AnnotatedParam): pass +@MockFunctionWrapper.annotated_param(List[List[int]], "l") +class ListParam(AnnotatedParam): + pass + + def test_registration(): with raises(InvalidOperationError): @@ -113,13 +118,15 @@ def _parse_function(f, params_re, return_re): _parse_function(f4, "^0x$", "d") _parse_function(f6, "^d$", "n") _parse_function(f7, "^yz$", "n") + if sys.version_info >= (3, 9): + _parse_function(f8, "^l$", "n") def f1(a: pd.DataFrame, b: pd.Series) -> None: pass -def f2(e: int, a, b: int, c): +def f2(e: "int", a, b: int, c): return e + a + b - c @@ -141,3 +148,7 @@ def f6(a: _Dummy) -> None: def f7(*args: Any, **kwargs: int): pass + + +def f8(a: list[list[int]]) -> None: + pass diff --git a/tests/utils/test_convert.py b/tests/utils/test_convert.py index 59a9a0d..459a278 100644 --- a/tests/utils/test_convert.py +++ b/tests/utils/test_convert.py @@ -1,31 +1,28 @@ +from __future__ import annotations + import builtins import urllib # must keep for testing purpose import urllib.request # must keep for testing purpose from datetime import date, datetime, timedelta +from typing import Any, Callable, Dict, List, Union, get_type_hints +import pytest +import sys import numpy as np import pandas as pd -import tests.utils.convert_examples as ex from pytest import raises + +import tests.utils.convert_examples as ex from tests.utils.convert_examples import BaseClass, Class2 from tests.utils.convert_examples import SubClass from tests.utils.convert_examples import SubClass as SubClassSame -from triad.utils.convert import ( - _parse_value_and_unit, - as_type, - get_caller_global_local_vars, - get_full_type_path, - str_to_instance, - str_to_object, - str_to_type, - to_bool, - to_datetime, - to_function, - to_instance, - to_size, - to_timedelta, - to_type, -) +from triad.utils.convert import (_parse_value_and_unit, as_type, + compare_annotations, + get_caller_global_local_vars, + get_full_type_path, str_to_instance, + str_to_object, str_to_type, to_bool, + to_datetime, to_function, to_instance, + to_size, to_timedelta, to_type) _GLOBAL_DUMMY = 1 @@ -348,6 +345,53 @@ def f4(): f1() +@pytest.mark.skipif(sys.version_info < (3, 9), reason="python<3.9") +def test_compare_annotations(): + def _assert(f, arg_a, arg_b, expected=True, **kwargs): + # get the argument type annoptation of name arg_a in function f + sig = get_type_hints(f) + a = sig.get(arg_a, Any) + b = sig.get(arg_b, Any) + assert compare_annotations(a, b, **kwargs) == expected + + def f1(a: int, b: str, c, d: None, e: Any): + pass + + _assert(f1, "a", "a") + _assert(f1, "a", "b", False) + _assert(f1, "a", "c", False) + _assert(f1, "c", "c") + _assert(f1, "c", "d", False) + _assert(f1, "c", "e") + _assert(f1, "e", "e") + + def f2(a: List, b: Dict, c: Union[int, str], d: Callable): + pass + + for o in [True, False]: + kwargs = dict(compare_origin=o) + _assert(f2, "a", "a", **kwargs) + _assert(f2, "a", "b", False, **kwargs) + _assert(f2, "c", "c", **kwargs) + _assert(f2, "c", "d", False, **kwargs) + + def f3(a: List[Dict[str, Any]], b: list[dict[str, Any]], c: List): + pass + + _assert(f3, "a", "a") + _assert(f3, "a", "b", True) + _assert(f3, "a", "b", False, compare_origin=False) + _assert(f3, "a", "c", False) + + def f4(a: Callable[..., Dict[str, Any]], b: Callable[..., dict[str, Any]], c: callable): + pass + + _assert(f4, "a", "a") + _assert(f4, "a", "b", True) + _assert(f4, "a", "b", False, compare_origin=False) + _assert(f3, "a", "c", False) + + # This is for test_obj_to_function def dummy_for_test(): pass diff --git a/triad/collections/function_wrapper.py b/triad/collections/function_wrapper.py index bfb41ee..27e9570 100644 --- a/triad/collections/function_wrapper.py +++ b/triad/collections/function_wrapper.py @@ -15,7 +15,7 @@ from ..exceptions import InvalidOperationError from ..utils.assertion import assert_or_throw -from ..utils.convert import get_full_type_path +from ..utils.convert import compare_annotations, get_full_type_path from ..utils.entry_points import load_entry_point from ..utils.hash import to_uuid from .dict import IndexedOrderedDict @@ -165,7 +165,7 @@ def _func(tp: Type["AnnotatedParam"]) -> Type["AnnotatedParam"]: anno = annotation def _m(a: Any) -> bool: - return a == anno + return compare_annotations(a, anno, compare_origin=True) _matcher = _m diff --git a/triad/utils/convert.py b/triad/utils/convert.py index dcc2b24..1ff5efe 100644 --- a/triad/utils/convert.py +++ b/triad/utils/convert.py @@ -2,7 +2,7 @@ import importlib import inspect from types import ModuleType -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, get_args, get_origin import numpy as np import pandas as pd @@ -16,7 +16,6 @@ _HAS_CISO8601 = False from triad.utils.assertion import assert_or_throw - EMPTY_ARGS: List[Any] = [] EMPTY_KWARGS: Dict[str, Any] = {} @@ -403,6 +402,7 @@ def to_timedelta(obj: Any) -> datetime.timedelta: :param obj: object :raises TypeError: if failed to convert + :return: timedelta value """ if obj is None: @@ -449,15 +449,9 @@ def to_size(exp: Any) -> int: default unit is byte if not provided. Unit can be `b`, `byte`, `k`, `kb`, `m`, `mb`, `g`, `gb`, `t`, `tb`. - Args: - exp (Any): expression string or numerical value - - Raises: - ValueError: for invalid expression - ValueError: for negative values - - Returns: - int: size in byte + :param exp: expression string or numerical value + :raises ValueError: for invalid expression and negative values + :return: size in byte """ n, u = _parse_value_and_unit(exp) assert n >= 0.0, "Size can't be negative" @@ -474,6 +468,27 @@ def to_size(exp: Any) -> int: raise ValueError(f"Invalid size expression {exp}") +def compare_annotations(a: Any, b: Any, compare_origin: bool = True) -> bool: + """Compare two type annotations + + :param a: first type annotation + :param b: second type annotation + :param compare_origin: whether to compare the origin of the type annotation + :return: whether the two type annotations are equal + """ + if compare_origin: + ta = get_origin(a) or a + tb = get_origin(b) or b + if ta != tb: + return False + aa = get_args(a) + ba = get_args(b) + if len(aa) != len(ba): + return False + return all(compare_annotations(x, y, compare_origin) for x, y in zip(aa, ba)) + return a == b + + def _parse_value_and_unit(exp: Any) -> Tuple[float, str]: try: assert exp is not None diff --git a/triad_version/__init__.py b/triad_version/__init__.py index e942428..281de95 100644 --- a/triad_version/__init__.py +++ b/triad_version/__init__.py @@ -1,2 +1,2 @@ # flake8: noqa -__version__ = "0.9.6" +__version__ = "0.9.7"