diff --git a/docs/source/stubgen.rst b/docs/source/stubgen.rst index 2de0743572e7..c9e52956379a 100644 --- a/docs/source/stubgen.rst +++ b/docs/source/stubgen.rst @@ -127,12 +127,22 @@ alter the default behavior: unwanted side effects, such as the running of tests. Stubgen tries to skip test modules even without this option, but this does not always work. -.. option:: --parse-only +.. option:: --no-analysis Don't perform semantic analysis of source files. This may generate worse stubs -- in particular, some module, class, and function aliases may be represented as variables with the ``Any`` type. This is generally only - useful if semantic analysis causes a critical mypy error. + useful if semantic analysis causes a critical mypy error. Does not apply to + C extension modules. Incompatible with :option:`--inspect-mode`. + +.. option:: --inspect-mode + + Import and inspect modules instead of parsing source code. This is the default + behavior for C modules and pyc-only packages. The flag is useful to force + inspection for pure Python modules that make use of dynamically generated + members that would otherwise be omitted when using the default behavior of + code parsing. Implies :option:`--no-analysis` as analysis requires source + code. .. option:: --doc-dir PATH diff --git a/mypy/moduleinspect.py b/mypy/moduleinspect.py index b383fc9dc145..580b31fb4107 100644 --- a/mypy/moduleinspect.py +++ b/mypy/moduleinspect.py @@ -39,6 +39,10 @@ def is_c_module(module: ModuleType) -> bool: return os.path.splitext(module.__dict__["__file__"])[-1] in [".so", ".pyd", ".dll"] +def is_pyc_only(file: str | None) -> bool: + return bool(file and file.endswith(".pyc") and not os.path.exists(file[:-1])) + + class InspectError(Exception): pass diff --git a/mypy/stubdoc.py b/mypy/stubdoc.py index 145f57fd7751..c277573f0b59 100644 --- a/mypy/stubdoc.py +++ b/mypy/stubdoc.py @@ -8,11 +8,14 @@ import contextlib import io +import keyword import re import tokenize from typing import Any, Final, MutableMapping, MutableSequence, NamedTuple, Sequence, Tuple from typing_extensions import TypeAlias as _TypeAlias +import mypy.util + # Type alias for signatures strings in format ('func_name', '(arg, opt_arg=False)'). Sig: _TypeAlias = Tuple[str, str] @@ -35,12 +38,16 @@ class ArgSig: def __init__(self, name: str, type: str | None = None, default: bool = False): self.name = name - if type and not is_valid_type(type): - raise ValueError("Invalid type: " + type) self.type = type # Does this argument have a default value? self.default = default + def is_star_arg(self) -> bool: + return self.name.startswith("*") and not self.name.startswith("**") + + def is_star_kwarg(self) -> bool: + return self.name.startswith("**") + def __repr__(self) -> str: return "ArgSig(name={}, type={}, default={})".format( repr(self.name), repr(self.type), repr(self.default) @@ -59,7 +66,80 @@ def __eq__(self, other: Any) -> bool: class FunctionSig(NamedTuple): name: str args: list[ArgSig] - ret_type: str + ret_type: str | None + + def is_special_method(self) -> bool: + return bool( + self.name.startswith("__") + and self.name.endswith("__") + and self.args + and self.args[0].name in ("self", "cls") + ) + + def has_catchall_args(self) -> bool: + """Return if this signature has catchall args: (*args, **kwargs)""" + if self.args and self.args[0].name in ("self", "cls"): + args = self.args[1:] + else: + args = self.args + return ( + len(args) == 2 + and all(a.type in (None, "object", "Any", "typing.Any") for a in args) + and args[0].is_star_arg() + and args[1].is_star_kwarg() + ) + + def is_catchall_signature(self) -> bool: + """Return if this signature is the catchall identity: (*args, **kwargs) -> Any""" + return self.has_catchall_args() and self.ret_type in (None, "Any", "typing.Any") + + def format_sig( + self, + indent: str = "", + is_async: bool = False, + any_val: str | None = None, + docstring: str | None = None, + ) -> str: + args: list[str] = [] + for arg in self.args: + arg_def = arg.name + + if arg_def in keyword.kwlist: + arg_def = "_" + arg_def + + if ( + arg.type is None + and any_val is not None + and arg.name not in ("self", "cls") + and not arg.name.startswith("*") + ): + arg_type: str | None = any_val + else: + arg_type = arg.type + if arg_type: + arg_def += ": " + arg_type + if arg.default: + arg_def += " = ..." + + elif arg.default: + arg_def += "=..." + + args.append(arg_def) + + retfield = "" + ret_type = self.ret_type if self.ret_type else any_val + if ret_type is not None: + retfield = " -> " + ret_type + + prefix = "async " if is_async else "" + sig = "{indent}{prefix}def {name}({args}){ret}:".format( + indent=indent, prefix=prefix, name=self.name, args=", ".join(args), ret=retfield + ) + if docstring: + suffix = f"\n{indent} {mypy.util.quote_docstring(docstring)}" + else: + suffix = " ..." + return f"{sig}{suffix}" # States of the docstring parser. @@ -176,17 +256,17 @@ def add_token(self, token: tokenize.TokenInfo) -> None: # arg_name is empty when there are no args. e.g. func() if self.arg_name: - try: + if self.arg_type and not is_valid_type(self.arg_type): + # wrong type, use Any + self.args.append( + ArgSig(name=self.arg_name, type=None, default=bool(self.arg_default)) + ) + else: self.args.append( ArgSig( name=self.arg_name, type=self.arg_type, default=bool(self.arg_default) ) ) - except ValueError: - # wrong type, use Any - self.args.append( - ArgSig(name=self.arg_name, type=None, default=bool(self.arg_default)) - ) self.arg_name = "" self.arg_type = None self.arg_default = None @@ -240,7 +320,7 @@ def args_kwargs(signature: FunctionSig) -> bool: def infer_sig_from_docstring(docstr: str | None, name: str) -> list[FunctionSig] | None: - """Convert function signature to list of TypedFunctionSig + """Convert function signature to list of FunctionSig Look for function signatures of function in docstring. Signature is a string of the format () -> or perhaps without diff --git a/mypy/stubgen.py b/mypy/stubgen.py index e8c12ee4d99b..395a49fa4e08 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -7,7 +7,7 @@ - or use mypy's mechanisms, if importing is prohibited * (optionally) semantically analysing the sources using mypy (as a single set) * emitting the stubs text: - - for Python modules: from ASTs using StubGenerator + - for Python modules: from ASTs using ASTStubGenerator - for C modules using runtime introspection and (optionally) Sphinx docs During first and third steps some problematic files can be skipped, but any @@ -42,14 +42,12 @@ from __future__ import annotations import argparse -import glob import keyword import os import os.path import sys import traceback -from collections import defaultdict -from typing import Final, Iterable, Mapping +from typing import Final, Iterable import mypy.build import mypy.mixedtraverser @@ -66,7 +64,7 @@ SearchPaths, default_lib_path, ) -from mypy.moduleinspect import ModuleInspect +from mypy.moduleinspect import ModuleInspect, is_pyc_only from mypy.nodes import ( ARG_NAMED, ARG_POS, @@ -85,6 +83,7 @@ DictExpr, EllipsisExpr, Expression, + ExpressionStmt, FloatExpr, FuncBase, FuncDef, @@ -109,20 +108,19 @@ Var, ) from mypy.options import Options as MypyOptions -from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures -from mypy.stubgenc import ( - DocstringSignatureGenerator, - ExternalSignatureGenerator, - FallbackSignatureGenerator, - SignatureGenerator, - generate_stub_for_c_module, -) +from mypy.stubdoc import ArgSig, FunctionSig +from mypy.stubgenc import InspectionStubGenerator, generate_stub_for_c_module from mypy.stubutil import ( + BaseStubGenerator, CantImport, + ClassInfo, + FunctionContext, common_dir_prefix, fail_missing, find_module_path_and_all_py3, generate_guarded, + infer_method_arg_types, + infer_method_ret_type, remove_misplaced_type_comments, report_missing, walk_packages, @@ -140,19 +138,13 @@ AnyType, CallableType, Instance, - NoneType, TupleType, Type, - TypeList, - TypeStrVisitor, UnboundType, - UnionType, get_proper_type, ) from mypy.visitor import NodeVisitor -TYPING_MODULE_NAMES: Final = ("typing", "typing_extensions") - # Common ways of naming package containing vendored modules. VENDOR_PACKAGES: Final = ["packages", "vendor", "vendored", "_vendor", "_vendored_packages"] @@ -165,32 +157,6 @@ "/_vendored_packages/", ] -# Special-cased names that are implicitly exported from the stub (from m import y as y). -EXTRA_EXPORTED: Final = { - "pyasn1_modules.rfc2437.univ", - "pyasn1_modules.rfc2459.char", - "pyasn1_modules.rfc2459.univ", -} - -# These names should be omitted from generated stubs. -IGNORED_DUNDERS: Final = { - "__all__", - "__author__", - "__version__", - "__about__", - "__copyright__", - "__email__", - "__license__", - "__summary__", - "__title__", - "__uri__", - "__str__", - "__repr__", - "__getstate__", - "__setstate__", - "__slots__", -} - # These methods are expected to always return a non-trivial value. METHODS_WITH_RETURN_VALUE: Final = { "__ne__", @@ -203,22 +169,6 @@ "__iter__", } -# These magic methods always return the same type. -KNOWN_MAGIC_METHODS_RETURN_TYPES: Final = { - "__len__": "int", - "__length_hint__": "int", - "__init__": "None", - "__del__": "None", - "__bool__": "bool", - "__bytes__": "bytes", - "__format__": "str", - "__contains__": "bool", - "__complex__": "complex", - "__int__": "int", - "__float__": "float", - "__index__": "int", -} - class Options: """Represents stubgen options. @@ -230,6 +180,7 @@ def __init__( self, pyversion: tuple[int, int], no_import: bool, + inspect: bool, doc_dir: str, search_path: list[str], interpreter: str, @@ -248,6 +199,7 @@ def __init__( # See parse_options for descriptions of the flags. self.pyversion = pyversion self.no_import = no_import + self.inspect = inspect self.doc_dir = doc_dir self.search_path = search_path self.interpreter = interpreter @@ -279,6 +231,9 @@ def __init__( self.runtime_all = runtime_all self.ast: MypyFile | None = None + def __repr__(self) -> str: + return f"StubSource({self.source})" + @property def module(self) -> str: return self.source.module @@ -303,71 +258,13 @@ def path(self) -> str | None: ERROR_MARKER: Final = "" -class AnnotationPrinter(TypeStrVisitor): - """Visitor used to print existing annotations in a file. - - The main difference from TypeStrVisitor is a better treatment of - unbound types. - - Notes: - * This visitor doesn't add imports necessary for annotations, this is done separately - by ImportTracker. - * It can print all kinds of types, but the generated strings may not be valid (notably - callable types) since it prints the same string that reveal_type() does. - * For Instance types it prints the fully qualified names. - """ - - # TODO: Generate valid string representation for callable types. - # TODO: Use short names for Instances. - def __init__(self, stubgen: StubGenerator) -> None: - super().__init__(options=mypy.options.Options()) - self.stubgen = stubgen - - def visit_any(self, t: AnyType) -> str: - s = super().visit_any(t) - self.stubgen.import_tracker.require_name(s) - return s - - def visit_unbound_type(self, t: UnboundType) -> str: - s = t.name - self.stubgen.import_tracker.require_name(s) - if t.args: - s += f"[{self.args_str(t.args)}]" - return s - - def visit_none_type(self, t: NoneType) -> str: - return "None" - - def visit_type_list(self, t: TypeList) -> str: - return f"[{self.list_str(t.items)}]" - - def visit_union_type(self, t: UnionType) -> str: - return " | ".join([item.accept(self) for item in t.items]) - - def args_str(self, args: Iterable[Type]) -> str: - """Convert an array of arguments to strings and join the results with commas. - - The main difference from list_str is the preservation of quotes for string - arguments - """ - types = ["builtins.bytes", "builtins.str"] - res = [] - for arg in args: - arg_str = arg.accept(self) - if isinstance(arg, UnboundType) and arg.original_str_fallback in types: - res.append(f"'{arg_str}'") - else: - res.append(arg_str) - return ", ".join(res) - - class AliasPrinter(NodeVisitor[str]): """Visitor used to collect type aliases _and_ type variable definitions. Visit r.h.s of the definition to get the string representation of type alias. """ - def __init__(self, stubgen: StubGenerator) -> None: + def __init__(self, stubgen: ASTStubGenerator) -> None: self.stubgen = stubgen super().__init__() @@ -435,124 +332,6 @@ def visit_op_expr(self, o: OpExpr) -> str: return f"{o.left.accept(self)} {o.op} {o.right.accept(self)}" -class ImportTracker: - """Record necessary imports during stub generation.""" - - def __init__(self) -> None: - # module_for['foo'] has the module name where 'foo' was imported from, or None if - # 'foo' is a module imported directly; examples - # 'from pkg.m import f as foo' ==> module_for['foo'] == 'pkg.m' - # 'from m import f' ==> module_for['f'] == 'm' - # 'import m' ==> module_for['m'] == None - # 'import pkg.m' ==> module_for['pkg.m'] == None - # ==> module_for['pkg'] == None - self.module_for: dict[str, str | None] = {} - - # direct_imports['foo'] is the module path used when the name 'foo' was added to the - # namespace. - # import foo.bar.baz ==> direct_imports['foo'] == 'foo.bar.baz' - # ==> direct_imports['foo.bar'] == 'foo.bar.baz' - # ==> direct_imports['foo.bar.baz'] == 'foo.bar.baz' - self.direct_imports: dict[str, str] = {} - - # reverse_alias['foo'] is the name that 'foo' had originally when imported with an - # alias; examples - # 'import numpy as np' ==> reverse_alias['np'] == 'numpy' - # 'import foo.bar as bar' ==> reverse_alias['bar'] == 'foo.bar' - # 'from decimal import Decimal as D' ==> reverse_alias['D'] == 'Decimal' - self.reverse_alias: dict[str, str] = {} - - # required_names is the set of names that are actually used in a type annotation - self.required_names: set[str] = set() - - # Names that should be reexported if they come from another module - self.reexports: set[str] = set() - - def add_import_from(self, module: str, names: list[tuple[str, str | None]]) -> None: - for name, alias in names: - if alias: - # 'from {module} import {name} as {alias}' - self.module_for[alias] = module - self.reverse_alias[alias] = name - else: - # 'from {module} import {name}' - self.module_for[name] = module - self.reverse_alias.pop(name, None) - self.direct_imports.pop(alias or name, None) - - def add_import(self, module: str, alias: str | None = None) -> None: - if alias: - # 'import {module} as {alias}' - self.module_for[alias] = None - self.reverse_alias[alias] = module - else: - # 'import {module}' - name = module - # add module and its parent packages - while name: - self.module_for[name] = None - self.direct_imports[name] = module - self.reverse_alias.pop(name, None) - name = name.rpartition(".")[0] - - def require_name(self, name: str) -> None: - while name not in self.direct_imports and "." in name: - name = name.rsplit(".", 1)[0] - self.required_names.add(name) - - def reexport(self, name: str) -> None: - """Mark a given non qualified name as needed in __all__. - - This means that in case it comes from a module, it should be - imported with an alias even is the alias is the same as the name. - """ - self.require_name(name) - self.reexports.add(name) - - def import_lines(self) -> list[str]: - """The list of required import lines (as strings with python code).""" - result = [] - - # To summarize multiple names imported from a same module, we collect those - # in the `module_map` dictionary, mapping a module path to the list of names that should - # be imported from it. the names can also be alias in the form 'original as alias' - module_map: Mapping[str, list[str]] = defaultdict(list) - - for name in sorted( - self.required_names, - key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""), - ): - # If we haven't seen this name in an import statement, ignore it - if name not in self.module_for: - continue - - m = self.module_for[name] - if m is not None: - # This name was found in a from ... import ... - # Collect the name in the module_map - if name in self.reverse_alias: - name = f"{self.reverse_alias[name]} as {name}" - elif name in self.reexports: - name = f"{name} as {name}" - module_map[m].append(name) - else: - # This name was found in an import ... - # We can already generate the import line - if name in self.reverse_alias: - source = self.reverse_alias[name] - result.append(f"import {source} as {name}\n") - elif name in self.reexports: - assert "." not in name # Because reexports only has nonqualified names - result.append(f"import {name} as {name}\n") - else: - result.append(f"import {name}\n") - - # Now generate all the from ... import ... lines collected in module_map - for module, names in sorted(module_map.items()): - result.append(f"from {module} import {', '.join(sorted(names))}\n") - return result - - def find_defined_names(file: MypyFile) -> set[str]: finder = DefinitionFinder() file.accept(finder) @@ -583,6 +362,10 @@ def find_referenced_names(file: MypyFile) -> set[str]: return finder.refs +def is_none_expr(expr: Expression) -> bool: + return isinstance(expr, NameExpr) and expr.name == "None" + + class ReferenceFinder(mypy.mixedtraverser.MixedTraverserVisitor): """Find all name references (both local and global).""" @@ -625,74 +408,37 @@ def add_ref(self, fullname: str) -> None: self.refs.add(fullname) -class StubGenerator(mypy.traverser.TraverserVisitor): +class ASTStubGenerator(BaseStubGenerator, mypy.traverser.TraverserVisitor): """Generate stub text from a mypy AST.""" def __init__( self, - _all_: list[str] | None, + _all_: list[str] | None = None, include_private: bool = False, analyzed: bool = False, export_less: bool = False, include_docstrings: bool = False, ) -> None: - # Best known value of __all__. - self._all_ = _all_ - self._output: list[str] = [] + super().__init__(_all_, include_private, export_less, include_docstrings) self._decorators: list[str] = [] - self._import_lines: list[str] = [] - # Current indent level (indent is hardcoded to 4 spaces). - self._indent = "" # Stack of defined variables (per scope). self._vars: list[list[str]] = [[]] # What was generated previously in the stub file. self._state = EMPTY - self._toplevel_names: list[str] = [] - self._include_private = include_private - self._include_docstrings = include_docstrings self._current_class: ClassDef | None = None - self.import_tracker = ImportTracker() # Was the tree semantically analysed before? self.analyzed = analyzed - # Disable implicit exports of package-internal imports? - self.export_less = export_less - # Add imports that could be implicitly generated - self.import_tracker.add_import_from("typing", [("NamedTuple", None)]) - # Names in __all__ are required - for name in _all_ or (): - if name not in IGNORED_DUNDERS: - self.import_tracker.reexport(name) - self.defined_names: set[str] = set() # Short names of methods defined in the body of the current class self.method_names: set[str] = set() self.processing_dataclass = False def visit_mypy_file(self, o: MypyFile) -> None: - self.module = o.fullname # Current module being processed + self.module_name = o.fullname # Current module being processed self.path = o.path - self.defined_names = find_defined_names(o) + self.set_defined_names(find_defined_names(o)) self.referenced_names = find_referenced_names(o) - known_imports = { - "_typeshed": ["Incomplete"], - "typing": ["Any", "TypeVar", "NamedTuple"], - "collections.abc": ["Generator"], - "typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"], - } - for pkg, imports in known_imports.items(): - for t in imports: - if t not in self.defined_names: - alias = None - else: - alias = "_" + t - self.import_tracker.add_import_from(pkg, [(t, alias)]) super().visit_mypy_file(o) - undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names] - if undefined_names: - if self._state != EMPTY: - self.add("\n") - self.add("# Names in __all__ with no definition:\n") - for name in sorted(undefined_names): - self.add(f"# {name}\n") + self.check_undefined_names() def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: """@property with setters and getters, @overload chain and some others.""" @@ -714,38 +460,14 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: # skip the overload implementation and clear the decorator we just processed self.clear_decorators() - def visit_func_def(self, o: FuncDef) -> None: - is_dataclass_generated = ( - self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated - ) - if is_dataclass_generated and o.name != "__init__": - # Skip methods generated by the @dataclass decorator (except for __init__) - return - if ( - self.is_private_name(o.name, o.fullname) - or self.is_not_in_all(o.name) - or (self.is_recorded_name(o.name) and not o.is_overload) - ): - self.clear_decorators() - return - if not self._indent and self._state not in (EMPTY, FUNC) and not o.is_awaitable_coroutine: - self.add("\n") - if not self.is_top_level(): - self_inits = find_self_initializers(o) - for init, value in self_inits: - if init in self.method_names: - # Can't have both an attribute and a method/property with the same name. - continue - init_code = self.get_init(init, value) - if init_code: - self.add(init_code) - # dump decorators, just before "def ..." - for s in self._decorators: - self.add(s) - self.clear_decorators() - self.add(f"{self._indent}{'async ' if o.is_coroutine else ''}def {o.name}(") - self.record_name(o.name) - args: list[str] = [] + def get_default_function_sig(self, func_def: FuncDef, ctx: FunctionContext) -> FunctionSig: + args = self._get_func_args(func_def, ctx) + retname = self._get_func_return(func_def, ctx) + return FunctionSig(func_def.name, args, retname) + + def _get_func_args(self, o: FuncDef, ctx: FunctionContext) -> list[ArgSig]: + args: list[ArgSig] = [] + for i, arg_ in enumerate(o.arguments): var = arg_.variable kind = arg_.kind @@ -759,87 +481,146 @@ def visit_func_def(self, o: FuncDef) -> None: # name their 0th argument other than self/cls is_self_arg = i == 0 and name == "self" is_cls_arg = i == 0 and name == "cls" - annotation = "" + typename: str | None = None if annotated_type and not is_self_arg and not is_cls_arg: # Luckily, an argument explicitly annotated with "Any" has # type "UnboundType" and will not match. if not isinstance(get_proper_type(annotated_type), AnyType): - annotation = f": {self.print_annotation(annotated_type)}" + typename = self.print_annotation(annotated_type) - if kind.is_named() and not any(arg.startswith("*") for arg in args): - args.append("*") + if kind.is_named() and not any(arg.name.startswith("*") for arg in args): + args.append(ArgSig("*")) if arg_.initializer: - if not annotation: + if not typename: typename = self.get_str_type_of_node(arg_.initializer, True, False) - if typename == "": - annotation = "=..." - else: - annotation = f": {typename} = ..." - else: - annotation += " = ..." - arg = name + annotation elif kind == ARG_STAR: - arg = f"*{name}{annotation}" + name = f"*{name}" elif kind == ARG_STAR2: - arg = f"**{name}{annotation}" - else: - arg = name + annotation - args.append(arg) - if o.name == "__init__" and is_dataclass_generated and "**" in args: - # The dataclass plugin generates invalid nameless "*" and "**" arguments - new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "") - args[args.index("*")] = f"*{new_name}_" # this name is guaranteed to be unique - args[args.index("**")] = f"**{new_name}__" # same here + name = f"**{name}" + + args.append(ArgSig(name, typename, default=bool(arg_.initializer))) + + if ctx.class_info is not None and all( + arg.type is None and arg.default is False for arg in args + ): + new_args = infer_method_arg_types( + ctx.name, ctx.class_info.self_var, [arg.name for arg in args] + ) + if new_args is not None: + args = new_args - retname = None + is_dataclass_generated = ( + self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated + ) + if o.name == "__init__" and is_dataclass_generated and "**" in [a.name for a in args]: + # The dataclass plugin generates invalid nameless "*" and "**" arguments + new_name = "".join(a.name.strip("*") for a in args) + for arg in args: + if arg.name == "*": + arg.name = f"*{new_name}_" # this name is guaranteed to be unique + elif arg.name == "**": + arg.name = f"**{new_name}__" # same here + return args + + def _get_func_return(self, o: FuncDef, ctx: FunctionContext) -> str | None: if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType): if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType): # Luckily, a return type explicitly annotated with "Any" has # type "UnboundType" and will enter the else branch. - retname = None # implicit Any + return None # implicit Any else: - retname = self.print_annotation(o.unanalyzed_type.ret_type) - elif o.abstract_status == IS_ABSTRACT or o.name in METHODS_WITH_RETURN_VALUE: + return self.print_annotation(o.unanalyzed_type.ret_type) + if o.abstract_status == IS_ABSTRACT or o.name in METHODS_WITH_RETURN_VALUE: # Always assume abstract methods return Any unless explicitly annotated. Also # some dunder methods should not have a None return type. - retname = None # implicit Any - elif o.name in KNOWN_MAGIC_METHODS_RETURN_TYPES: - retname = KNOWN_MAGIC_METHODS_RETURN_TYPES[o.name] - elif has_yield_expression(o) or has_yield_from_expression(o): - generator_name = self.add_typing_import("Generator") + return None # implicit Any + retname = infer_method_ret_type(o.name) + if retname is not None: + return retname + if has_yield_expression(o) or has_yield_from_expression(o): + generator_name = self.add_name("collections.abc.Generator") yield_name = "None" send_name = "None" return_name = "None" if has_yield_from_expression(o): - yield_name = send_name = self.add_typing_import("Incomplete") + yield_name = send_name = self.add_name("_typeshed.Incomplete") else: for expr, in_assignment in all_yield_expressions(o): - if expr.expr is not None and not self.is_none_expr(expr.expr): - yield_name = self.add_typing_import("Incomplete") + if expr.expr is not None and not is_none_expr(expr.expr): + yield_name = self.add_name("_typeshed.Incomplete") if in_assignment: - send_name = self.add_typing_import("Incomplete") + send_name = self.add_name("_typeshed.Incomplete") if has_return_statement(o): - return_name = self.add_typing_import("Incomplete") - retname = f"{generator_name}[{yield_name}, {send_name}, {return_name}]" - elif not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT: - retname = "None" - retfield = "" - if retname is not None: - retfield = " -> " + retname + return_name = self.add_name("_typeshed.Incomplete") + return f"{generator_name}[{yield_name}, {send_name}, {return_name}]" + if not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT: + return "None" + return None + + def _get_func_docstring(self, node: FuncDef) -> str | None: + if not node.body.body: + return None + expr = node.body.body[0] + if isinstance(expr, ExpressionStmt) and isinstance(expr.expr, StrExpr): + return expr.expr.value + return None - self.add(", ".join(args)) - self.add(f"){retfield}:") - if self._include_docstrings and o.docstring: - docstring = mypy.util.quote_docstring(o.docstring) - self.add(f"\n{self._indent} {docstring}\n") + def visit_func_def(self, o: FuncDef) -> None: + is_dataclass_generated = ( + self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated + ) + if is_dataclass_generated and o.name != "__init__": + # Skip methods generated by the @dataclass decorator (except for __init__) + return + if ( + self.is_private_name(o.name, o.fullname) + or self.is_not_in_all(o.name) + or (self.is_recorded_name(o.name) and not o.is_overload) + ): + self.clear_decorators() + return + if self.is_top_level() and self._state not in (EMPTY, FUNC): + self.add("\n") + if not self.is_top_level(): + self_inits = find_self_initializers(o) + for init, value in self_inits: + if init in self.method_names: + # Can't have both an attribute and a method/property with the same name. + continue + init_code = self.get_init(init, value) + if init_code: + self.add(init_code) + + if self._current_class is not None: + if len(o.arguments): + self_var = o.arguments[0].variable.name + else: + self_var = "self" + class_info = ClassInfo(self._current_class.name, self_var) else: - self.add(" ...\n") + class_info = None + + ctx = FunctionContext( + module_name=self.module_name, + name=o.name, + docstring=self._get_func_docstring(o), + is_abstract=o.abstract_status != NOT_ABSTRACT, + class_info=class_info, + ) - self._state = FUNC + self.record_name(o.name) - def is_none_expr(self, expr: Expression) -> bool: - return isinstance(expr, NameExpr) and expr.name == "None" + default_sig = self.get_default_function_sig(o, ctx) + sigs = self.get_signatures(default_sig, self.sig_generators, ctx) + + for output in self.format_func_def( + sigs, is_coroutine=o.is_coroutine, decorators=self._decorators, docstring=ctx.docstring + ): + self.add(output + "\n") + + self.clear_decorators() + self._state = FUNC def visit_decorator(self, o: Decorator) -> None: if self.is_private_name(o.func.name, o.func.fullname): @@ -917,13 +698,12 @@ def visit_class_def(self, o: ClassDef) -> None: self._current_class = o self.method_names = find_method_names(o.defs.body) sep: int | None = None - if not self._indent and self._state != EMPTY: + if self.is_top_level() and self._state != EMPTY: sep = len(self._output) self.add("\n") decorators = self.get_class_decorators(o) for d in decorators: self.add(f"{self._indent}@{d}\n") - self.add(f"{self._indent}class {o.name}") self.record_name(o.name) base_types = self.get_base_types(o) if base_types: @@ -936,17 +716,16 @@ def visit_class_def(self, o: ClassDef) -> None: base_types.append("metaclass=abc.ABCMeta") self.import_tracker.add_import("abc") self.import_tracker.require_name("abc") - if base_types: - self.add(f"({', '.join(base_types)})") - self.add(":\n") - self._indent += " " + bases = f"({', '.join(base_types)})" if base_types else "" + self.add(f"{self._indent}class {o.name}{bases}:\n") + self.indent() if self._include_docstrings and o.docstring: docstring = mypy.util.quote_docstring(o.docstring) self.add(f"{self._indent}{docstring}\n") n = len(self._output) self._vars.append([]) super().visit_class_def(o) - self._indent = self._indent[:-4] + self.dedent() self._vars.pop() self._vars[-1].append(o.name) if len(self._output) == n: @@ -987,17 +766,17 @@ def get_base_types(self, cdef: ClassDef) -> list[str]: typename = base.args[0].value if nt_fields is None: # Invalid namedtuple() call, cannot determine fields - base_types.append(self.add_typing_import("Incomplete")) + base_types.append(self.add_name("_typeshed.Incomplete")) continue fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields) - namedtuple_name = self.add_typing_import("NamedTuple") + namedtuple_name = self.add_name("typing.NamedTuple") base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])") elif self.is_typed_namedtuple(base): base_types.append(base.accept(p)) else: # At this point, we don't know what the base class is, so we # just use Incomplete as the base class. - base_types.append(self.add_typing_import("Incomplete")) + base_types.append(self.add_name("_typeshed.Incomplete")) for name, value in cdef.keywords.items(): if name == "metaclass": continue # handled separately @@ -1063,7 +842,7 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: init = self.get_init(item.name, o.rvalue, annotation) if init: found = True - if not sep and not self._indent and self._state not in (EMPTY, VAR): + if not sep and self.is_top_level() and self._state not in (EMPTY, VAR): init = "\n" + init sep = True self.add(init) @@ -1092,10 +871,12 @@ def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None field_names.append(field.value) else: return None # Invalid namedtuple fields type - if not field_names: + if field_names: + incomplete = self.add_name("_typeshed.Incomplete") + return [(field_name, incomplete) for field_name in field_names] + else: return [] - incomplete = self.add_typing_import("Incomplete") - return [(field_name, incomplete) for field_name in field_names] + elif self.is_typed_namedtuple(call): fields_arg = call.args[1] if not isinstance(fields_arg, (ListExpr, TupleExpr)): @@ -1125,7 +906,7 @@ def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: if fields is None: self.annotate_as_incomplete(lvalue) return - bases = self.add_typing_import("NamedTuple") + bases = self.add_name("typing.NamedTuple") # TODO: Add support for generic NamedTuples. Requires `Generic` as base class. class_def = f"{self._indent}class {lvalue.name}({bases}):" if len(fields) == 0: @@ -1175,13 +956,13 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None: total = arg else: items.append((arg_name, arg)) - bases = self.add_typing_import("TypedDict") p = AliasPrinter(self) if any(not key.isidentifier() or keyword.iskeyword(key) for key, _ in items): # Keep the call syntax if there are non-identifier or reserved keyword keys. self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n") self._state = VAR else: + bases = self.add_name("typing_extensions.TypedDict") # TODO: Add support for generic TypedDicts. Requires `Generic` as base class. if total is not None: bases += f", total={total.accept(p)}" @@ -1198,7 +979,8 @@ def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None: self._state = CLASS def annotate_as_incomplete(self, lvalue: NameExpr) -> None: - self.add(f"{self._indent}{lvalue.name}: {self.add_typing_import('Incomplete')}\n") + incomplete = self.add_name("_typeshed.Incomplete") + self.add(f"{self._indent}{lvalue.name}: {incomplete}\n") self._state = VAR def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool: @@ -1280,9 +1062,9 @@ def visit_import_from(self, o: ImportFrom) -> None: exported_names: set[str] = set() import_names = [] module, relative = translate_module_name(o.id, o.relative) - if self.module: + if self.module_name: full_module, ok = mypy.util.correct_relative_import( - self.module, relative, module, self.path.endswith(".__init__.py") + self.module_name, relative, module, self.path.endswith(".__init__.py") ) if not ok: full_module = module @@ -1295,37 +1077,7 @@ def visit_import_from(self, o: ImportFrom) -> None: # Vendored six -- translate into plain 'import six'. self.visit_import(Import([("six", None)])) continue - exported = False - if as_name is None and self.module and (self.module + "." + name) in EXTRA_EXPORTED: - # Special case certain names that should be exported, against our general rules. - exported = True - is_private = self.is_private_name(name, full_module + "." + name) - if ( - as_name is None - and name not in self.referenced_names - and not any(n.startswith(name + ".") for n in self.referenced_names) - and (not self._all_ or name in IGNORED_DUNDERS) - and not is_private - and module not in ("abc", "asyncio") + TYPING_MODULE_NAMES - ): - # An imported name that is never referenced in the module is assumed to be - # exported, unless there is an explicit __all__. Note that we need to special - # case 'abc' since some references are deleted during semantic analysis. - exported = True - top_level = full_module.split(".", 1)[0] - self_top_level = self.module.split(".", 1)[0] - if ( - as_name is None - and not self.export_less - and (not self._all_ or name in IGNORED_DUNDERS) - and self.module - and not is_private - and top_level in (self_top_level, "_" + self_top_level) - ): - # Export imports from the same package, since we can't reliably tell whether they - # are part of the public API. - exported = True - if exported: + if self.should_reexport(name, full_module, as_name is not None): self.import_tracker.reexport(name) as_name = name import_names.append((name, as_name)) @@ -1339,7 +1091,7 @@ def visit_import_from(self, o: ImportFrom) -> None: names = [ name for name, alias in o.names - if name in self._all_ and alias is None and name not in IGNORED_DUNDERS + if name in self._all_ and alias is None and name not in self.IGNORED_DUNDERS ] exported_names.update(names) @@ -1373,7 +1125,7 @@ def get_init( isinstance(annotation, UnboundType) and not annotation.args and annotation.name == "Final" - and self.import_tracker.module_for.get("Final") in TYPING_MODULE_NAMES + and self.import_tracker.module_for.get("Final") in self.TYPING_MODULE_NAMES ): # Final without type argument is invalid in stubs. final_arg = self.get_str_type_of_node(rvalue) @@ -1406,67 +1158,14 @@ def get_assign_initializer(self, rvalue: Expression) -> str: # By default, no initializer is required: return "" - def add(self, string: str) -> None: - """Add text to generated stub.""" - self._output.append(string) - def add_decorator(self, name: str, require_name: bool = False) -> None: if require_name: self.import_tracker.require_name(name) - if not self._indent and self._state not in (EMPTY, FUNC): - self._decorators.append("\n") - self._decorators.append(f"{self._indent}@{name}\n") + self._decorators.append(f"@{name}") def clear_decorators(self) -> None: self._decorators.clear() - def typing_name(self, name: str) -> str: - if name in self.defined_names: - # Avoid name clash between name from typing and a name defined in stub. - return "_" + name - else: - return name - - def add_typing_import(self, name: str) -> str: - """Add a name to be imported for typing, unless it's imported already. - - The import will be internal to the stub. - """ - name = self.typing_name(name) - self.import_tracker.require_name(name) - return name - - def add_import_line(self, line: str) -> None: - """Add a line of text to the import section, unless it's already there.""" - if line not in self._import_lines: - self._import_lines.append(line) - - def output(self) -> str: - """Return the text for the stub.""" - imports = "" - if self._import_lines: - imports += "".join(self._import_lines) - imports += "".join(self.import_tracker.import_lines()) - if imports and self._output: - imports += "\n" - return imports + "".join(self._output) - - def is_not_in_all(self, name: str) -> bool: - if self.is_private_name(name): - return False - if self._all_: - return self.is_top_level() and name not in self._all_ - return False - - def is_private_name(self, name: str, fullname: str | None = None) -> bool: - if self._include_private: - return False - if fullname in EXTRA_EXPORTED: - return False - if name == "_": - return False - return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS) - def is_private_member(self, fullname: str) -> bool: parts = fullname.split(".") return any(self.is_private_name(part) for part in parts) @@ -1494,9 +1193,9 @@ def get_str_type_of_node( if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"): return "bool" if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None": - return f"{self.add_typing_import('Incomplete')} | None" + return f"{self.add_name('_typeshed.Incomplete')} | None" if can_be_any: - return self.add_typing_import("Incomplete") + return self.add_name("_typeshed.Incomplete") else: return "" @@ -1534,25 +1233,20 @@ def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression: # This is some other unary expr, we cannot do anything with it (yet?). return expr - def print_annotation(self, t: Type) -> str: - printer = AnnotationPrinter(self) - return t.accept(printer) - - def is_top_level(self) -> bool: - """Are we processing the top level of a file?""" - return self._indent == "" - - def record_name(self, name: str) -> None: - """Mark a name as defined. - - This only does anything if at the top level of a module. - """ - if self.is_top_level(): - self._toplevel_names.append(name) - - def is_recorded_name(self, name: str) -> bool: - """Has this name been recorded previously?""" - return self.is_top_level() and name in self._toplevel_names + def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool: + is_private = self.is_private_name(name, full_module + "." + name) + if ( + not name_is_alias + and name not in self.referenced_names + and (not self._all_ or name in self.IGNORED_DUNDERS) + and not is_private + and full_module not in ("abc", "asyncio") + self.TYPING_MODULE_NAMES + ): + # An imported name that is never referenced in the module is assumed to be + # exported, unless there is an explicit __all__. Note that we need to special + # case 'abc' since some references are deleted during semantic analysis. + return True + return super().should_reexport(name, full_module, name_is_alias) def find_method_names(defs: list[Statement]) -> set[str]: @@ -1608,6 +1302,17 @@ def remove_blacklisted_modules(modules: list[StubSource]) -> list[StubSource]: ] +def split_pyc_from_py(modules: list[StubSource]) -> tuple[list[StubSource], list[StubSource]]: + py_modules = [] + pyc_modules = [] + for mod in modules: + if is_pyc_only(mod.path): + pyc_modules.append(mod) + else: + py_modules.append(mod) + return pyc_modules, py_modules + + def is_blacklisted_path(path: str) -> bool: return any(substr in (normalize_path_separators(path) + "\n") for substr in BLACKLIST) @@ -1620,10 +1325,10 @@ def normalize_path_separators(path: str) -> str: def collect_build_targets( options: Options, mypy_opts: MypyOptions -) -> tuple[list[StubSource], list[StubSource]]: +) -> tuple[list[StubSource], list[StubSource], list[StubSource]]: """Collect files for which we need to generate stubs. - Return list of Python modules and C modules. + Return list of py modules, pyc modules, and C modules. """ if options.packages or options.modules: if options.no_import: @@ -1646,8 +1351,8 @@ def collect_build_targets( c_modules = [] py_modules = remove_blacklisted_modules(py_modules) - - return py_modules, c_modules + pyc_mod, py_mod = split_pyc_from_py(py_modules) + return py_mod, pyc_mod, c_modules def find_module_paths_using_imports( @@ -1826,98 +1531,90 @@ def generate_asts_for_modules( mod.runtime_all = res.manager.semantic_analyzer.export_map[mod.module] -def generate_stub_from_ast( +def generate_stub_for_py_module( mod: StubSource, target: str, + *, parse_only: bool = False, + inspect: bool = False, include_private: bool = False, export_less: bool = False, include_docstrings: bool = False, + doc_dir: str = "", + all_modules: list[str], ) -> None: """Use analysed (or just parsed) AST to generate type stub for single file. If directory for target doesn't exist it will created. Existing stub will be overwritten. """ - gen = StubGenerator( - mod.runtime_all, - include_private=include_private, - analyzed=not parse_only, - export_less=export_less, - include_docstrings=include_docstrings, - ) - assert mod.ast is not None, "This function must be used only with analyzed modules" - mod.ast.accept(gen) + if inspect: + ngen = InspectionStubGenerator( + module_name=mod.module, + known_modules=all_modules, + _all_=mod.runtime_all, + doc_dir=doc_dir, + include_private=include_private, + export_less=export_less, + include_docstrings=include_docstrings, + ) + ngen.generate_module() + output = ngen.output() + + else: + gen = ASTStubGenerator( + mod.runtime_all, + include_private=include_private, + analyzed=not parse_only, + export_less=export_less, + include_docstrings=include_docstrings, + ) + assert mod.ast is not None, "This function must be used only with analyzed modules" + mod.ast.accept(gen) + output = gen.output() # Write output to file. subdir = os.path.dirname(target) if subdir and not os.path.isdir(subdir): os.makedirs(subdir) with open(target, "w") as file: - file.write("".join(gen.output())) - - -def get_sig_generators(options: Options) -> list[SignatureGenerator]: - sig_generators: list[SignatureGenerator] = [ - DocstringSignatureGenerator(), - FallbackSignatureGenerator(), - ] - if options.doc_dir: - # Collect info from docs (if given). Always check these first. - sigs, class_sigs = collect_docs_signatures(options.doc_dir) - sig_generators.insert(0, ExternalSignatureGenerator(sigs, class_sigs)) - return sig_generators - - -def collect_docs_signatures(doc_dir: str) -> tuple[dict[str, str], dict[str, str]]: - """Gather all function and class signatures in the docs. - - Return a tuple (function signatures, class signatures). - Currently only used for C modules. - """ - all_sigs: list[Sig] = [] - all_class_sigs: list[Sig] = [] - for path in glob.glob(f"{doc_dir}/*.rst"): - with open(path) as f: - loc_sigs, loc_class_sigs = parse_all_signatures(f.readlines()) - all_sigs += loc_sigs - all_class_sigs += loc_class_sigs - sigs = dict(find_unique_signatures(all_sigs)) - class_sigs = dict(find_unique_signatures(all_class_sigs)) - return sigs, class_sigs + file.write(output) def generate_stubs(options: Options) -> None: """Main entry point for the program.""" mypy_opts = mypy_options(options) - py_modules, c_modules = collect_build_targets(options, mypy_opts) - sig_generators = get_sig_generators(options) + py_modules, pyc_modules, c_modules = collect_build_targets(options, mypy_opts) + all_modules = py_modules + pyc_modules + c_modules + all_module_names = sorted(m.module for m in all_modules) # Use parsed sources to generate stubs for Python modules. generate_asts_for_modules(py_modules, options.parse_only, mypy_opts, options.verbose) files = [] - for mod in py_modules: + for mod in py_modules + pyc_modules: assert mod.path is not None, "Not found module was not skipped" target = mod.module.replace(".", "/") - if os.path.basename(mod.path) == "__init__.py": + if os.path.basename(mod.path) in ["__init__.py", "__init__.pyc"]: target += "/__init__.pyi" else: target += ".pyi" target = os.path.join(options.output_dir, target) files.append(target) with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): - generate_stub_from_ast( + generate_stub_for_py_module( mod, target, - options.parse_only, - options.include_private, - options.export_less, + parse_only=options.parse_only, + inspect=options.inspect or mod in pyc_modules, + include_private=options.include_private, + export_less=options.export_less, include_docstrings=options.include_docstrings, + doc_dir=options.doc_dir, + all_modules=all_module_names, ) # Separately analyse C modules using different logic. - all_modules = sorted(m.module for m in (py_modules + c_modules)) for mod in c_modules: - if any(py_mod.module.startswith(mod.module + ".") for py_mod in py_modules + c_modules): + if any(py_mod.module.startswith(mod.module + ".") for py_mod in all_modules): target = mod.module.replace(".", "/") + "/__init__.pyi" else: target = mod.module.replace(".", "/") + ".pyi" @@ -1927,11 +1624,12 @@ def generate_stubs(options: Options) -> None: generate_stub_for_c_module( mod.module, target, - known_modules=all_modules, - sig_generators=sig_generators, - include_docstrings=options.include_docstrings, + known_modules=all_module_names, + doc_dir=options.doc_dir, + include_private=options.include_private, + export_less=options.export_less, ) - num_modules = len(py_modules) + len(c_modules) + num_modules = len(all_modules) if not options.quiet and num_modules > 0: print("Processed %d modules" % num_modules) if len(files) == 1: @@ -1967,10 +1665,21 @@ def parse_options(args: list[str]) -> Options: "respect __all__)", ) parser.add_argument( + "--no-analysis", "--parse-only", + dest="parse_only", action="store_true", help="don't perform semantic analysis of sources, just parse them " - "(only applies to Python modules, might affect quality of stubs)", + "(only applies to Python modules, might affect quality of stubs. " + "Not compatible with --inspect)", + ) + parser.add_argument( + "--inspect-mode", + dest="inspect", + action="store_true", + help="import and inspect modules instead of parsing source code." + "This is the default behavior for c modules and pyc-only packages, but " + "it is also useful for pure python modules with dynamically generated members.", ) parser.add_argument( "--include-private", @@ -2047,6 +1756,8 @@ def parse_options(args: list[str]) -> Options: parser.error("May only specify one of: modules/packages or files.") if ns.quiet and ns.verbose: parser.error("Cannot specify both quiet and verbose messages") + if ns.inspect and ns.parse_only: + parser.error("Cannot specify both --parse-only/--no-analysis and --inspect-mode") # Create the output folder if it doesn't already exist. if not os.path.exists(ns.output_dir): @@ -2055,6 +1766,7 @@ def parse_options(args: list[str]) -> Options: return Options( pyversion=pyversion, no_import=ns.no_import, + inspect=ns.inspect, doc_dir=ns.doc_dir, search_path=ns.search_path.split(":"), interpreter=ns.interpreter, diff --git a/mypy/stubgenc.py b/mypy/stubgenc.py index 31487f9d0dcf..0ad79a4265b3 100755 --- a/mypy/stubgenc.py +++ b/mypy/stubgenc.py @@ -6,68 +6,38 @@ from __future__ import annotations +import glob import importlib import inspect +import keyword import os.path -import re -from abc import abstractmethod -from types import ModuleType -from typing import Any, Final, Iterable, Mapping +from types import FunctionType, ModuleType +from typing import Any, Mapping -import mypy.util +from mypy.fastparse import parse_type_comment from mypy.moduleinspect import is_c_module from mypy.stubdoc import ( ArgSig, FunctionSig, + Sig, + find_unique_signatures, infer_arg_sig_from_anon_docstring, infer_prop_type_from_docstring, infer_ret_type_sig_from_anon_docstring, infer_ret_type_sig_from_docstring, infer_sig_from_docstring, + parse_all_signatures, ) - -# Members of the typing module to consider for importing by default. -_DEFAULT_TYPING_IMPORTS: Final = ( - "Any", - "Callable", - "ClassVar", - "Dict", - "Iterable", - "Iterator", - "List", - "Optional", - "Tuple", - "Union", +from mypy.stubutil import ( + BaseStubGenerator, + ClassInfo, + FunctionContext, + SignatureGenerator, + infer_method_arg_types, + infer_method_ret_type, ) -class SignatureGenerator: - """Abstract base class for extracting a list of FunctionSigs for each function.""" - - def remove_self_type( - self, inferred: list[FunctionSig] | None, self_var: str - ) -> list[FunctionSig] | None: - """Remove type annotation from self/cls argument""" - if inferred: - for signature in inferred: - if signature.args: - if signature.args[0].name == self_var: - signature.args[0].type = None - return inferred - - @abstractmethod - def get_function_sig( - self, func: object, module_name: str, name: str - ) -> list[FunctionSig] | None: - pass - - @abstractmethod - def get_method_sig( - self, cls: type, func: object, module_name: str, class_name: str, name: str, self_var: str - ) -> list[FunctionSig] | None: - pass - - class ExternalSignatureGenerator(SignatureGenerator): def __init__( self, func_sigs: dict[str, str] | None = None, class_sigs: dict[str, str] | None = None @@ -79,97 +49,104 @@ class signatures (usually corresponds to __init__). self.func_sigs = func_sigs or {} self.class_sigs = class_sigs or {} - def get_function_sig( - self, func: object, module_name: str, name: str - ) -> list[FunctionSig] | None: - if name in self.func_sigs: - return [ - FunctionSig( - name=name, - args=infer_arg_sig_from_anon_docstring(self.func_sigs[name]), - ret_type="Any", - ) - ] - else: - return None + @classmethod + def from_doc_dir(cls, doc_dir: str) -> ExternalSignatureGenerator: + """Instantiate from a directory of .rst files.""" + all_sigs: list[Sig] = [] + all_class_sigs: list[Sig] = [] + for path in glob.glob(f"{doc_dir}/*.rst"): + with open(path) as f: + loc_sigs, loc_class_sigs = parse_all_signatures(f.readlines()) + all_sigs += loc_sigs + all_class_sigs += loc_class_sigs + sigs = dict(find_unique_signatures(all_sigs)) + class_sigs = dict(find_unique_signatures(all_class_sigs)) + return ExternalSignatureGenerator(sigs, class_sigs) - def get_method_sig( - self, cls: type, func: object, module_name: str, class_name: str, name: str, self_var: str + def get_function_sig( + self, default_sig: FunctionSig, ctx: FunctionContext ) -> list[FunctionSig] | None: + # method: if ( - name in ("__new__", "__init__") - and name not in self.func_sigs - and class_name in self.class_sigs + ctx.class_info + and ctx.name in ("__new__", "__init__") + and ctx.name not in self.func_sigs + and ctx.class_info.name in self.class_sigs ): return [ FunctionSig( - name=name, - args=infer_arg_sig_from_anon_docstring(self.class_sigs[class_name]), - ret_type=infer_method_ret_type(name), + name=ctx.name, + args=infer_arg_sig_from_anon_docstring(self.class_sigs[ctx.class_info.name]), + ret_type=infer_method_ret_type(ctx.name), ) ] - inferred = self.get_function_sig(func, module_name, name) - return self.remove_self_type(inferred, self_var) + + # function: + if ctx.name not in self.func_sigs: + return None + + inferred = [ + FunctionSig( + name=ctx.name, + args=infer_arg_sig_from_anon_docstring(self.func_sigs[ctx.name]), + ret_type=None, + ) + ] + if ctx.class_info: + return self.remove_self_type(inferred, ctx.class_info.self_var) + else: + return inferred + + def get_property_type(self, default_type: str | None, ctx: FunctionContext) -> str | None: + return None class DocstringSignatureGenerator(SignatureGenerator): def get_function_sig( - self, func: object, module_name: str, name: str + self, default_sig: FunctionSig, ctx: FunctionContext ) -> list[FunctionSig] | None: - docstr = getattr(func, "__doc__", None) - inferred = infer_sig_from_docstring(docstr, name) + inferred = infer_sig_from_docstring(ctx.docstring, ctx.name) if inferred: - assert docstr is not None - if is_pybind11_overloaded_function_docstring(docstr, name): + assert ctx.docstring is not None + if is_pybind11_overloaded_function_docstring(ctx.docstring, ctx.name): # Remove pybind11 umbrella (*args, **kwargs) for overloaded functions del inferred[-1] - return inferred - def get_method_sig( - self, - cls: type, - func: object, - module_name: str, - class_name: str, - func_name: str, - self_var: str, - ) -> list[FunctionSig] | None: - inferred = self.get_function_sig(func, module_name, func_name) - if not inferred and func_name == "__init__": - # look for class-level constructor signatures of the form () - inferred = self.get_function_sig(cls, module_name, class_name) - return self.remove_self_type(inferred, self_var) + if ctx.class_info: + if not inferred and ctx.name == "__init__": + # look for class-level constructor signatures of the form () + inferred = infer_sig_from_docstring(ctx.class_info.docstring, ctx.class_info.name) + if inferred: + inferred = [sig._replace(name="__init__") for sig in inferred] + return self.remove_self_type(inferred, ctx.class_info.self_var) + else: + return inferred + def get_property_type(self, default_type: str | None, ctx: FunctionContext) -> str | None: + """Infer property type from docstring or docstring signature.""" + if ctx.docstring is not None: + inferred = infer_ret_type_sig_from_anon_docstring(ctx.docstring) + if not inferred: + inferred = infer_ret_type_sig_from_docstring(ctx.docstring, ctx.name) + if not inferred: + inferred = infer_prop_type_from_docstring(ctx.docstring) + return inferred + else: + return None -class FallbackSignatureGenerator(SignatureGenerator): - def get_function_sig( - self, func: object, module_name: str, name: str - ) -> list[FunctionSig] | None: - return [ - FunctionSig( - name=name, - args=infer_arg_sig_from_anon_docstring("(*args, **kwargs)"), - ret_type="Any", - ) - ] - def get_method_sig( - self, cls: type, func: object, module_name: str, class_name: str, name: str, self_var: str - ) -> list[FunctionSig] | None: - return [ - FunctionSig( - name=name, - args=infer_method_args(name, self_var), - ret_type=infer_method_ret_type(name), - ) - ] +def is_pybind11_overloaded_function_docstring(docstring: str, name: str) -> bool: + return docstring.startswith(f"{name}(*args, **kwargs)\nOverloaded function.\n\n") def generate_stub_for_c_module( module_name: str, target: str, known_modules: list[str], - sig_generators: Iterable[SignatureGenerator], + doc_dir: str = "", + *, + include_private: bool = False, + export_less: bool = False, include_docstrings: bool = False, ) -> None: """Generate stub for C module. @@ -184,452 +161,664 @@ def generate_stub_for_c_module( If directory for target doesn't exist it will be created. Existing stub will be overwritten. """ - module = importlib.import_module(module_name) - assert is_c_module(module), f"{module_name} is not a C module" subdir = os.path.dirname(target) if subdir and not os.path.isdir(subdir): os.makedirs(subdir) - imports: list[str] = [] - functions: list[str] = [] - done = set() - items = sorted(get_members(module), key=lambda x: x[0]) - for name, obj in items: - if is_c_function(obj): - generate_c_function_stub( - module, - name, - obj, - output=functions, - known_modules=known_modules, - imports=imports, - sig_generators=sig_generators, - include_docstrings=include_docstrings, - ) - done.add(name) - types: list[str] = [] - for name, obj in items: - if name.startswith("__") and name.endswith("__"): - continue - if is_c_type(obj): - generate_c_type_stub( - module, - name, - obj, - output=types, - known_modules=known_modules, - imports=imports, - sig_generators=sig_generators, - include_docstrings=include_docstrings, - ) - done.add(name) - variables = [] - for name, obj in items: - if name.startswith("__") and name.endswith("__"): - continue - if name not in done and not inspect.ismodule(obj): - type_str = strip_or_import( - get_type_fullname(type(obj)), module, known_modules, imports - ) - variables.append(f"{name}: {type_str}") - output = sorted(set(imports)) - for line in variables: - output.append(line) - for line in types: - if line.startswith("class") and output and output[-1]: - output.append("") - output.append(line) - if output and functions: - output.append("") - for line in functions: - output.append(line) - output = add_typing_import(output) + + gen = InspectionStubGenerator( + module_name, + known_modules, + doc_dir, + include_private=include_private, + export_less=export_less, + include_docstrings=include_docstrings, + ) + gen.generate_module() + output = gen.output() + with open(target, "w") as file: - for line in output: - file.write(f"{line}\n") - - -def add_typing_import(output: list[str]) -> list[str]: - """Add typing imports for collections/types that occur in the generated stub.""" - names = [] - for name in _DEFAULT_TYPING_IMPORTS: - if any(re.search(r"\b%s\b" % name, line) for line in output): - names.append(name) - if names: - return [f"from typing import {', '.join(names)}", ""] + output - else: - return output.copy() - - -def get_members(obj: object) -> list[tuple[str, Any]]: - obj_dict: Mapping[str, Any] = getattr(obj, "__dict__") # noqa: B009 - results = [] - for name in obj_dict: - if is_skipped_attribute(name): - continue - # Try to get the value via getattr - try: - value = getattr(obj, name) - except AttributeError: - continue - else: - results.append((name, value)) - return results + file.write(output) -def is_c_function(obj: object) -> bool: - return inspect.isbuiltin(obj) or type(obj) is type(ord) +class CFunctionStub: + """ + Class that mimics a C function in order to provide parseable docstrings. + """ + def __init__(self, name: str, doc: str, is_abstract: bool = False): + self.__name__ = name + self.__doc__ = doc + self.__abstractmethod__ = is_abstract -def is_c_method(obj: object) -> bool: - return inspect.ismethoddescriptor(obj) or type(obj) in ( - type(str.index), - type(str.__add__), - type(str.__new__), - ) + @classmethod + def _from_sig(cls, sig: FunctionSig, is_abstract: bool = False) -> CFunctionStub: + return CFunctionStub(sig.name, sig.format_sig()[:-4], is_abstract) + @classmethod + def _from_sigs(cls, sigs: list[FunctionSig], is_abstract: bool = False) -> CFunctionStub: + return CFunctionStub( + sigs[0].name, "\n".join(sig.format_sig()[:-4] for sig in sigs), is_abstract + ) -def is_c_classmethod(obj: object) -> bool: - return inspect.isbuiltin(obj) or type(obj).__name__ in ( - "classmethod", - "classmethod_descriptor", - ) + def __get__(self) -> None: + """ + This exists to make this object look like a method descriptor and thus + return true for CStubGenerator.ismethod() + """ + pass -def is_c_property(obj: object) -> bool: - return inspect.isdatadescriptor(obj) or hasattr(obj, "fget") +class InspectionStubGenerator(BaseStubGenerator): + """Stub generator that does not parse code. + Generation is performed by inspecting the module's contents, and thus works + for highly dynamic modules, pyc files, and C modules (via the CStubGenerator + subclass). + """ -def is_c_property_readonly(prop: Any) -> bool: - return hasattr(prop, "fset") and prop.fset is None + def __init__( + self, + module_name: str, + known_modules: list[str], + doc_dir: str = "", + _all_: list[str] | None = None, + include_private: bool = False, + export_less: bool = False, + include_docstrings: bool = False, + module: ModuleType | None = None, + ) -> None: + self.doc_dir = doc_dir + if module is None: + self.module = importlib.import_module(module_name) + else: + self.module = module + self.is_c_module = is_c_module(self.module) + self.known_modules = known_modules + self.resort_members = self.is_c_module + super().__init__(_all_, include_private, export_less, include_docstrings) + self.module_name = module_name + + def get_default_function_sig(self, func: object, ctx: FunctionContext) -> FunctionSig: + argspec = None + if not self.is_c_module: + # Get the full argument specification of the function + try: + argspec = inspect.getfullargspec(func) + except TypeError: + # some callables cannot be inspected, e.g. functools.partial + pass + if argspec is None: + if ctx.class_info is not None: + # method: + return FunctionSig( + name=ctx.name, + args=infer_c_method_args(ctx.name, ctx.class_info.self_var), + ret_type=infer_method_ret_type(ctx.name), + ) + else: + # function: + return FunctionSig( + name=ctx.name, + args=[ArgSig(name="*args"), ArgSig(name="**kwargs")], + ret_type=None, + ) + # Extract the function arguments, defaults, and varargs + args = argspec.args + defaults = argspec.defaults + varargs = argspec.varargs + kwargs = argspec.varkw + annotations = argspec.annotations + + def get_annotation(key: str) -> str | None: + if key not in annotations: + return None + argtype = annotations[key] + if argtype is None: + return "None" + if not isinstance(argtype, str): + return self.get_type_fullname(argtype) + return argtype + + arglist: list[ArgSig] = [] + # Add the arguments to the signature + for i, arg in enumerate(args): + # Check if the argument has a default value + if defaults and i >= len(args) - len(defaults): + default_value = defaults[i - (len(args) - len(defaults))] + if arg in annotations: + argtype = annotations[arg] + else: + argtype = self.get_type_annotation(default_value) + if argtype == "None": + # None is not a useful annotation, but we can infer that the arg + # is optional + incomplete = self.add_name("_typeshed.Incomplete") + argtype = f"{incomplete} | None" + arglist.append(ArgSig(arg, argtype, default=True)) + else: + arglist.append(ArgSig(arg, get_annotation(arg), default=False)) -def is_c_type(obj: object) -> bool: - return inspect.isclass(obj) or type(obj) is type(int) + # Add *args if present + if varargs: + arglist.append(ArgSig(f"*{varargs}", get_annotation(varargs))) + # Add **kwargs if present + if kwargs: + arglist.append(ArgSig(f"**{kwargs}", get_annotation(kwargs))) -def is_pybind11_overloaded_function_docstring(docstr: str, name: str) -> bool: - return docstr.startswith(f"{name}(*args, **kwargs)\n" + "Overloaded function.\n\n") + # add types for known special methods + if ctx.class_info is not None and all( + arg.type is None and arg.default is False for arg in arglist + ): + new_args = infer_method_arg_types( + ctx.name, ctx.class_info.self_var, [arg.name for arg in arglist if arg.name] + ) + if new_args is not None: + arglist = new_args + ret_type = get_annotation("return") or infer_method_ret_type(ctx.name) + return FunctionSig(ctx.name, arglist, ret_type) -def generate_c_function_stub( - module: ModuleType, - name: str, - obj: object, - *, - known_modules: list[str], - sig_generators: Iterable[SignatureGenerator], - output: list[str], - imports: list[str], - self_var: str | None = None, - cls: type | None = None, - class_name: str | None = None, - include_docstrings: bool = False, -) -> None: - """Generate stub for a single function or method. + def get_sig_generators(self) -> list[SignatureGenerator]: + if not self.is_c_module: + return [] + else: + sig_generators: list[SignatureGenerator] = [DocstringSignatureGenerator()] + if self.doc_dir: + # Collect info from docs (if given). Always check these first. + sig_generators.insert(0, ExternalSignatureGenerator.from_doc_dir(self.doc_dir)) + return sig_generators - The result will be appended to 'output'. - If necessary, any required names will be added to 'imports'. - The 'class_name' is used to find signature of __init__ or __new__ in - 'class_sigs'. - """ - inferred: list[FunctionSig] | None = None - docstr: str | None = None - if class_name: - # method: - assert cls is not None, "cls should be provided for methods" - assert self_var is not None, "self_var should be provided for methods" - for sig_gen in sig_generators: - inferred = sig_gen.get_method_sig( - cls, obj, module.__name__, class_name, name, self_var + def strip_or_import(self, type_name: str) -> str: + """Strips unnecessary module names from typ. + + If typ represents a type that is inside module or is a type coming from builtins, remove + module declaration from it. Return stripped name of the type. + + Arguments: + typ: name of the type + """ + local_modules = ["builtins", self.module_name] + parsed_type = parse_type_comment(type_name, 0, 0, None)[1] + assert parsed_type is not None, type_name + return self.print_annotation(parsed_type, self.known_modules, local_modules) + + def get_obj_module(self, obj: object) -> str | None: + """Return module name of the object.""" + return getattr(obj, "__module__", None) + + def is_defined_in_module(self, obj: object) -> bool: + """Check if object is considered defined in the current module.""" + module = self.get_obj_module(obj) + return module is None or module == self.module_name + + def generate_module(self) -> None: + all_items = self.get_members(self.module) + if self.resort_members: + all_items = sorted(all_items, key=lambda x: x[0]) + items = [] + for name, obj in all_items: + if inspect.ismodule(obj) and obj.__name__ in self.known_modules: + module_name = obj.__name__ + if module_name.startswith(self.module_name + "."): + # from {.rel_name} import {mod_name} as {name} + pkg_name, mod_name = module_name.rsplit(".", 1) + rel_module = pkg_name[len(self.module_name) :] or "." + self.import_tracker.add_import_from(rel_module, [(mod_name, name)]) + self.import_tracker.reexport(name) + else: + # import {module_name} as {name} + self.import_tracker.add_import(module_name, name) + self.import_tracker.reexport(name) + elif self.is_defined_in_module(obj) and not inspect.ismodule(obj): + # process this below + items.append((name, obj)) + else: + # from {obj_module} import {obj_name} + obj_module_name = self.get_obj_module(obj) + if obj_module_name: + self.import_tracker.add_import_from(obj_module_name, [(name, None)]) + if self.should_reexport(name, obj_module_name, name_is_alias=False): + self.import_tracker.reexport(name) + + self.set_defined_names(set([name for name, obj in all_items if not inspect.ismodule(obj)])) + + if self.resort_members: + functions: list[str] = [] + types: list[str] = [] + variables: list[str] = [] + else: + output: list[str] = [] + functions = types = variables = output + + for name, obj in items: + if self.is_function(obj): + self.generate_function_stub(name, obj, output=functions) + elif inspect.isclass(obj): + self.generate_class_stub(name, obj, output=types) + else: + self.generate_variable_stub(name, obj, output=variables) + + self._output = [] + + if self.resort_members: + for line in variables: + self._output.append(line + "\n") + for line in types: + if line.startswith("class") and self._output and self._output[-1]: + self._output.append("\n") + self._output.append(line + "\n") + if self._output and functions: + self._output.append("\n") + for line in functions: + self._output.append(line + "\n") + else: + for i, line in enumerate(output): + if ( + self._output + and line.startswith("class") + and ( + not self._output[-1].startswith("class") + or (len(output) > i + 1 and output[i + 1].startswith(" ")) + ) + ) or ( + self._output + and self._output[-1].startswith("def") + and not line.startswith("def") + ): + self._output.append("\n") + self._output.append(line + "\n") + self.check_undefined_names() + + def is_skipped_attribute(self, attr: str) -> bool: + return ( + attr + in ( + "__class__", + "__getattribute__", + "__str__", + "__repr__", + "__doc__", + "__dict__", + "__module__", + "__weakref__", + "__annotations__", ) - if inferred: - # add self/cls var, if not present - for sig in inferred: - if not sig.args or sig.args[0].name not in ("self", "cls"): - sig.args.insert(0, ArgSig(name=self_var)) - break - else: - # function: - for sig_gen in sig_generators: - inferred = sig_gen.get_function_sig(obj, module.__name__, name) - if inferred: - break - - if not inferred: - raise ValueError( - "No signature was found. This should never happen " - "if FallbackSignatureGenerator is provided" + or attr in self.IGNORED_DUNDERS + or is_pybind_skipped_attribute(attr) # For pickling + or keyword.iskeyword(attr) ) - is_overloaded = len(inferred) > 1 if inferred else False - if is_overloaded: - imports.append("from typing import overload") - if inferred: - for signature in inferred: - args: list[str] = [] - for arg in signature.args: - arg_def = arg.name - if arg_def == "None": - arg_def = "_none" # None is not a valid argument name - - if arg.type: - arg_def += ": " + strip_or_import(arg.type, module, known_modules, imports) - - if arg.default: - arg_def += " = ..." - - args.append(arg_def) - - if is_overloaded: - output.append("@overload") - # a sig generator indicates @classmethod by specifying the cls arg - if class_name and signature.args and signature.args[0].name == "cls": - output.append("@classmethod") - output_signature = "def {function}({args}) -> {ret}:".format( - function=name, - args=", ".join(args), - ret=strip_or_import(signature.ret_type, module, known_modules, imports), - ) - if include_docstrings and docstr: - docstr_quoted = mypy.util.quote_docstring(docstr.strip()) - docstr_indented = "\n ".join(docstr_quoted.split("\n")) - output.append(output_signature) - output.extend(f" {docstr_indented}".split("\n")) + def get_members(self, obj: object) -> list[tuple[str, Any]]: + obj_dict: Mapping[str, Any] = getattr(obj, "__dict__") # noqa: B009 + results = [] + for name in obj_dict: + if self.is_skipped_attribute(name): + continue + # Try to get the value via getattr + try: + value = getattr(obj, name) + except AttributeError: + continue else: - output_signature += " ..." - output.append(output_signature) - + results.append((name, value)) + return results -def strip_or_import( - typ: str, module: ModuleType, known_modules: list[str], imports: list[str] -) -> str: - """Strips unnecessary module names from typ. + def get_type_annotation(self, obj: object) -> str: + """ + Given an instance, return a string representation of its type that is valid + to use as a type annotation. + """ + if obj is None or obj is type(None): + return "None" + elif inspect.isclass(obj): + return "type[{}]".format(self.get_type_fullname(obj)) + elif isinstance(obj, FunctionType): + return self.add_name("typing.Callable") + elif isinstance(obj, ModuleType): + return self.add_name("types.ModuleType", require=False) + else: + return self.get_type_fullname(type(obj)) - If typ represents a type that is inside module or is a type coming from builtins, remove - module declaration from it. Return stripped name of the type. + def is_function(self, obj: object) -> bool: + if self.is_c_module: + return inspect.isbuiltin(obj) + else: + return inspect.isfunction(obj) + + def is_method(self, class_info: ClassInfo, name: str, obj: object) -> bool: + if self.is_c_module: + return inspect.ismethoddescriptor(obj) or type(obj) in ( + type(str.index), + type(str.__add__), + type(str.__new__), + ) + else: + # this is valid because it is only called on members of a class + return inspect.isfunction(obj) + + def is_classmethod(self, class_info: ClassInfo, name: str, obj: object) -> bool: + if self.is_c_module: + return inspect.isbuiltin(obj) or type(obj).__name__ in ( + "classmethod", + "classmethod_descriptor", + ) + else: + return inspect.ismethod(obj) - Arguments: - typ: name of the type - module: in which this type is used - known_modules: other modules being processed - imports: list of import statements (may be modified during the call) - """ - local_modules = ["builtins"] - if module: - local_modules.append(module.__name__) - - stripped_type = typ - if any(c in typ for c in "[,"): - for subtyp in re.split(r"[\[,\]]", typ): - stripped_subtyp = strip_or_import(subtyp.strip(), module, known_modules, imports) - if stripped_subtyp != subtyp: - stripped_type = re.sub( - r"(^|[\[, ]+)" + re.escape(subtyp) + r"($|[\], ]+)", - r"\1" + stripped_subtyp + r"\2", - stripped_type, - ) - elif "." in typ: - for module_name in local_modules + list(reversed(known_modules)): - if typ.startswith(module_name + "."): - if module_name in local_modules: - stripped_type = typ[len(module_name) + 1 :] - arg_module = module_name - break + def is_staticmethod(self, class_info: ClassInfo | None, name: str, obj: object) -> bool: + if self.is_c_module: + return False else: - arg_module = typ[: typ.rindex(".")] - if arg_module not in local_modules: - imports.append(f"import {arg_module}") - if stripped_type == "NoneType": - stripped_type = "None" - return stripped_type - - -def is_static_property(obj: object) -> bool: - return type(obj).__name__ == "pybind11_static_property" - - -def generate_c_property_stub( - name: str, - obj: object, - static_properties: list[str], - rw_properties: list[str], - ro_properties: list[str], - readonly: bool, - module: ModuleType | None = None, - known_modules: list[str] | None = None, - imports: list[str] | None = None, -) -> None: - """Generate property stub using introspection of 'obj'. + return class_info is not None and isinstance( + inspect.getattr_static(class_info.cls, name), staticmethod + ) - Try to infer type from docstring, append resulting lines to 'output'. - """ + @staticmethod + def is_abstract_method(obj: object) -> bool: + return getattr(obj, "__abstractmethod__", False) - def infer_prop_type(docstr: str | None) -> str | None: - """Infer property type from docstring or docstring signature.""" - if docstr is not None: - inferred = infer_ret_type_sig_from_anon_docstring(docstr) - if not inferred: - inferred = infer_ret_type_sig_from_docstring(docstr, name) - if not inferred: - inferred = infer_prop_type_from_docstring(docstr) - return inferred - else: - return None + @staticmethod + def is_property(class_info: ClassInfo, name: str, obj: object) -> bool: + return inspect.isdatadescriptor(obj) or hasattr(obj, "fget") - inferred = infer_prop_type(getattr(obj, "__doc__", None)) - if not inferred: - fget = getattr(obj, "fget", None) - inferred = infer_prop_type(getattr(fget, "__doc__", None)) - if not inferred: - inferred = "Any" - - if module is not None and imports is not None and known_modules is not None: - inferred = strip_or_import(inferred, module, known_modules, imports) - - if is_static_property(obj): - trailing_comment = " # read-only" if readonly else "" - static_properties.append(f"{name}: ClassVar[{inferred}] = ...{trailing_comment}") - else: # regular property - if readonly: - ro_properties.append("@property") - ro_properties.append(f"def {name}(self) -> {inferred}: ...") + @staticmethod + def is_property_readonly(prop: Any) -> bool: + return hasattr(prop, "fset") and prop.fset is None + + def is_static_property(self, obj: object) -> bool: + """For c-modules, whether the property behaves like an attribute""" + if self.is_c_module: + # StaticProperty is from boost-python + return type(obj).__name__ in ("pybind11_static_property", "StaticProperty") else: - rw_properties.append(f"{name}: {inferred}") + return False + + def process_inferred_sigs(self, inferred: list[FunctionSig]) -> None: + for i, sig in enumerate(inferred): + for arg in sig.args: + if arg.type is not None: + arg.type = self.strip_or_import(arg.type) + if sig.ret_type is not None: + inferred[i] = sig._replace(ret_type=self.strip_or_import(sig.ret_type)) + + def generate_function_stub( + self, name: str, obj: object, *, output: list[str], class_info: ClassInfo | None = None + ) -> None: + """Generate stub for a single function or method. + + The result (always a single line) will be appended to 'output'. + If necessary, any required names will be added to 'imports'. + The 'class_name' is used to find signature of __init__ or __new__ in + 'class_sigs'. + """ + docstring: Any = getattr(obj, "__doc__", None) + if not isinstance(docstring, str): + docstring = None + + ctx = FunctionContext( + self.module_name, + name, + docstring=docstring, + is_abstract=self.is_abstract_method(obj), + class_info=class_info, + ) + if self.is_private_name(name, ctx.fullname) or self.is_not_in_all(name): + return + self.record_name(ctx.name) + default_sig = self.get_default_function_sig(obj, ctx) + inferred = self.get_signatures(default_sig, self.sig_generators, ctx) + self.process_inferred_sigs(inferred) -def generate_c_type_stub( - module: ModuleType, - class_name: str, - obj: type, - output: list[str], - known_modules: list[str], - imports: list[str], - sig_generators: Iterable[SignatureGenerator], - include_docstrings: bool = False, -) -> None: - """Generate stub for a single class using runtime introspection. + decorators = [] + if len(inferred) > 1: + decorators.append("@{}".format(self.add_name("typing.overload"))) - The result lines will be appended to 'output'. If necessary, any - required names will be added to 'imports'. - """ - raw_lookup = getattr(obj, "__dict__") # noqa: B009 - items = sorted(get_members(obj), key=lambda x: method_name_sort_key(x[0])) - names = {x[0] for x in items} - methods: list[str] = [] - types: list[str] = [] - static_properties: list[str] = [] - rw_properties: list[str] = [] - ro_properties: list[str] = [] - attrs: list[tuple[str, Any]] = [] - for attr, value in items: - # use unevaluated descriptors when dealing with property inspection - raw_value = raw_lookup.get(attr, value) - if is_c_method(value) or is_c_classmethod(value): - if attr == "__new__": - # TODO: We should support __new__. - if "__init__" in names: - # Avoid duplicate functions if both are present. - # But is there any case where .__new__() has a - # better signature than __init__() ? - continue - attr = "__init__" - if is_c_classmethod(value): - self_var = "cls" + if ctx.is_abstract: + decorators.append("@{}".format(self.add_name("abc.abstractmethod"))) + + if class_info is not None: + if self.is_staticmethod(class_info, name, obj): + decorators.append("@staticmethod") else: - self_var = "self" - generate_c_function_stub( - module, - attr, - value, - output=methods, - known_modules=known_modules, - imports=imports, - self_var=self_var, - cls=obj, - class_name=class_name, - sig_generators=sig_generators, - include_docstrings=include_docstrings, - ) - elif is_c_property(raw_value): - generate_c_property_stub( - attr, - raw_value, - static_properties, - rw_properties, - ro_properties, - is_c_property_readonly(raw_value), - module=module, - known_modules=known_modules, - imports=imports, - ) - elif is_c_type(value): - generate_c_type_stub( - module, - attr, - value, - types, - imports=imports, - known_modules=known_modules, - sig_generators=sig_generators, - include_docstrings=include_docstrings, + for sig in inferred: + if not sig.args or sig.args[0].name not in ("self", "cls"): + sig.args.insert(0, ArgSig(name=class_info.self_var)) + # a sig generator indicates @classmethod by specifying the cls arg. + if inferred[0].args and inferred[0].args[0].name == "cls": + decorators.append("@classmethod") + + output.extend(self.format_func_def(inferred, decorators=decorators, docstring=docstring)) + self._fix_iter(ctx, inferred, output) + + def _fix_iter( + self, ctx: FunctionContext, inferred: list[FunctionSig], output: list[str] + ) -> None: + """Ensure that objects which implement old-style iteration via __getitem__ + are considered iterable. + """ + if ( + ctx.class_info + and ctx.class_info.cls is not None + and ctx.name == "__getitem__" + and "__iter__" not in ctx.class_info.cls.__dict__ + ): + item_type: str | None = None + for sig in inferred: + if sig.args and sig.args[-1].type == "int": + item_type = sig.ret_type + break + if item_type is None: + return + obj = CFunctionStub( + "__iter__", f"def __iter__(self) -> typing.Iterator[{item_type}]\n" ) + self.generate_function_stub("__iter__", obj, output=output, class_info=ctx.class_info) + + def generate_property_stub( + self, + name: str, + raw_obj: object, + obj: object, + static_properties: list[str], + rw_properties: list[str], + ro_properties: list[str], + class_info: ClassInfo | None = None, + ) -> None: + """Generate property stub using introspection of 'obj'. + + Try to infer type from docstring, append resulting lines to 'output'. + + raw_obj : object before evaluation of descriptor (if any) + obj : object after evaluation of descriptor + """ + + docstring = getattr(raw_obj, "__doc__", None) + fget = getattr(raw_obj, "fget", None) + if fget: + alt_docstr = getattr(fget, "__doc__", None) + if alt_docstr and docstring: + docstring += alt_docstr + elif alt_docstr: + docstring = alt_docstr + + ctx = FunctionContext( + self.module_name, name, docstring=docstring, is_abstract=False, class_info=class_info + ) + + if self.is_private_name(name, ctx.fullname) or self.is_not_in_all(name): + return + + self.record_name(ctx.name) + static = self.is_static_property(raw_obj) + readonly = self.is_property_readonly(raw_obj) + if static: + ret_type: str | None = self.strip_or_import(self.get_type_annotation(obj)) else: - attrs.append((attr, value)) + default_sig = self.get_default_function_sig(raw_obj, ctx) + ret_type = default_sig.ret_type + + inferred_type = self.get_property_type(ret_type, self.sig_generators, ctx) + if inferred_type is not None: + inferred_type = self.strip_or_import(inferred_type) - for attr, value in attrs: - static_properties.append( - "{}: ClassVar[{}] = ...".format( - attr, - strip_or_import(get_type_fullname(type(value)), module, known_modules, imports), + if static: + classvar = self.add_name("typing.ClassVar") + trailing_comment = " # read-only" if readonly else "" + if inferred_type is None: + inferred_type = self.add_name("_typeshed.Incomplete") + + static_properties.append( + f"{self._indent}{name}: {classvar}[{inferred_type}] = ...{trailing_comment}" ) - ) - all_bases = type.mro(obj) - if all_bases[-1] is object: - # TODO: Is this always object? - del all_bases[-1] - # remove pybind11_object. All classes generated by pybind11 have pybind11_object in their MRO, - # which only overrides a few functions in object type - if all_bases and all_bases[-1].__name__ == "pybind11_object": - del all_bases[-1] - # remove the class itself - all_bases = all_bases[1:] - # Remove base classes of other bases as redundant. - bases: list[type] = [] - for base in all_bases: - if not any(issubclass(b, base) for b in bases): - bases.append(base) - if bases: - bases_str = "(%s)" % ", ".join( - strip_or_import(get_type_fullname(base), module, known_modules, imports) - for base in bases - ) - else: - bases_str = "" - if types or static_properties or rw_properties or methods or ro_properties: - output.append(f"class {class_name}{bases_str}:") - for line in types: - if ( - output - and output[-1] - and not output[-1].startswith("class") - and line.startswith("class") + else: # regular property + if readonly: + ro_properties.append(f"{self._indent}@property") + sig = FunctionSig(name, [ArgSig("self")], inferred_type) + ro_properties.append(sig.format_sig(indent=self._indent)) + else: + if inferred_type is None: + inferred_type = self.add_name("_typeshed.Incomplete") + + rw_properties.append(f"{self._indent}{name}: {inferred_type}") + + def get_type_fullname(self, typ: type) -> str: + """Given a type, return a string representation""" + if typ is Any: + return "Any" + typename = getattr(typ, "__qualname__", typ.__name__) + module_name = self.get_obj_module(typ) + assert module_name is not None, typ + if module_name != "builtins": + typename = f"{module_name}.{typename}" + return typename + + def get_base_types(self, obj: type) -> list[str]: + all_bases = type.mro(obj) + if all_bases[-1] is object: + # TODO: Is this always object? + del all_bases[-1] + # remove pybind11_object. All classes generated by pybind11 have pybind11_object in their MRO, + # which only overrides a few functions in object type + if all_bases and all_bases[-1].__name__ == "pybind11_object": + del all_bases[-1] + # remove the class itself + all_bases = all_bases[1:] + # Remove base classes of other bases as redundant. + bases: list[type] = [] + for base in all_bases: + if not any(issubclass(b, base) for b in bases): + bases.append(base) + return [self.strip_or_import(self.get_type_fullname(base)) for base in bases] + + def generate_class_stub(self, class_name: str, cls: type, output: list[str]) -> None: + """Generate stub for a single class using runtime introspection. + + The result lines will be appended to 'output'. If necessary, any + required names will be added to 'imports'. + """ + raw_lookup = getattr(cls, "__dict__") # noqa: B009 + items = self.get_members(cls) + if self.resort_members: + items = sorted(items, key=lambda x: method_name_sort_key(x[0])) + names = set(x[0] for x in items) + methods: list[str] = [] + types: list[str] = [] + static_properties: list[str] = [] + rw_properties: list[str] = [] + ro_properties: list[str] = [] + attrs: list[tuple[str, Any]] = [] + + self.record_name(class_name) + self.indent() + + class_info = ClassInfo(class_name, "", getattr(cls, "__doc__", None), cls) + + for attr, value in items: + # use unevaluated descriptors when dealing with property inspection + raw_value = raw_lookup.get(attr, value) + if self.is_method(class_info, attr, value) or self.is_classmethod( + class_info, attr, value ): - output.append("") - output.append(" " + line) - for line in static_properties: - output.append(f" {line}") - for line in rw_properties: - output.append(f" {line}") - for line in methods: - output.append(f" {line}") - for line in ro_properties: - output.append(f" {line}") - else: - output.append(f"class {class_name}{bases_str}: ...") + if attr == "__new__": + # TODO: We should support __new__. + if "__init__" in names: + # Avoid duplicate functions if both are present. + # But is there any case where .__new__() has a + # better signature than __init__() ? + continue + attr = "__init__" + # FIXME: make this nicer + if self.is_classmethod(class_info, attr, value): + class_info.self_var = "cls" + else: + class_info.self_var = "self" + self.generate_function_stub(attr, value, output=methods, class_info=class_info) + elif self.is_property(class_info, attr, raw_value): + self.generate_property_stub( + attr, + raw_value, + value, + static_properties, + rw_properties, + ro_properties, + class_info, + ) + elif inspect.isclass(value) and self.is_defined_in_module(value): + self.generate_class_stub(attr, value, types) + else: + attrs.append((attr, value)) + for attr, value in attrs: + if attr == "__hash__" and value is None: + # special case for __hash__ + continue + prop_type_name = self.strip_or_import(self.get_type_annotation(value)) + classvar = self.add_name("typing.ClassVar") + static_properties.append(f"{self._indent}{attr}: {classvar}[{prop_type_name}] = ...") -def get_type_fullname(typ: type) -> str: - return f"{typ.__module__}.{getattr(typ, '__qualname__', typ.__name__)}" + self.dedent() + + bases = self.get_base_types(cls) + if bases: + bases_str = "(%s)" % ", ".join(bases) + else: + bases_str = "" + if types or static_properties or rw_properties or methods or ro_properties: + output.append(f"{self._indent}class {class_name}{bases_str}:") + for line in types: + if ( + output + and output[-1] + and not output[-1].strip().startswith("class") + and line.strip().startswith("class") + ): + output.append("") + output.append(line) + for line in static_properties: + output.append(line) + for line in rw_properties: + output.append(line) + for line in methods: + output.append(line) + for line in ro_properties: + output.append(line) + else: + output.append(f"{self._indent}class {class_name}{bases_str}: ...") + + def generate_variable_stub(self, name: str, obj: object, output: list[str]) -> None: + """Generate stub for a single variable using runtime introspection. + + The result lines will be appended to 'output'. If necessary, any + required names will be added to 'imports'. + """ + if self.is_private_name(name, f"{self.module_name}.{name}") or self.is_not_in_all(name): + return + self.record_name(name) + type_str = self.strip_or_import(self.get_type_annotation(obj)) + output.append(f"{name}: {type_str}") def method_name_sort_key(name: str) -> tuple[int, str]: @@ -648,22 +837,9 @@ def is_pybind_skipped_attribute(attr: str) -> bool: return attr.startswith("__pybind11_module_local_") -def is_skipped_attribute(attr: str) -> bool: - return attr in ( - "__class__", - "__getattribute__", - "__str__", - "__repr__", - "__doc__", - "__dict__", - "__module__", - "__weakref__", - ) or is_pybind_skipped_attribute( # For pickling - attr - ) - - -def infer_method_args(name: str, self_var: str | None = None) -> list[ArgSig]: +def infer_c_method_args( + name: str, self_var: str = "self", arg_names: list[str] | None = None +) -> list[ArgSig]: args: list[ArgSig] | None = None if name.startswith("__") and name.endswith("__"): name = name[2:-2] @@ -703,13 +879,9 @@ def infer_method_args(name: str, self_var: str | None = None) -> list[ArgSig]: args = [] elif name == "setstate": args = [ArgSig(name="state")] + elif name in ("eq", "ne", "lt", "le", "gt", "ge"): + args = [ArgSig(name="other", type="object")] elif name in ( - "eq", - "ne", - "lt", - "le", - "gt", - "ge", "add", "radd", "sub", @@ -761,22 +933,15 @@ def infer_method_args(name: str, self_var: str | None = None) -> list[ArgSig]: elif name == "reduce_ex": args = [ArgSig(name="protocol")] elif name == "exit": - args = [ArgSig(name="type"), ArgSig(name="value"), ArgSig(name="traceback")] + args = [ + ArgSig(name="type", type="type[BaseException] | None"), + ArgSig(name="value", type="BaseException | None"), + ArgSig(name="traceback", type="types.TracebackType | None"), + ] + if args is None: + args = infer_method_arg_types(name, self_var, arg_names) + else: + args = [ArgSig(name=self_var)] + args if args is None: args = [ArgSig(name="*args"), ArgSig(name="**kwargs")] - return [ArgSig(name=self_var or "self")] + args - - -def infer_method_ret_type(name: str) -> str: - if name.startswith("__") and name.endswith("__"): - name = name[2:-2] - if name in ("float", "bool", "bytes", "int"): - return name - # Note: __eq__ and co may return arbitrary types, but bool is good enough for stubgen. - elif name in ("eq", "ne", "lt", "le", "gt", "ge", "contains"): - return "bool" - elif name in ("len", "hash", "sizeof", "trunc", "floor", "ceil"): - return "int" - elif name in ("init", "setitem"): - return "None" - return "Any" + return args diff --git a/mypy/stubutil.py b/mypy/stubutil.py index e15766b66cb3..22e525c14e7c 100644 --- a/mypy/stubutil.py +++ b/mypy/stubutil.py @@ -5,19 +5,26 @@ import os.path import re import sys +from abc import abstractmethod +from collections import defaultdict from contextlib import contextmanager -from typing import Iterator +from typing import Final, Iterable, Iterator, Mapping from typing_extensions import overload +from mypy_extensions import mypyc_attr + +import mypy.options from mypy.modulefinder import ModuleNotFoundReason from mypy.moduleinspect import InspectError, ModuleInspect +from mypy.stubdoc import ArgSig, FunctionSig +from mypy.types import AnyType, NoneType, Type, TypeList, TypeStrVisitor, UnboundType, UnionType # Modules that may fail when imported, or that may have side effects (fully qualified). NOT_IMPORTABLE_MODULES = () class CantImport(Exception): - def __init__(self, module: str, message: str): + def __init__(self, module: str, message: str) -> None: self.module = module self.message = message @@ -70,8 +77,9 @@ def find_module_path_and_all_py3( ) -> tuple[str | None, list[str] | None] | None: """Find module and determine __all__ for a Python 3 module. - Return None if the module is a C module. Return (module_path, __all__) if - it is a Python module. Raise CantImport if import failed. + Return None if the module is a C or pyc-only module. + Return (module_path, __all__) if it is a Python module. + Raise CantImport if import failed. """ if module in NOT_IMPORTABLE_MODULES: raise CantImport(module, "") @@ -182,3 +190,591 @@ def common_dir_prefix(paths: list[str]) -> str: cur = path break return cur or "." + + +class AnnotationPrinter(TypeStrVisitor): + """Visitor used to print existing annotations in a file. + + The main difference from TypeStrVisitor is a better treatment of + unbound types. + + Notes: + * This visitor doesn't add imports necessary for annotations, this is done separately + by ImportTracker. + * It can print all kinds of types, but the generated strings may not be valid (notably + callable types) since it prints the same string that reveal_type() does. + * For Instance types it prints the fully qualified names. + """ + + # TODO: Generate valid string representation for callable types. + # TODO: Use short names for Instances. + def __init__( + self, + stubgen: BaseStubGenerator, + known_modules: list[str] | None = None, + local_modules: list[str] | None = None, + ) -> None: + super().__init__(options=mypy.options.Options()) + self.stubgen = stubgen + self.known_modules = known_modules + self.local_modules = local_modules or ["builtins"] + + def visit_any(self, t: AnyType) -> str: + s = super().visit_any(t) + self.stubgen.import_tracker.require_name(s) + return s + + def visit_unbound_type(self, t: UnboundType) -> str: + s = t.name + if self.known_modules is not None and "." in s: + # see if this object is from any of the modules that we're currently processing. + # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo". + for module_name in self.local_modules + sorted(self.known_modules, reverse=True): + if s.startswith(module_name + "."): + if module_name in self.local_modules: + s = s[len(module_name) + 1 :] + arg_module = module_name + break + else: + arg_module = s[: s.rindex(".")] + if arg_module not in self.local_modules: + self.stubgen.import_tracker.add_import(arg_module, require=True) + elif s == "NoneType": + # when called without analysis all types are unbound, so this won't hit + # visit_none_type(). + s = "None" + else: + self.stubgen.import_tracker.require_name(s) + if t.args: + s += f"[{self.args_str(t.args)}]" + return s + + def visit_none_type(self, t: NoneType) -> str: + return "None" + + def visit_type_list(self, t: TypeList) -> str: + return f"[{self.list_str(t.items)}]" + + def visit_union_type(self, t: UnionType) -> str: + return " | ".join([item.accept(self) for item in t.items]) + + def args_str(self, args: Iterable[Type]) -> str: + """Convert an array of arguments to strings and join the results with commas. + + The main difference from list_str is the preservation of quotes for string + arguments + """ + types = ["builtins.bytes", "builtins.str"] + res = [] + for arg in args: + arg_str = arg.accept(self) + if isinstance(arg, UnboundType) and arg.original_str_fallback in types: + res.append(f"'{arg_str}'") + else: + res.append(arg_str) + return ", ".join(res) + + +class ClassInfo: + def __init__( + self, name: str, self_var: str, docstring: str | None = None, cls: type | None = None + ) -> None: + self.name = name + self.self_var = self_var + self.docstring = docstring + self.cls = cls + + +class FunctionContext: + def __init__( + self, + module_name: str, + name: str, + docstring: str | None = None, + is_abstract: bool = False, + class_info: ClassInfo | None = None, + ) -> None: + self.module_name = module_name + self.name = name + self.docstring = docstring + self.is_abstract = is_abstract + self.class_info = class_info + self._fullname: str | None = None + + @property + def fullname(self) -> str: + if self._fullname is None: + if self.class_info: + self._fullname = f"{self.module_name}.{self.class_info.name}.{self.name}" + else: + self._fullname = f"{self.module_name}.{self.name}" + return self._fullname + + +def infer_method_ret_type(name: str) -> str | None: + """Infer return types for known special methods""" + if name.startswith("__") and name.endswith("__"): + name = name[2:-2] + if name in ("float", "bool", "bytes", "int", "complex", "str"): + return name + # Note: __eq__ and co may return arbitrary types, but bool is good enough for stubgen. + elif name in ("eq", "ne", "lt", "le", "gt", "ge", "contains"): + return "bool" + elif name in ("len", "length_hint", "index", "hash", "sizeof", "trunc", "floor", "ceil"): + return "int" + elif name in ("format", "repr"): + return "str" + elif name in ("init", "setitem", "del", "delitem"): + return "None" + return None + + +def infer_method_arg_types( + name: str, self_var: str = "self", arg_names: list[str] | None = None +) -> list[ArgSig] | None: + """Infer argument types for known special methods""" + args: list[ArgSig] | None = None + if name.startswith("__") and name.endswith("__"): + if arg_names and len(arg_names) >= 1 and arg_names[0] == "self": + arg_names = arg_names[1:] + + name = name[2:-2] + if name == "exit": + if arg_names is None: + arg_names = ["type", "value", "traceback"] + if len(arg_names) == 3: + arg_types = [ + "type[BaseException] | None", + "BaseException | None", + "types.TracebackType | None", + ] + args = [ + ArgSig(name=arg_name, type=arg_type) + for arg_name, arg_type in zip(arg_names, arg_types) + ] + if args is not None: + return [ArgSig(name=self_var)] + args + return None + + +@mypyc_attr(allow_interpreted_subclasses=True) +class SignatureGenerator: + """Abstract base class for extracting a list of FunctionSigs for each function.""" + + def remove_self_type( + self, inferred: list[FunctionSig] | None, self_var: str + ) -> list[FunctionSig] | None: + """Remove type annotation from self/cls argument""" + if inferred: + for signature in inferred: + if signature.args: + if signature.args[0].name == self_var: + signature.args[0].type = None + return inferred + + @abstractmethod + def get_function_sig( + self, default_sig: FunctionSig, ctx: FunctionContext + ) -> list[FunctionSig] | None: + """Return a list of signatures for the given function. + + If no signature can be found, return None. If all of the registered SignatureGenerators + for the stub generator return None, then the default_sig will be used. + """ + pass + + @abstractmethod + def get_property_type(self, default_type: str | None, ctx: FunctionContext) -> str | None: + """Return the type of the given property""" + pass + + +class ImportTracker: + """Record necessary imports during stub generation.""" + + def __init__(self) -> None: + # module_for['foo'] has the module name where 'foo' was imported from, or None if + # 'foo' is a module imported directly; + # direct_imports['foo'] is the module path used when the name 'foo' was added to the + # namespace. + # reverse_alias['foo'] is the name that 'foo' had originally when imported with an + # alias; examples + # 'from pkg import mod' ==> module_for['mod'] == 'pkg' + # 'from pkg import mod as m' ==> module_for['m'] == 'pkg' + # ==> reverse_alias['m'] == 'mod' + # 'import pkg.mod as m' ==> module_for['m'] == None + # ==> reverse_alias['m'] == 'pkg.mod' + # 'import pkg.mod' ==> module_for['pkg'] == None + # ==> module_for['pkg.mod'] == None + # ==> direct_imports['pkg'] == 'pkg.mod' + # ==> direct_imports['pkg.mod'] == 'pkg.mod' + self.module_for: dict[str, str | None] = {} + self.direct_imports: dict[str, str] = {} + self.reverse_alias: dict[str, str] = {} + + # required_names is the set of names that are actually used in a type annotation + self.required_names: set[str] = set() + + # Names that should be reexported if they come from another module + self.reexports: set[str] = set() + + def add_import_from( + self, module: str, names: list[tuple[str, str | None]], require: bool = False + ) -> None: + for name, alias in names: + if alias: + # 'from {module} import {name} as {alias}' + self.module_for[alias] = module + self.reverse_alias[alias] = name + else: + # 'from {module} import {name}' + self.module_for[name] = module + self.reverse_alias.pop(name, None) + if require: + self.require_name(alias or name) + self.direct_imports.pop(alias or name, None) + + def add_import(self, module: str, alias: str | None = None, require: bool = False) -> None: + if alias: + # 'import {module} as {alias}' + assert "." not in alias # invalid syntax + self.module_for[alias] = None + self.reverse_alias[alias] = module + if require: + self.required_names.add(alias) + else: + # 'import {module}' + name = module + if require: + self.required_names.add(name) + # add module and its parent packages + while name: + self.module_for[name] = None + self.direct_imports[name] = module + self.reverse_alias.pop(name, None) + name = name.rpartition(".")[0] + + def require_name(self, name: str) -> None: + while name not in self.direct_imports and "." in name: + name = name.rsplit(".", 1)[0] + self.required_names.add(name) + + def reexport(self, name: str) -> None: + """Mark a given non qualified name as needed in __all__. + + This means that in case it comes from a module, it should be + imported with an alias even if the alias is the same as the name. + """ + self.require_name(name) + self.reexports.add(name) + + def import_lines(self) -> list[str]: + """The list of required import lines (as strings with python code). + + In order for a module be included in this output, an indentifier must be both + 'required' via require_name() and 'imported' via add_import_from() + or add_import() + """ + result = [] + + # To summarize multiple names imported from a same module, we collect those + # in the `module_map` dictionary, mapping a module path to the list of names that should + # be imported from it. the names can also be alias in the form 'original as alias' + module_map: Mapping[str, list[str]] = defaultdict(list) + + for name in sorted( + self.required_names, + key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""), + ): + # If we haven't seen this name in an import statement, ignore it + if name not in self.module_for: + continue + + m = self.module_for[name] + if m is not None: + # This name was found in a from ... import ... + # Collect the name in the module_map + if name in self.reverse_alias: + name = f"{self.reverse_alias[name]} as {name}" + elif name in self.reexports: + name = f"{name} as {name}" + module_map[m].append(name) + else: + # This name was found in an import ... + # We can already generate the import line + if name in self.reverse_alias: + source = self.reverse_alias[name] + result.append(f"import {source} as {name}\n") + elif name in self.reexports: + assert "." not in name # Because reexports only has nonqualified names + result.append(f"import {name} as {name}\n") + else: + result.append(f"import {name}\n") + + # Now generate all the from ... import ... lines collected in module_map + for module, names in sorted(module_map.items()): + result.append(f"from {module} import {', '.join(sorted(names))}\n") + return result + + +@mypyc_attr(allow_interpreted_subclasses=True) +class BaseStubGenerator: + # These names should be omitted from generated stubs. + IGNORED_DUNDERS: Final = { + "__all__", + "__author__", + "__about__", + "__copyright__", + "__email__", + "__license__", + "__summary__", + "__title__", + "__uri__", + "__str__", + "__repr__", + "__getstate__", + "__setstate__", + "__slots__", + "__builtins__", + "__cached__", + "__file__", + "__name__", + "__package__", + "__path__", + "__spec__", + "__loader__", + } + TYPING_MODULE_NAMES: Final = ("typing", "typing_extensions") + # Special-cased names that are implicitly exported from the stub (from m import y as y). + EXTRA_EXPORTED: Final = { + "pyasn1_modules.rfc2437.univ", + "pyasn1_modules.rfc2459.char", + "pyasn1_modules.rfc2459.univ", + } + + def __init__( + self, + _all_: list[str] | None = None, + include_private: bool = False, + export_less: bool = False, + include_docstrings: bool = False, + ): + # Best known value of __all__. + self._all_ = _all_ + self._include_private = include_private + self._include_docstrings = include_docstrings + # Disable implicit exports of package-internal imports? + self.export_less = export_less + self._import_lines: list[str] = [] + self._output: list[str] = [] + # Current indent level (indent is hardcoded to 4 spaces). + self._indent = "" + self._toplevel_names: list[str] = [] + self.import_tracker = ImportTracker() + # Top-level members + self.defined_names: set[str] = set() + self.sig_generators = self.get_sig_generators() + # populated by visit_mypy_file + self.module_name: str = "" + + def get_sig_generators(self) -> list[SignatureGenerator]: + return [] + + def refers_to_fullname(self, name: str, fullname: str | tuple[str, ...]) -> bool: + """Return True if the variable name identifies the same object as the given fullname(s).""" + if isinstance(fullname, tuple): + return any(self.refers_to_fullname(name, fname) for fname in fullname) + module, short = fullname.rsplit(".", 1) + return self.import_tracker.module_for.get(name) == module and ( + name == short or self.import_tracker.reverse_alias.get(name) == short + ) + + def add_name(self, fullname: str, require: bool = True) -> str: + """Add a name to be imported and return the name reference. + + The import will be internal to the stub (i.e don't reexport). + """ + module, name = fullname.rsplit(".", 1) + alias = "_" + name if name in self.defined_names else None + self.import_tracker.add_import_from(module, [(name, alias)], require=require) + return alias or name + + def add_import_line(self, line: str) -> None: + """Add a line of text to the import section, unless it's already there.""" + if line not in self._import_lines: + self._import_lines.append(line) + + def get_imports(self) -> str: + """Return the import statements for the stub.""" + imports = "" + if self._import_lines: + imports += "".join(self._import_lines) + imports += "".join(self.import_tracker.import_lines()) + return imports + + def output(self) -> str: + """Return the text for the stub.""" + imports = self.get_imports() + if imports and self._output: + imports += "\n" + return imports + "".join(self._output) + + def add(self, string: str) -> None: + """Add text to generated stub.""" + self._output.append(string) + + def is_top_level(self) -> bool: + """Are we processing the top level of a file?""" + return self._indent == "" + + def indent(self) -> None: + """Add one level of indentation.""" + self._indent += " " + + def dedent(self) -> None: + """Remove one level of indentation.""" + self._indent = self._indent[:-4] + + def record_name(self, name: str) -> None: + """Mark a name as defined. + + This only does anything if at the top level of a module. + """ + if self.is_top_level(): + self._toplevel_names.append(name) + + def is_recorded_name(self, name: str) -> bool: + """Has this name been recorded previously?""" + return self.is_top_level() and name in self._toplevel_names + + def set_defined_names(self, defined_names: set[str]) -> None: + self.defined_names = defined_names + # Names in __all__ are required + for name in self._all_ or (): + if name not in self.IGNORED_DUNDERS: + self.import_tracker.reexport(name) + + # These are "soft" imports for objects which might appear in annotations but not have + # a corresponding import statement. + known_imports = { + "_typeshed": ["Incomplete"], + "typing": ["Any", "TypeVar", "NamedTuple"], + "collections.abc": ["Generator"], + "typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"], + } + for pkg, imports in known_imports.items(): + for t in imports: + # require=False means that the import won't be added unless require_name() is called + # for the object during generation. + self.add_name(f"{pkg}.{t}", require=False) + + def check_undefined_names(self) -> None: + print(self._all_) + print(self._toplevel_names) + undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names] + if undefined_names: + if self._output: + self.add("\n") + self.add("# Names in __all__ with no definition:\n") + for name in sorted(undefined_names): + self.add(f"# {name}\n") + + def get_signatures( + self, + default_signature: FunctionSig, + sig_generators: list[SignatureGenerator], + func_ctx: FunctionContext, + ) -> list[FunctionSig]: + for sig_gen in sig_generators: + inferred = sig_gen.get_function_sig(default_signature, func_ctx) + if inferred: + return inferred + + return [default_signature] + + def get_property_type( + self, + default_type: str | None, + sig_generators: list[SignatureGenerator], + func_ctx: FunctionContext, + ) -> str | None: + for sig_gen in sig_generators: + inferred = sig_gen.get_property_type(default_type, func_ctx) + if inferred: + return inferred + + return default_type + + def format_func_def( + self, + sigs: list[FunctionSig], + is_coroutine: bool = False, + decorators: list[str] | None = None, + docstring: str | None = None, + ) -> list[str]: + lines: list[str] = [] + if decorators is None: + decorators = [] + + for signature in sigs: + # dump decorators, just before "def ..." + for deco in decorators: + lines.append(f"{self._indent}{deco}") + + lines.append( + signature.format_sig( + indent=self._indent, + is_async=is_coroutine, + docstring=docstring if self._include_docstrings else None, + ) + ) + return lines + + def print_annotation( + self, + t: Type, + known_modules: list[str] | None = None, + local_modules: list[str] | None = None, + ) -> str: + printer = AnnotationPrinter(self, known_modules, local_modules) + return t.accept(printer) + + def is_not_in_all(self, name: str) -> bool: + if self.is_private_name(name): + return False + if self._all_: + return self.is_top_level() and name not in self._all_ + return False + + def is_private_name(self, name: str, fullname: str | None = None) -> bool: + if self._include_private: + return False + if fullname in self.EXTRA_EXPORTED: + return False + if name == "_": + return False + return name.startswith("_") and (not name.endswith("__") or name in self.IGNORED_DUNDERS) + + def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool: + if ( + not name_is_alias + and self.module_name + and (self.module_name + "." + name) in self.EXTRA_EXPORTED + ): + # Special case certain names that should be exported, against our general rules. + return True + is_private = self.is_private_name(name, full_module + "." + name) + top_level = full_module.split(".")[0] + self_top_level = self.module_name.split(".", 1)[0] + if ( + not name_is_alias + and not self.export_less + and (not self._all_ or name in self.IGNORED_DUNDERS) + and self.module_name + and not is_private + and top_level in (self_top_level, "_" + self_top_level) + ): + # Export imports from the same package, since we can't reliably tell whether they + # are part of the public API. + return True + return False diff --git a/mypy/test/teststubgen.py b/mypy/test/teststubgen.py index 7e30515ac892..ace0b4d95573 100644 --- a/mypy/test/teststubgen.py +++ b/mypy/test/teststubgen.py @@ -28,21 +28,19 @@ Options, collect_build_targets, generate_stubs, - get_sig_generators, is_blacklisted_path, is_non_library_module, mypy_options, parse_options, ) -from mypy.stubgenc import ( - generate_c_function_stub, - generate_c_property_stub, - generate_c_type_stub, - infer_method_args, +from mypy.stubgenc import InspectionStubGenerator, infer_c_method_args +from mypy.stubutil import ( + ClassInfo, + common_dir_prefix, infer_method_ret_type, - is_c_property_readonly, + remove_misplaced_type_comments, + walk_packages, ) -from mypy.stubutil import common_dir_prefix, remove_misplaced_type_comments, walk_packages from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.helpers import assert_equal, assert_string_arrays_equal, local_sys_path_set @@ -62,7 +60,8 @@ def test_files_found(self) -> None: os.mkdir(os.path.join("subdir", "pack")) self.make_file("subdir", "pack", "__init__.py") opts = parse_options(["subdir"]) - py_mods, c_mods = collect_build_targets(opts, mypy_options(opts)) + py_mods, pyi_mods, c_mods = collect_build_targets(opts, mypy_options(opts)) + assert_equal(pyi_mods, []) assert_equal(c_mods, []) files = {mod.path for mod in py_mods} assert_equal( @@ -87,7 +86,8 @@ def test_packages_found(self) -> None: self.make_file("pack", "a.py") self.make_file("pack", "b.py") opts = parse_options(["-p", "pack"]) - py_mods, c_mods = collect_build_targets(opts, mypy_options(opts)) + py_mods, pyi_mods, c_mods = collect_build_targets(opts, mypy_options(opts)) + assert_equal(pyi_mods, []) assert_equal(c_mods, []) files = {os.path.relpath(mod.path or "FAIL") for mod in py_mods} assert_equal( @@ -111,7 +111,7 @@ def test_module_not_found(self) -> None: os.chdir(tmp) self.make_file(tmp, "mymodule.py", content="import a") opts = parse_options(["-m", "mymodule"]) - py_mods, c_mods = collect_build_targets(opts, mypy_options(opts)) + collect_build_targets(opts, mypy_options(opts)) assert captured_output.getvalue() == "" finally: sys.stdout = sys.__stdout__ @@ -702,10 +702,14 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None: out_dir = "out" try: try: - if not testcase.name.endswith("_import"): - options.no_import = True - if not testcase.name.endswith("_semanal"): - options.parse_only = True + if testcase.name.endswith("_inspect"): + options.inspect = True + else: + if not testcase.name.endswith("_import"): + options.no_import = True + if not testcase.name.endswith("_semanal"): + options.parse_only = True + generate_stubs(options) a: list[str] = [] for module in modules: @@ -781,35 +785,28 @@ class StubgencSuite(unittest.TestCase): """ def test_infer_hash_sig(self) -> None: - assert_equal(infer_method_args("__hash__"), [self_arg]) + assert_equal(infer_c_method_args("__hash__"), [self_arg]) assert_equal(infer_method_ret_type("__hash__"), "int") def test_infer_getitem_sig(self) -> None: - assert_equal(infer_method_args("__getitem__"), [self_arg, ArgSig(name="index")]) + assert_equal(infer_c_method_args("__getitem__"), [self_arg, ArgSig(name="index")]) def test_infer_setitem_sig(self) -> None: assert_equal( - infer_method_args("__setitem__"), + infer_c_method_args("__setitem__"), [self_arg, ArgSig(name="index"), ArgSig(name="object")], ) assert_equal(infer_method_ret_type("__setitem__"), "None") + def test_infer_eq_op_sig(self) -> None: + for op in ("eq", "ne", "lt", "le", "gt", "ge"): + assert_equal( + infer_c_method_args(f"__{op}__"), [self_arg, ArgSig(name="other", type="object")] + ) + def test_infer_binary_op_sig(self) -> None: - for op in ( - "eq", - "ne", - "lt", - "le", - "gt", - "ge", - "add", - "radd", - "sub", - "rsub", - "mul", - "rmul", - ): - assert_equal(infer_method_args(f"__{op}__"), [self_arg, ArgSig(name="other")]) + for op in ("add", "radd", "sub", "rsub", "mul", "rmul"): + assert_equal(infer_c_method_args(f"__{op}__"), [self_arg, ArgSig(name="other")]) def test_infer_equality_op_sig(self) -> None: for op in ("eq", "ne", "lt", "le", "gt", "ge", "contains"): @@ -817,46 +814,31 @@ def test_infer_equality_op_sig(self) -> None: def test_infer_unary_op_sig(self) -> None: for op in ("neg", "pos"): - assert_equal(infer_method_args(f"__{op}__"), [self_arg]) + assert_equal(infer_c_method_args(f"__{op}__"), [self_arg]) def test_infer_cast_sig(self) -> None: for op in ("float", "bool", "bytes", "int"): assert_equal(infer_method_ret_type(f"__{op}__"), op) - def test_generate_c_type_stub_no_crash_for_object(self) -> None: + def test_generate_class_stub_no_crash_for_object(self) -> None: output: list[str] = [] mod = ModuleType("module", "") # any module is fine - imports: list[str] = [] - generate_c_type_stub( - mod, - "alias", - object, - output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) - assert_equal(imports, []) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + + gen.generate_class_stub("alias", object, output) + assert_equal(gen.get_imports().splitlines(), []) assert_equal(output[0], "class alias:") - def test_generate_c_type_stub_variable_type_annotation(self) -> None: + def test_generate_class_stub_variable_type_annotation(self) -> None: # This class mimics the stubgen unit test 'testClassVariable' class TestClassVariableCls: x = 1 output: list[str] = [] - imports: list[str] = [] mod = ModuleType("module", "") # any module is fine - generate_c_type_stub( - mod, - "C", - TestClassVariableCls, - output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) - assert_equal(imports, []) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_class_stub("C", TestClassVariableCls, output) + assert_equal(gen.get_imports().splitlines(), ["from typing import ClassVar"]) assert_equal(output, ["class C:", " x: ClassVar[int] = ..."]) def test_generate_c_type_inheritance(self) -> None: @@ -864,35 +846,19 @@ class TestClass(KeyError): pass output: list[str] = [] - imports: list[str] = [] mod = ModuleType("module, ") - generate_c_type_stub( - mod, - "C", - TestClass, - output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_class_stub("C", TestClass, output) assert_equal(output, ["class C(KeyError): ..."]) - assert_equal(imports, []) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_type_inheritance_same_module(self) -> None: output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestBaseClass.__module__, "") - generate_c_type_stub( - mod, - "C", - TestClass, - output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_class_stub("C", TestClass, output) assert_equal(output, ["class C(TestBaseClass): ..."]) - assert_equal(imports, []) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_type_inheritance_other_module(self) -> None: import argparse @@ -901,38 +867,22 @@ class TestClass(argparse.Action): pass output: list[str] = [] - imports: list[str] = [] mod = ModuleType("module", "") - generate_c_type_stub( - mod, - "C", - TestClass, - output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_class_stub("C", TestClass, output) assert_equal(output, ["class C(argparse.Action): ..."]) - assert_equal(imports, ["import argparse"]) + assert_equal(gen.get_imports().splitlines(), ["import argparse"]) def test_generate_c_type_inheritance_builtin_type(self) -> None: class TestClass(type): pass output: list[str] = [] - imports: list[str] = [] mod = ModuleType("module", "") - generate_c_type_stub( - mod, - "C", - TestClass, - output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_class_stub("C", TestClass, output) assert_equal(output, ["class C(type): ..."]) - assert_equal(imports, []) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_type_with_docstring(self) -> None: class TestClass: @@ -942,22 +892,16 @@ def test(self, arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="self", cls=TestClass, name="TestClass"), ) assert_equal(output, ["def test(self, arg0: int) -> Any: ..."]) - assert_equal(imports, []) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_type_with_docstring_no_self_arg(self) -> None: class TestClass: @@ -967,22 +911,16 @@ def test(self, arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="self", cls=TestClass, name="TestClass"), ) assert_equal(output, ["def test(self, arg0: int) -> Any: ..."]) - assert_equal(imports, []) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_type_classmethod(self) -> None: class TestClass: @@ -991,22 +929,16 @@ def test(cls, arg0: str) -> None: pass output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="cls", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="cls", cls=TestClass, name="TestClass"), ) - assert_equal(output, ["@classmethod", "def test(cls, *args, **kwargs) -> Any: ..."]) - assert_equal(imports, []) + assert_equal(output, ["@classmethod", "def test(cls, *args, **kwargs): ..."]) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_type_classmethod_with_overloads(self) -> None: class TestClass: @@ -1019,19 +951,13 @@ def test(self, arg0: str) -> None: pass output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="cls", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="cls", cls=TestClass, name="TestClass"), ) assert_equal( output, @@ -1044,7 +970,7 @@ def test(self, arg0: str) -> None: "def test(cls, arg0: int) -> Any: ...", ], ) - assert_equal(imports, ["from typing import overload"]) + assert_equal(gen.get_imports().splitlines(), ["from typing import overload"]) def test_generate_c_type_with_docstring_empty_default(self) -> None: class TestClass: @@ -1054,22 +980,16 @@ def test(self, arg0: str = "") -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="self", cls=TestClass, name="TestClass"), ) assert_equal(output, ["def test(self, arg0: str = ...) -> Any: ..."]) - assert_equal(imports, []) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_function_other_module_arg(self) -> None: """Test that if argument references type from other module, module will be imported.""" @@ -1082,19 +1002,11 @@ def test(arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(self.__module__, "") - generate_c_function_stub( - mod, - "test", - test, - output=output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub("test", test, output=output) assert_equal(output, ["def test(arg0: argparse.Action) -> Any: ..."]) - assert_equal(imports, ["import argparse"]) + assert_equal(gen.get_imports().splitlines(), ["import argparse"]) def test_generate_c_function_same_module(self) -> None: """Test that if annotation references type from same module but using full path, no module @@ -1109,19 +1021,11 @@ def test(arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType("argparse", "") - generate_c_function_stub( - mod, - "test", - test, - output=output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub("test", test, output=output) assert_equal(output, ["def test(arg0: Action) -> Action: ..."]) - assert_equal(imports, []) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_function_other_module(self) -> None: """Test that if annotation references type from other module, module will be imported.""" @@ -1132,19 +1036,11 @@ def test(arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(self.__module__, "") - generate_c_function_stub( - mod, - "test", - test, - output=output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub("test", test, output=output) assert_equal(output, ["def test(arg0: argparse.Action) -> argparse.Action: ..."]) - assert_equal(set(imports), {"import argparse"}) + assert_equal(gen.get_imports().splitlines(), ["import argparse"]) def test_generate_c_function_same_module_nested(self) -> None: """Test that if annotation references type from same module but using full path, no module @@ -1159,19 +1055,11 @@ def test(arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType("argparse", "") - generate_c_function_stub( - mod, - "test", - test, - output=output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub("test", test, output=output) assert_equal(output, ["def test(arg0: list[Action]) -> list[Action]: ..."]) - assert_equal(imports, []) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_function_same_module_compound(self) -> None: """Test that if annotation references type from same module but using full path, no module @@ -1186,19 +1074,11 @@ def test(arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType("argparse", "") - generate_c_function_stub( - mod, - "test", - test, - output=output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) - assert_equal(output, ["def test(arg0: Union[Action,None]) -> Tuple[Action,None]: ..."]) - assert_equal(imports, []) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub("test", test, output=output) + assert_equal(output, ["def test(arg0: Union[Action, None]) -> Tuple[Action, None]: ..."]) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_function_other_module_nested(self) -> None: """Test that if annotation references type from other module, module will be imported, @@ -1210,19 +1090,13 @@ def test(arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(self.__module__, "") - generate_c_function_stub( - mod, - "test", - test, - output=output, - imports=imports, - known_modules=["foo", "foo.spangle", "bar"], - sig_generators=get_sig_generators(parse_options([])), + gen = InspectionStubGenerator( + mod.__name__, known_modules=["foo", "foo.spangle", "bar"], module=mod ) + gen.generate_function_stub("test", test, output=output) assert_equal(output, ["def test(arg0: foo.bar.Action) -> other.Thing: ..."]) - assert_equal(set(imports), {"import foo", "import other"}) + assert_equal(gen.get_imports().splitlines(), ["import foo", "import other"]) def test_generate_c_function_no_crash_for_non_str_docstring(self) -> None: def test(arg0: str) -> None: @@ -1231,19 +1105,11 @@ def test(arg0: str) -> None: test.__doc__ = property(lambda self: "test(arg0: str) -> None") # type: ignore[assignment] output: list[str] = [] - imports: list[str] = [] mod = ModuleType(self.__module__, "") - generate_c_function_stub( - mod, - "test", - test, - output=output, - imports=imports, - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), - ) - assert_equal(output, ["def test(*args, **kwargs) -> Any: ..."]) - assert_equal(imports, []) + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub("test", test, output=output) + assert_equal(output, ["def test(*args, **kwargs): ..."]) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_property_with_pybind11(self) -> None: """Signatures included by PyBind11 inside property.fget are read.""" @@ -1258,13 +1124,15 @@ def get_attribute(self) -> None: readwrite_properties: list[str] = [] readonly_properties: list[str] = [] - generate_c_property_stub( + mod = ModuleType("module", "") # any module is fine + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_property_stub( "attribute", + TestClass.__dict__["attribute"], TestClass.attribute, [], readwrite_properties, readonly_properties, - is_c_property_readonly(TestClass.attribute), ) assert_equal(readwrite_properties, []) assert_equal(readonly_properties, ["@property", "def attribute(self) -> str: ..."]) @@ -1284,15 +1152,17 @@ def attribute(self, value: int) -> None: readwrite_properties: list[str] = [] readonly_properties: list[str] = [] - generate_c_property_stub( + mod = ModuleType("module", "") # any module is fine + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_property_stub( "attribute", - type(TestClass.attribute), + TestClass.__dict__["attribute"], + TestClass.attribute, [], readwrite_properties, readonly_properties, - is_c_property_readonly(TestClass.attribute), ) - assert_equal(readwrite_properties, ["attribute: Any"]) + assert_equal(readwrite_properties, ["attribute: Incomplete"]) assert_equal(readonly_properties, []) def test_generate_c_type_with_single_arg_generic(self) -> None: @@ -1303,22 +1173,16 @@ def test(self, arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="self", cls=TestClass, name="TestClass"), ) assert_equal(output, ["def test(self, arg0: List[int]) -> Any: ..."]) - assert_equal(imports, []) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_type_with_double_arg_generic(self) -> None: class TestClass: @@ -1328,22 +1192,16 @@ def test(self, arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="self", cls=TestClass, name="TestClass"), ) - assert_equal(output, ["def test(self, arg0: Dict[str,int]) -> Any: ..."]) - assert_equal(imports, []) + assert_equal(output, ["def test(self, arg0: Dict[str, int]) -> Any: ..."]) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_type_with_nested_generic(self) -> None: class TestClass: @@ -1353,22 +1211,16 @@ def test(self, arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="self", cls=TestClass, name="TestClass"), ) - assert_equal(output, ["def test(self, arg0: Dict[str,List[int]]) -> Any: ..."]) - assert_equal(imports, []) + assert_equal(output, ["def test(self, arg0: Dict[str, List[int]]) -> Any: ..."]) + assert_equal(gen.get_imports().splitlines(), []) def test_generate_c_type_with_generic_using_other_module_first(self) -> None: class TestClass: @@ -1378,22 +1230,16 @@ def test(self, arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="self", cls=TestClass, name="TestClass"), ) - assert_equal(output, ["def test(self, arg0: Dict[argparse.Action,int]) -> Any: ..."]) - assert_equal(imports, ["import argparse"]) + assert_equal(output, ["def test(self, arg0: Dict[argparse.Action, int]) -> Any: ..."]) + assert_equal(gen.get_imports().splitlines(), ["import argparse"]) def test_generate_c_type_with_generic_using_other_module_last(self) -> None: class TestClass: @@ -1403,22 +1249,16 @@ def test(self, arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "test", TestClass.test, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="self", cls=TestClass, name="TestClass"), ) - assert_equal(output, ["def test(self, arg0: Dict[str,argparse.Action]) -> Any: ..."]) - assert_equal(imports, ["import argparse"]) + assert_equal(output, ["def test(self, arg0: Dict[str, argparse.Action]) -> Any: ..."]) + assert_equal(gen.get_imports().splitlines(), ["import argparse"]) def test_generate_c_type_with_overload_pybind11(self) -> None: class TestClass: @@ -1433,19 +1273,13 @@ def __init__(self, arg0: str) -> None: """ output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "__init__", TestClass.__init__, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo(self_var="self", cls=TestClass, name="TestClass"), ) assert_equal( output, @@ -1458,7 +1292,7 @@ def __init__(self, arg0: str) -> None: "def __init__(self, *args, **kwargs) -> Any: ...", ], ) - assert_equal(set(imports), {"from typing import overload"}) + assert_equal(gen.get_imports().splitlines(), ["from typing import overload"]) def test_generate_c_type_with_overload_shiboken(self) -> None: class TestClass: @@ -1471,19 +1305,18 @@ def __init__(self, arg0: str) -> None: pass output: list[str] = [] - imports: list[str] = [] mod = ModuleType(TestClass.__module__, "") - generate_c_function_stub( - mod, + gen = InspectionStubGenerator(mod.__name__, known_modules=[mod.__name__], module=mod) + gen.generate_function_stub( "__init__", TestClass.__init__, output=output, - imports=imports, - self_var="self", - cls=TestClass, - class_name="TestClass", - known_modules=[mod.__name__], - sig_generators=get_sig_generators(parse_options([])), + class_info=ClassInfo( + self_var="self", + cls=TestClass, + name="TestClass", + docstring=getattr(TestClass, "__doc__", None), + ), ) assert_equal( output, @@ -1494,7 +1327,7 @@ def __init__(self, arg0: str) -> None: "def __init__(self, arg0: str, arg1: str) -> None: ...", ], ) - assert_equal(set(imports), {"from typing import overload"}) + assert_equal(gen.get_imports().splitlines(), ["from typing import overload"]) class ArgSigSuite(unittest.TestCase): diff --git a/mypy/traverser.py b/mypy/traverser.py index 2fcc376cfb7c..d11dd395f978 100644 --- a/mypy/traverser.py +++ b/mypy/traverser.py @@ -2,7 +2,7 @@ from __future__ import annotations -from mypy_extensions import mypyc_attr +from mypy_extensions import mypyc_attr, trait from mypy.nodes import ( REVEAL_TYPE, @@ -94,6 +94,7 @@ from mypy.visitor import NodeVisitor +@trait @mypyc_attr(allow_interpreted_subclasses=True) class TraverserVisitor(NodeVisitor[None]): """A parse tree visitor that traverses the parse tree during visiting. diff --git a/setup.py b/setup.py index dcbdc96b3ccf..e3ebe9dd62ec 100644 --- a/setup.py +++ b/setup.py @@ -112,7 +112,6 @@ def run(self): "stubtest.py", "stubgenc.py", "stubdoc.py", - "stubutil.py", ) ) + ( # Don't want to grab this accidentally diff --git a/test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/__init__.pyi b/test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/__init__.pyi index e69de29bb2d1..0cb252f00259 100644 --- a/test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/__init__.pyi +++ b/test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/__init__.pyi @@ -0,0 +1 @@ +from . import basics as basics diff --git a/test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/basics.pyi b/test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/basics.pyi index ab5a4f4e78d2..6527f5733eaf 100644 --- a/test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/basics.pyi +++ b/test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/basics.pyi @@ -1,7 +1,7 @@ -from typing import ClassVar +from typing import ClassVar, overload -from typing import overload PI: float +__version__: str class Point: class AngleUnit: @@ -11,12 +11,10 @@ class Point: radian: ClassVar[Point.AngleUnit] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property @@ -30,12 +28,10 @@ class Point: pixel: ClassVar[Point.LengthUnit] = ... def __init__(self, value: int) -> None: ... def __eq__(self, other: object) -> bool: ... - def __getstate__(self) -> int: ... def __hash__(self) -> int: ... def __index__(self) -> int: ... def __int__(self) -> int: ... def __ne__(self, other: object) -> bool: ... - def __setstate__(self, state: int) -> None: ... @property def name(self) -> str: ... @property diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 23dbf36a551b..d83d74306230 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -11,6 +11,11 @@ def f() -> None: ... [case testTwoFunctions] def f(a, b): + """ + this is a docstring + + more. + """ x = 1 def g(arg): pass @@ -37,11 +42,21 @@ def f(x=True, y=False): ... [out] def f(x: bool = ..., y: bool = ...) -> None: ... +[case testDefaultArgBool_inspect] +def f(x=True, y=False): ... +[out] +def f(x: bool = ..., y: bool = ...): ... + [case testDefaultArgStr] def f(x='foo'): ... [out] def f(x: str = ...) -> None: ... +[case testDefaultArgStr_inspect] +def f(x='foo'): ... +[out] +def f(x: str = ...): ... + [case testDefaultArgBytes] def f(x=b'foo'): ... [out] @@ -300,6 +315,7 @@ __all__ = [] __author__ = '' __version__ = '' [out] +__version__: str [case testBaseClass] class A: ... @@ -361,6 +377,24 @@ class A: def f(self, x) -> None: ... def h(self) -> None: ... +-- a read/write property is treated the same as an attribute +[case testProperty_inspect] +class A: + @property + def f(self): + return 1 + @f.setter + def f(self, x): ... + + def h(self): + self.f = 1 +[out] +from _typeshed import Incomplete + +class A: + f: Incomplete + def h(self): ... + [case testFunctoolsCachedProperty] import functools @@ -435,6 +469,15 @@ class A: @classmethod def f(cls) -> None: ... +[case testClassMethod_inspect] +class A: + @classmethod + def f(cls): ... +[out] +class A: + @classmethod + def f(cls): ... + [case testIfMainCheck] def a(): ... if __name__ == '__main__': @@ -472,6 +515,23 @@ class B: ... class C: def f(self) -> None: ... +[case testNoSpacesBetweenEmptyClasses_inspect] +class X: + def g(self): ... +class A: ... +class B: ... +class C: + def f(self): ... +[out] +class X: + def g(self): ... + +class A: ... +class B: ... + +class C: + def f(self): ... + [case testExceptionBaseClasses] class A(Exception): ... class B(ValueError): ... @@ -490,6 +550,17 @@ class A: class A: def __eq__(self): ... +[case testOmitSomeSpecialMethods_inspect] +class A: + def __str__(self): ... + def __repr__(self): ... + def __eq__(self): ... + def __getstate__(self): ... + def __setstate__(self, state): ... +[out] +class A: + def __eq__(self) -> bool: ... + -- Tests that will perform runtime imports of modules. -- Don't use `_import` suffix if there are unquoted forward references. @@ -507,6 +578,13 @@ def g(): ... [out] def f() -> None: ... +[case testOmitDefsNotInAll_inspect] +__all__ = [] + ['f'] +def f(): ... +def g(): ... +[out] +def f(): ... + [case testVarDefsNotInAll_import] __all__ = [] + ['f', 'g'] def f(): ... @@ -517,6 +595,16 @@ def g(): ... def f() -> None: ... def g() -> None: ... +[case testVarDefsNotInAll_inspect] +__all__ = [] + ['f', 'g'] +def f(): ... +x = 1 +y = 1 +def g(): ... +[out] +def f(): ... +def g(): ... + [case testIncludeClassNotInAll_import] __all__ = [] + ['f'] def f(): ... @@ -526,6 +614,15 @@ def f() -> None: ... class A: ... +[case testIncludeClassNotInAll_inspect] +__all__ = [] + ['f'] +def f(): ... +class A: ... +[out] +def f(): ... + +class A: ... + [case testAllAndClass_import] __all__ = ['A'] class A: @@ -636,6 +733,23 @@ class C: # Names in __all__ with no definition: # g +[case testCommentForUndefinedName_inspect] +__all__ = ['f', 'x', 'C', 'g'] +def f(): ... +x = 1 +class C: + def g(self): ... +[out] +def f(): ... + +x: int + +class C: + def g(self): ... + +# Names in __all__ with no definition: +# g + [case testIgnoreSlots] class A: __slots__ = () @@ -649,6 +763,13 @@ class A: [out] class A: ... +[case testSkipPrivateProperty_inspect] +class A: + @property + def _foo(self): ... +[out] +class A: ... + [case testIncludePrivateProperty] # flags: --include-private class A: @@ -659,6 +780,16 @@ class A: @property def _foo(self) -> None: ... +[case testIncludePrivateProperty_inspect] +# flags: --include-private +class A: + @property + def _foo(self): ... +[out] +class A: + @property + def _foo(self): ... + [case testSkipPrivateStaticAndClassMethod] class A: @staticmethod @@ -668,6 +799,15 @@ class A: [out] class A: ... +[case testSkipPrivateStaticAndClassMethod_inspect] +class A: + @staticmethod + def _foo(): ... + @classmethod + def _bar(cls): ... +[out] +class A: ... + [case testIncludePrivateStaticAndClassMethod] # flags: --include-private class A: @@ -682,6 +822,20 @@ class A: @classmethod def _bar(cls) -> None: ... +[case testIncludePrivateStaticAndClassMethod_inspect] +# flags: --include-private +class A: + @staticmethod + def _foo(): ... + @classmethod + def _bar(cls): ... +[out] +class A: + @staticmethod + def _foo(): ... + @classmethod + def _bar(cls): ... + [case testNamedtuple] import collections, typing, x X = collections.namedtuple('X', ['a', 'b']) @@ -1801,6 +1955,19 @@ class Outer: class Inner: ... A = Outer.Inner +-- needs improvement +[case testNestedClass_inspect] +class Outer: + class Inner: + pass + +A = Outer.Inner +[out] +class Outer: + class Inner: ... + +class A: ... + [case testFunctionAlias_semanal] from asyncio import coroutine @@ -2034,6 +2201,25 @@ class A: def f(x) -> None: ... def g(x, y: str): ... +class A: + def f(self, x) -> None: ... + +-- Same as above +[case testFunctionPartiallyAnnotated_inspect] +def f(x) -> None: + pass + +def g(x, y: str): + pass + +class A: + def f(self, x) -> None: + pass + +[out] +def f(x) -> None: ... +def g(x, y: str): ... + class A: def f(self, x) -> None: ... @@ -2054,6 +2240,24 @@ def f(x: Any): ... def g(x, y: Any) -> str: ... def h(x: Any) -> str: ... +-- Same as above +[case testExplicitAnyArg_inspect] +from typing import Any + +def f(x: Any): + pass +def g(x, y: Any) -> str: + pass +def h(x: Any) -> str: + pass + +[out] +from typing import Any + +def f(x: Any): ... +def g(x, y: Any) -> str: ... +def h(x: Any) -> str: ... + [case testExplicitReturnedAny] from typing import Any @@ -2385,6 +2589,28 @@ def g() -> None: ... +[case testTestFiles_inspect] +# modules: p p.x p.tests p.tests.test_foo + +[file p/__init__.py] +def f(): pass + +[file p/x.py] +def g(): pass + +[file p/tests/__init__.py] + +[file p/tests/test_foo.py] +def test_thing(): pass + +[out] +# p/__init__.pyi +def f(): ... +# p/x.pyi +def g(): ... + + + [case testVerboseFlag] # Just test that --verbose does not break anything in a basic test case. # flags: --verbose @@ -2686,6 +2912,8 @@ __uri__ = '' __version__ = '' [out] +from m import __version__ as __version__ + class A: ... [case testHideDunderModuleAttributesWithAll_import] @@ -2715,6 +2943,7 @@ __uri__ = '' __version__ = '' [out] +from m import __version__ as __version__ [case testAttrsClass_semanal] import attrs @@ -2949,7 +3178,6 @@ class A: @overload def f(self, x: Tuple[int, int]) -> int: ... - @overload def f(x: int, y: int) -> int: ... @overload @@ -2993,7 +3221,6 @@ class A: @overload def f(self, x: Tuple[int, int]) -> int: ... - @overload def f(x: int, y: int) -> int: ... @overload @@ -3068,7 +3295,6 @@ class A: @classmethod def g(cls, x: typing.Tuple[int, int]) -> int: ... - @typing.overload def f(x: int, y: int) -> int: ... @typing.overload @@ -3147,7 +3373,6 @@ class A: @classmethod def g(cls, x: t.Tuple[int, int]) -> int: ... - @t.overload def f(x: int, y: int) -> int: ... @t.overload @@ -3345,6 +3570,67 @@ class Some: def __float__(self) -> float: ... def __index__(self) -> int: ... +-- Same as above +[case testKnownMagicMethodsReturnTypes_inspect] +class Some: + def __len__(self): ... + def __length_hint__(self): ... + def __init__(self): ... + def __del__(self): ... + def __bool__(self): ... + def __bytes__(self): ... + def __format__(self, spec): ... + def __contains__(self, obj): ... + def __complex__(self): ... + def __int__(self): ... + def __float__(self): ... + def __index__(self): ... +[out] +class Some: + def __len__(self) -> int: ... + def __length_hint__(self) -> int: ... + def __init__(self) -> None: ... + def __del__(self) -> None: ... + def __bool__(self) -> bool: ... + def __bytes__(self) -> bytes: ... + def __format__(self, spec) -> str: ... + def __contains__(self, obj) -> bool: ... + def __complex__(self) -> complex: ... + def __int__(self) -> int: ... + def __float__(self) -> float: ... + def __index__(self) -> int: ... + + +[case testKnownMagicMethodsArgTypes] +class MismatchNames: + def __exit__(self, tp, val, tb): ... + +class MatchNames: + def __exit__(self, type, value, traceback): ... + +[out] +class MismatchNames: + def __exit__(self, tp: type[BaseException] | None, val: BaseException | None, tb: types.TracebackType | None) -> None: ... + +class MatchNames: + def __exit__(self, type: type[BaseException] | None, value: BaseException | None, traceback: types.TracebackType | None) -> None: ... + +-- Same as above (but can generate import statements) +[case testKnownMagicMethodsArgTypes_inspect] +class MismatchNames: + def __exit__(self, tp, val, tb): ... + +class MatchNames: + def __exit__(self, type, value, traceback): ... + +[out] +import types + +class MismatchNames: + def __exit__(self, tp: type[BaseException] | None, val: BaseException | None, tb: types.TracebackType | None): ... + +class MatchNames: + def __exit__(self, type: type[BaseException] | None, value: BaseException | None, traceback: types.TracebackType | None): ... [case testTypeVarPEP604Bound] from typing import TypeVar @@ -3397,7 +3683,7 @@ from typing import TypedDict X = TypedDict('X', a=int, b=str) Y = TypedDict('X', a=int, b=str, total=False) [out] -from typing import TypedDict +from typing_extensions import TypedDict class X(TypedDict): a: int