diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py index 5c51b1b09fed..148016fb2d34 100644 --- a/python/tvm/tir/schedule/_type_checker.py +++ b/python/tvm/tir/schedule/_type_checker.py @@ -47,6 +47,10 @@ def _get_subtypes(type_: Any) -> Any: class _Subtype: @staticmethod def _origin(type_: Any) -> Any: + # In Python 3.14+, check if the type has __origin__ attribute directly + if hasattr(type_, "__origin__"): + return type_.__origin__ + if hasattr(typing, "_SpecialGenericAlias"): if isinstance(type_, typing._SpecialGenericAlias): # type: ignore # pylint: disable=protected-access return type_.__origin__ diff --git a/tests/python/testing/test_type_annotation_checker.py b/tests/python/testing/test_type_annotation_checker.py index 42ce1e103903..71bc9ba98bc9 100644 --- a/tests/python/testing/test_type_annotation_checker.py +++ b/tests/python/testing/test_type_annotation_checker.py @@ -187,5 +187,39 @@ def func(_: type_annotation): func(case) +@pytest.mark.parametrize( + ["type_annotation", "expected_key", "expected_subtypes"], + [ + pytest.param(Union[str, int], "union", [str, int], id="Union[str, int]"), + pytest.param(List[str], "list", [str], id="List[str]"), + pytest.param(Dict[str, int], "dict", [str, int], id="Dict[str, int]"), + pytest.param(Tuple[str, int], "tuple", (str, int), id="Tuple[str, int]"), + pytest.param( + Union[List[str], Dict[str, int]], + "union", + [List[str], Dict[str, int]], + id="Union[List[str], Dict[str, int]]", + ), + ], +) +def test_subscripted_generics(type_annotation, expected_key, expected_subtypes): + """Test that _dispatcher correctly handles subscripted generics in Python 3.14+. + + In Python 3.14, Union and other generic types have a different internal representation. + This test ensures that the dispatcher correctly identifies these types. + """ + from tvm.tir.schedule._type_checker import _dispatcher + + key, subtypes = _dispatcher(type_annotation) + assert key == expected_key, f"Expected '{expected_key}' but got '{key}'" + + if isinstance(expected_subtypes, tuple): + assert ( + tuple(subtypes) == expected_subtypes + ), f"Expected {expected_subtypes} but got {subtypes}" + else: + assert subtypes == expected_subtypes, f"Expected {expected_subtypes} but got {subtypes}" + + if __name__ == "__main__": tvm.testing.main()