diff --git a/mypy/stubgen.py b/mypy/stubgen.py index f1835aa9acd9..7028b8da04b6 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -85,7 +85,7 @@ from mypy.options import Options as MypyOptions from mypy.types import ( Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance, - AnyType + AnyType, get_proper_type ) from mypy.visitor import NodeVisitor from mypy.find_sources import create_source_list, InvalidSourceList @@ -624,26 +624,24 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False, # name their 0th argument other than self/cls is_self_arg = i == 0 and name == 'self' is_cls_arg = i == 0 and name == 'cls' - if (annotated_type is None - and not arg_.initializer - and not is_self_arg - and not is_cls_arg): - self.add_typing_import("Any") - annotation = ": {}".format(self.typing_name("Any")) - elif annotated_type and not is_self_arg and not is_cls_arg: - annotation = ": {}".format(self.print_annotation(annotated_type)) - else: - annotation = "" + annotation = "" + if annotated_type and not is_self_arg and not is_cls_arg: + # Luckily, an argument explicitly annotated with "Any" has + # type "UnboundType" and will not match. + if not isinstance(get_proper_type(annotated_type), AnyType): + annotation = ": {}".format(self.print_annotation(annotated_type)) if arg_.initializer: - initializer = '...' if kind in (ARG_NAMED, ARG_NAMED_OPT) and not any(arg.startswith('*') for arg in args): args.append('*') if not annotation: - typename = self.get_str_type_of_node(arg_.initializer, True) - annotation = ': {} = ...'.format(typename) + typename = self.get_str_type_of_node(arg_.initializer, True, False) + if typename == '': + annotation = '=...' + else: + annotation = ': {} = ...'.format(typename) else: - annotation += '={}'.format(initializer) + annotation += ' = ...' arg = name + annotation elif kind == ARG_STAR: arg = '*%s%s' % (name, annotation) @@ -654,12 +652,16 @@ def visit_func_def(self, o: FuncDef, is_abstract: bool = False, args.append(arg) retname = None if o.name != '__init__' and isinstance(o.unanalyzed_type, CallableType): - retname = self.print_annotation(o.unanalyzed_type.ret_type) + if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType): + # Luckily, a return type explicitly annotated with "Any" has + # type "UnboundType" and will enter the else branch. + retname = None # implicit Any + else: + retname = self.print_annotation(o.unanalyzed_type.ret_type) elif isinstance(o, FuncDef) and (o.is_abstract or o.name in METHODS_WITH_RETURN_VALUE): # Always assume abstract methods return Any unless explicitly annotated. Also # some dunder methods should not have a None return type. - retname = self.typing_name('Any') - self.add_typing_import("Any") + retname = None # implicit Any elif not has_return_statement(o) and not is_abstract: retname = 'None' retfield = '' @@ -1148,7 +1150,8 @@ def is_private_member(self, fullname: str) -> bool: return False def get_str_type_of_node(self, rvalue: Expression, - can_infer_optional: bool = False) -> str: + can_infer_optional: bool = False, + can_be_any: bool = True) -> str: if isinstance(rvalue, IntExpr): return 'int' if isinstance(rvalue, StrExpr): @@ -1165,8 +1168,11 @@ def get_str_type_of_node(self, rvalue: Expression, isinstance(rvalue, NameExpr) and rvalue.name == 'None': self.add_typing_import('Any') return '{} | None'.format(self.typing_name('Any')) - self.add_typing_import('Any') - return self.typing_name('Any') + if can_be_any: + self.add_typing_import('Any') + return self.typing_name('Any') + else: + return '' def print_annotation(self, t: Type) -> str: printer = AnnotationPrinter(self) diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index f2e102a5dcfe..daee45a84082 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -15,18 +15,14 @@ def f(a, b): def g(arg): pass [out] -from typing import Any - -def f(a: Any, b: Any) -> None: ... -def g(arg: Any) -> None: ... +def f(a, b) -> None: ... +def g(arg) -> None: ... [case testDefaultArgInt] def f(a, b=2): ... def g(b=-1, c=0): ... [out] -from typing import Any - -def f(a: Any, b: int = ...) -> None: ... +def f(a, b: int = ...) -> None: ... def g(b: int = ..., c: int = ...) -> None: ... [case testDefaultArgNone] @@ -59,14 +55,14 @@ def f(x: float = ...) -> None: ... [case testDefaultArgOther] def f(x=ord): ... [out] -from typing import Any - -def f(x: Any = ...) -> None: ... +def f(x=...) -> None: ... [case testPreserveFunctionAnnotation] def f(x: Foo) -> Bar: ... +def g(x: Foo = Foo()) -> Bar: ... [out] def f(x: Foo) -> Bar: ... +def g(x: Foo = ...) -> Bar: ... [case testPreserveVarAnnotation] x: Foo @@ -81,16 +77,12 @@ x: Foo [case testVarArgs] def f(x, *y): ... [out] -from typing import Any - -def f(x: Any, *y: Any) -> None: ... +def f(x, *y) -> None: ... [case testKwVarArgs] def f(x, **y): ... [out] -from typing import Any - -def f(x: Any, **y: Any) -> None: ... +def f(x, **y) -> None: ... [case testVarArgsWithKwVarArgs] def f(a, *b, **c): ... @@ -99,13 +91,11 @@ def h(a, *b, c=1, **d): ... def i(a, *, b=1): ... def j(a, *, b=1, **c): ... [out] -from typing import Any - -def f(a: Any, *b: Any, **c: Any) -> None: ... -def g(a: Any, *b: Any, c: int = ...) -> None: ... -def h(a: Any, *b: Any, c: int = ..., **d: Any) -> None: ... -def i(a: Any, *, b: int = ...) -> None: ... -def j(a: Any, *, b: int = ..., **c: Any) -> None: ... +def f(a, *b, **c) -> None: ... +def g(a, *b, c: int = ...) -> None: ... +def h(a, *b, c: int = ..., **d) -> None: ... +def i(a, *, b: int = ...) -> None: ... +def j(a, *, b: int = ..., **c) -> None: ... [case testClass] class A: @@ -113,10 +103,8 @@ class A: x = 1 def g(): ... [out] -from typing import Any - class A: - def f(self, x: Any) -> None: ... + def f(self, x) -> None: ... def g() -> None: ... @@ -263,9 +251,7 @@ class B(A): ... @decorator def foo(x): ... [out] -from typing import Any - -def foo(x: Any) -> None: ... +def foo(x) -> None: ... [case testMultipleAssignment] x, y = 1, 2 @@ -293,10 +279,8 @@ y: Any def f(x, *, y=1): ... def g(x, *, y=1, z=2): ... [out] -from typing import Any - -def f(x: Any, *, y: int = ...) -> None: ... -def g(x: Any, *, y: int = ..., z: int = ...) -> None: ... +def f(x, *, y: int = ...) -> None: ... +def g(x, *, y: int = ..., z: int = ...) -> None: ... [case testProperty] class A: @@ -309,13 +293,11 @@ class A: def h(self): self.f = 1 [out] -from typing import Any - class A: @property def f(self): ... @f.setter - def f(self, x: Any) -> None: ... + def f(self, x) -> None: ... def h(self) -> None: ... [case testStaticMethod] @@ -323,11 +305,9 @@ class A: @staticmethod def f(x): ... [out] -from typing import Any - class A: @staticmethod - def f(x: Any) -> None: ... + def f(x) -> None: ... [case testClassMethod] class A: @@ -390,10 +370,8 @@ class A: def __getstate__(self): ... def __setstate__(self, state): ... [out] -from typing import Any - class A: - def __eq__(self) -> Any: ... + def __eq__(self): ... -- Tests that will perform runtime imports of modules. -- Don't use `_import` suffix if there are unquoted forward references. @@ -774,17 +752,13 @@ class A(X): ... def syslog(a): pass def syslog(a): pass [out] -from typing import Any - -def syslog(a: Any) -> None: ... +def syslog(a) -> None: ... [case testAsyncAwait_fast_parser] async def f(a): x = await y [out] -from typing import Any - -async def f(a: Any) -> None: ... +async def f(a) -> None: ... [case testInferOptionalOnlyFunc] class A: @@ -1527,11 +1501,10 @@ class Base(metaclass=ABCMeta): import abc from abc import abstractmethod from base import Base -from typing import Any class C(Base, metaclass=abc.ABCMeta): @abstractmethod - def other(self) -> Any: ... + def other(self): ... [case testInvalidNumberOfArgsInAnnotation] def f(x): @@ -1539,9 +1512,7 @@ def f(x): return '' [out] -from typing import Any - -def f(x: Any): ... +def f(x): ... [case testFunctionPartiallyAnnotated] def f(x) -> None: @@ -1555,13 +1526,45 @@ class A: pass [out] +def f(x) -> None: ... +def g(x, y: str): ... + +class A: + def f(self, x) -> None: ... + +[case testExplicitAnyArg] from typing import Any -def f(x: Any) -> None: ... -def g(x: Any, y: str) -> Any: ... +def f(x: Any): + pass +def g(x, y: Any) -> str: + pass +def h(x: Any) -> str: + pass -class A: - def f(self, x: Any) -> None: ... +[out] +from typing import Any + +def f(x: Any): ... +def g(x, y: Any) -> str: ... +def h(x: Any) -> str: ... + +[case testExplicitReturnedAny] +from typing import Any + +def f(x: str) -> Any: + pass +def g(x, y: str) -> Any: + pass +def h(x) -> Any: + pass + +[out] +from typing import Any + +def f(x: str) -> Any: ... +def g(x, y: str) -> Any: ... +def h(x) -> Any: ... [case testPlacementOfDecorators] class A: @@ -1580,8 +1583,6 @@ class B: self.y = 'y' [out] -from typing import Any - class A: y: str @property @@ -1592,7 +1593,7 @@ class B: def x(self): ... y: str @x.setter - def x(self, value: Any) -> None: ... + def x(self, value) -> None: ... [case testMisplacedTypeComment] def f(): @@ -1666,12 +1667,11 @@ class A: [out] import abc -from typing import Any class A(metaclass=abc.ABCMeta): @property @abc.abstractmethod - def x(self) -> Any: ... + def x(self): ... [case testAbstractProperty2_semanal] import other @@ -1683,12 +1683,11 @@ class A: [out] import abc -from typing import Any class A(metaclass=abc.ABCMeta): @property @abc.abstractmethod - def x(self) -> Any: ... + def x(self): ... [case testAbstractProperty3_semanal] import other @@ -1700,16 +1699,14 @@ class A: [out] import abc -from typing import Any class A(metaclass=abc.ABCMeta): @property @abc.abstractmethod - def x(self) -> Any: ... + def x(self): ... [case testClassWithNameAnyOrOptional] -def f(x=object()): - return 1 +Y = object() def g(x=None): pass @@ -1724,7 +1721,8 @@ def Optional(): [out] from typing import Any as _Any -def f(x: _Any = ...): ... +Y: _Any + def g(x: _Any | None = ...) -> None: ... x: _Any @@ -1858,9 +1856,7 @@ def g() -> None: ... def f(x, y): pass [out] -from typing import Any - -def f(x: Any, y: Any) -> None: ... +def f(x, y) -> None: ... [case testImportedModuleExits_import] # modules: a b c @@ -2197,11 +2193,11 @@ from typing import Any class C: x: Any - def __init__(self, x: Any) -> None: ... - def __lt__(self, other: Any) -> Any: ... - def __le__(self, other: Any) -> Any: ... - def __gt__(self, other: Any) -> Any: ... - def __ge__(self, other: Any) -> Any: ... + def __init__(self, x) -> None: ... + def __lt__(self, other): ... + def __le__(self, other): ... + def __gt__(self, other): ... + def __ge__(self, other): ... [case testNamedTupleInClass] from collections import namedtuple