From b5181c414e53ce5a4fce88558eaa23d1900cf9c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Silvestr=20L=C3=A1n=C3=ADk?= Date: Sun, 14 Aug 2022 18:43:39 +0200 Subject: [PATCH 1/2] fix: add import Union to stub file --- mypy/stubgen.py | 19 +++++++++++++++++++ test-data/unit/stubgen.test | 25 +++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index fc4a7e0fcd9d..e75d126bf557 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -125,6 +125,7 @@ TypeList, TypeStrVisitor, UnboundType, + UnionType, get_proper_type, ) from mypy.visitor import NodeVisitor @@ -461,6 +462,17 @@ def import_lines(self) -> List[str]: module_map: Mapping[str, List[str]] = defaultdict(list) for name in sorted(self.required_names): + # We don't want to ignore Union even if it's not listed in the import statement because for PEP 604 style of + # Union we're still generating explicit Union e.g. + # + # def foo(a: int | str): + # print(a) + # ==> + # def foo(a: Union[int | str]): ... + + if name == "Union": + self.module_for[name] = "typing" + # If we haven't seen this name in an import statement, ignore it if name not in self.module_for: continue @@ -693,6 +705,8 @@ def visit_func_def( # Luckily, an argument explicitly annotated with "Any" has # type "UnboundType" and will not match. if not isinstance(get_proper_type(annotated_type), AnyType): + if isinstance(get_proper_type(annotated_type), UnionType): + self.add_typing_import("Union") annotation = f": {self.print_annotation(annotated_type)}" if kind.is_named() and not any(arg.startswith("*") for arg in args): @@ -722,6 +736,8 @@ def visit_func_def( # type "UnboundType" and will enter the else branch. retname = None # implicit Any else: + if isinstance(get_proper_type(o.unanalyzed_type.ret_type), UnionType): + self.add_typing_import("Union") retname = self.print_annotation(o.unanalyzed_type.ret_type) elif isinstance(o, FuncDef) and ( o.abstract_status == IS_ABSTRACT or o.name in METHODS_WITH_RETURN_VALUE @@ -1200,6 +1216,9 @@ def get_init( return None self._vars[-1].append(lvalue) if annotation is not None: + if isinstance(get_proper_type(annotation), UnionType): + self.add_typing_import("Union") + typename = self.print_annotation(annotation) if ( isinstance(annotation, UnboundType) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 408f116443d2..8890fab03685 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -2705,3 +2705,28 @@ def f(): return 0 [out] def f(): ... + +[case testAddUnionImportForArgument] +def foo(a: str | None): + pass +[out] +from typing import Union + +def foo(a: Union[str, None]): ... + +[case testAddUnionImportForReturn] +def foo() -> str | None: + pass +[out] +from typing import Union + +def foo() -> Union[str, None]: ... + +[case testAddUnionImportForClassAttribute] +class A: + a: int | str +[out] +from typing import Union + +class A: + a: Union[int, str] From 9e88224ce1ba2a15791632a1cb83c29a859b2c85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Silvestr=20L=C3=A1n=C3=ADk?= Date: Tue, 16 Aug 2022 22:03:34 +0200 Subject: [PATCH 2/2] test: fixing failing test --- test-data/unit/stubgen.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 8890fab03685..45d3aabb45e2 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -909,13 +909,13 @@ alias = Dict[str, List[str]] alias = Dict[str, List[str]] [case testDeepGenericTypeAliasPreserved] -from typing import TypeVar +from typing import List, TypeVar, Union T = TypeVar('T') alias = Union[T, List[T]] [out] -from typing import TypeVar +from typing import List, TypeVar, Union T = TypeVar('T') alias = Union[T, List[T]]