diff --git a/libcst/matchers/_visitors.py b/libcst/matchers/_visitors.py index 301e675aa..be50edfd3 100644 --- a/libcst/matchers/_visitors.py +++ b/libcst/matchers/_visitors.py @@ -79,8 +79,18 @@ def _get_possible_match_classes(matcher: BaseMatcherNode) -> List[Type[cst.CSTNo return [getattr(cst, matcher.__class__.__name__)] -def _get_possible_annotated_classes(annotation: object) -> List[Type[object]]: +def _annotation_looks_like_union(annotation: object) -> bool: if getattr(annotation, "__origin__", None) is Union: + return True + # support PEP-604 style unions introduced in Python 3.10 + return ( + annotation.__class__.__name__ == "Union" + and annotation.__class__.__module__ == "types" + ) + + +def _get_possible_annotated_classes(annotation: object) -> List[Type[object]]: + if _annotation_looks_like_union(annotation): return getattr(annotation, "__args__", []) else: return [cast(Type[object], annotation)] diff --git a/libcst/matchers/tests/test_decorators.py b/libcst/matchers/tests/test_decorators.py index c102f2ab5..b1ff3d054 100644 --- a/libcst/matchers/tests/test_decorators.py +++ b/libcst/matchers/tests/test_decorators.py @@ -6,6 +6,7 @@ from ast import literal_eval from textwrap import dedent from typing import List, Set +from unittest.mock import Mock import libcst as cst import libcst.matchers as m @@ -993,3 +994,25 @@ def bar() -> None: # We should have only visited a select number of nodes. self.assertEqual(visitor.visits, ['"baz"']) + + +# This is meant to simulate `cst.ImportFrom | cst.RemovalSentinel` in py3.10 +FakeUnionClass: Mock = Mock() +setattr(FakeUnionClass, "__name__", "Union") +setattr(FakeUnionClass, "__module__", "types") +FakeUnion: Mock = Mock() +FakeUnion.__class__ = FakeUnionClass +FakeUnion.__args__ = [cst.ImportFrom, cst.RemovalSentinel] + + +class MatchersUnionDecoratorsTest(UnitTest): + def test_init_with_new_union_annotation(self) -> None: + class TransformerWithUnionReturnAnnotation(m.MatcherDecoratableTransformer): + @m.leave(m.ImportFrom(module=m.Name(value="typing"))) + def test( + self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> FakeUnion: + pass + + # assert that init (specifically _check_types on return annotation) passes + TransformerWithUnionReturnAnnotation()