Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Utils] Handled Callable in tir.schedule._type_checker (apache#12633)
Browse files Browse the repository at this point in the history
Previously, `Callable` was handled as an atomic type.  This worked
when it was included as last element of a `Union[]` annotation with no
subtypes, but raised an error for other use cases, including
`Optional[Callable]`.

This commit adds explicit checks for `Callable` type annotations to
validate whether the argument is callable, but doesn't recursively
validate the signature of the callable object, because lambda
functions cannot have type
annotations. (https://peps.python.org/pep-3107/#lambda)
  • Loading branch information
Lunderberg authored and xinetzone committed Nov 25, 2022
1 parent fc54bf1 commit 4e20021
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 14 deletions.
40 changes: 40 additions & 0 deletions python/tvm/tir/schedule/_type_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Type checking functionality"""
import collections
import collections.abc
import functools
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
Expand All @@ -26,6 +28,7 @@ def _is_none_type(type_: Any) -> bool:


if hasattr(typing, "_GenericAlias"):
# For python versions 3.7 onward, check the __origin__ attribute.

class _Subtype:
@staticmethod
Expand Down Expand Up @@ -71,7 +74,15 @@ def union(type_: Any) -> Optional[List[type]]:
return list(subtypes)
return None

@staticmethod
def callable(type_: Any) -> Optional[List[type]]:
if _Subtype._origin(type_) is collections.abc.Callable:
subtypes = type_.__args__
return subtypes
return None

elif hasattr(typing, "_Union"):
# For python 3.6 and below, check the __name__ attribute, or CallableMeta.

class _Subtype: # type: ignore
@staticmethod
Expand Down Expand Up @@ -114,6 +125,13 @@ def union(type_: Any) -> Optional[List[type]]:
return list(subtypes)
return None

@staticmethod
def callable(type_: Any) -> Optional[List[type]]:
if isinstance(type_, typing.CallableMeta): # type: ignore # pylint: disable=no-member,protected-access
subtypes = type_.__args__
return subtypes
return None


def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
if _is_none_type(type_):
Expand All @@ -139,12 +157,27 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
if subtype is not None:
return "union", subtype

subtype = _Subtype.callable(type_)
if subtype is not None:
return "callable", subtype

return "atomic", [type_]


def callable_str(subtypes):
if subtypes:
*arg_types, return_type = subtypes
arg_str = ", ".join(_type2str(arg_type) for arg_type in arg_types)
return_type_str = _type2str(return_type)
return f"Callable[[{arg_str}], {return_type_str}]"
else:
return "Callable"


_TYPE2STR: Dict[Any, Callable] = {
"none": lambda: "None",
"atomic": lambda t: str(t.__name__),
"callable": callable_str,
"list": lambda t: f"List[{_type2str(t)}]",
"dict": lambda k, v: f"Dict[{_type2str(k)}, {_type2str(v)}]",
"tuple": lambda *t: f"Tuple[{', '.join([_type2str(x) for x in t])}]",
Expand Down Expand Up @@ -188,6 +221,12 @@ def _type_check_none(v: Any, name: str) -> Optional[str]:
def _type_check_atomic(v: Any, name: str, type_: Any) -> Optional[str]:
return None if isinstance(v, type_) else _type_check_err(v, name, type_)

def _type_check_callable(v: Any, name: str, *_subtypes: Any) -> Optional[str]:
# Current implementation only validates that the argument is
# callable, and doesn't validate the arguments accepted by the
# callable, if any.
return None if callable(v) else _type_check_err(v, name, Callable)

def _type_check_list(v: List[Any], name: str, type_: Any) -> Optional[str]:
if not isinstance(v, (list, tuple)):
return _type_check_err(v, name, list)
Expand Down Expand Up @@ -234,6 +273,7 @@ def _type_check_union(v: Any, name: str, *types: Any) -> Optional[str]:
return {
"none": _type_check_none,
"atomic": _type_check_atomic,
"callable": _type_check_callable,
"list": _type_check_list,
"dict": _type_check_dict,
"tuple": _type_check_tuple,
Expand Down
77 changes: 63 additions & 14 deletions tests/python/unittest/test_type_annotation_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,22 @@
"""Test type checker based on python's type annotations"""

import sys
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple, Union, Callable

import pytest
import _pytest

from tvm.tir.schedule._type_checker import type_checked


def int_func(x: int) -> int:
return 2 * x


def str_func(x: str) -> str:
return 2 * x


test_cases = [
{
"type_annotation": int,
Expand Down Expand Up @@ -90,30 +99,71 @@
None,
],
},
{
"type_annotation": Callable,
"positive_cases": [str_func, int_func],
"negative_cases": [
None,
"x",
42,
],
},
{
"type_annotation": Callable[[int], int],
"positive_cases": [int_func],
"negative_cases": [
None,
"x",
42,
pytest.param(
str_func,
marks=pytest.mark.xfail(
reason="Signature of Callable arguments not currently checked"
),
),
],
},
]

positive_cases = [
(config["type_annotation"], case) for config in test_cases for case in config["positive_cases"]
]

negative_cases = [
(config["type_annotation"], case) for config in test_cases for case in config["negative_cases"]
]

def make_parametrization(type_annotation, case):
if isinstance(case, _pytest.mark.structures.ParameterSet):
marks = case.marks
(case,) = case.values
else:
marks = []

def format_name(type_annotation, case):
try:
name = type_annotation.__name__
annotation_name = type_annotation.__name__
except AttributeError:
name = str(type_annotation).replace("typing.", "")
annotation_name = str(type_annotation).replace("typing.", "")

if hasattr(case, "__name__"):
case_name = case.__name__
else:
case_name = str(case)

return f"{name}_{case}"
name = f"{annotation_name}, {case_name}"

return pytest.param(type_annotation, case, marks=marks, id=name)


positive_cases = [
make_parametrization(config["type_annotation"], case)
for config in test_cases
for case in config["positive_cases"]
]

negative_cases = [
make_parametrization(config["type_annotation"], case)
for config in test_cases
for case in config["negative_cases"]
]


@pytest.mark.parametrize(
["type_annotation", "case"],
positive_cases,
ids=[format_name(t, c) for t, c in positive_cases],
)
def test_matches_type(type_annotation, case):
@type_checked
Expand All @@ -126,7 +176,6 @@ def func(_: type_annotation):
@pytest.mark.parametrize(
["type_annotation", "case"],
negative_cases,
ids=[format_name(t, c) for t, c in negative_cases],
)
def test_not_matches(type_annotation, case):
@type_checked
Expand Down

0 comments on commit 4e20021

Please sign in to comment.