Skip to content

Add import Union to stub file (#12929) #13428

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
TypeList,
TypeStrVisitor,
UnboundType,
UnionType,
get_proper_type,
)
from mypy.visitor import NodeVisitor
Expand Down Expand Up @@ -461,6 +462,17 @@ def import_lines(self) -> List[str]:
module_map: Mapping[str, List[str]] = defaultdict(list)

for name in sorted(self.required_names):
# We don't want to ignore Union even if it's not listed in the import statement because for PEP 604 style of
# Union we're still generating explicit Union e.g.
#
# def foo(a: int | str):
# print(a)
# ==>
# def foo(a: Union[int | str]): ...

if name == "Union":
self.module_for[name] = "typing"

# If we haven't seen this name in an import statement, ignore it
if name not in self.module_for:
continue
Expand Down Expand Up @@ -693,6 +705,8 @@ def visit_func_def(
# Luckily, an argument explicitly annotated with "Any" has
# type "UnboundType" and will not match.
if not isinstance(get_proper_type(annotated_type), AnyType):
if isinstance(get_proper_type(annotated_type), UnionType):
self.add_typing_import("Union")
annotation = f": {self.print_annotation(annotated_type)}"

if kind.is_named() and not any(arg.startswith("*") for arg in args):
Expand Down Expand Up @@ -722,6 +736,8 @@ def visit_func_def(
# type "UnboundType" and will enter the else branch.
retname = None # implicit Any
else:
if isinstance(get_proper_type(o.unanalyzed_type.ret_type), UnionType):
self.add_typing_import("Union")
retname = self.print_annotation(o.unanalyzed_type.ret_type)
elif isinstance(o, FuncDef) and (
o.abstract_status == IS_ABSTRACT or o.name in METHODS_WITH_RETURN_VALUE
Expand Down Expand Up @@ -1200,6 +1216,9 @@ def get_init(
return None
self._vars[-1].append(lvalue)
if annotation is not None:
if isinstance(get_proper_type(annotation), UnionType):
self.add_typing_import("Union")

typename = self.print_annotation(annotation)
if (
isinstance(annotation, UnboundType)
Expand Down
29 changes: 27 additions & 2 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -909,13 +909,13 @@ alias = Dict[str, List[str]]
alias = Dict[str, List[str]]

[case testDeepGenericTypeAliasPreserved]
from typing import TypeVar
from typing import List, TypeVar, Union

T = TypeVar('T')
alias = Union[T, List[T]]

[out]
from typing import TypeVar
from typing import List, TypeVar, Union

T = TypeVar('T')
alias = Union[T, List[T]]
Expand Down Expand Up @@ -2705,3 +2705,28 @@ def f():
return 0
[out]
def f(): ...

[case testAddUnionImportForArgument]
def foo(a: str | None):
pass
[out]
from typing import Union
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also add Optional to the test? It is highly related to Union as well.


def foo(a: Union[str, None]): ...

[case testAddUnionImportForReturn]
def foo() -> str | None:
pass
[out]
from typing import Union

def foo() -> Union[str, None]: ...

[case testAddUnionImportForClassAttribute]
class A:
a: int | str
[out]
from typing import Union

class A:
a: Union[int, str]