diff --git a/mypy/stubgen.py b/mypy/stubgen.py index 2e86eb3b9cee7..071a238b57141 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -104,7 +104,6 @@ TupleExpr, TypeInfo, UnaryExpr, - is_StrExpr_list, ) from mypy.options import Options as MypyOptions from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures @@ -129,6 +128,7 @@ from mypy.types import ( OVERLOAD_NAMES, TPDICT_NAMES, + TYPED_NAMEDTUPLE_NAMES, AnyType, CallableType, Instance, @@ -400,10 +400,12 @@ def visit_str_expr(self, node: StrExpr) -> str: def visit_index_expr(self, node: IndexExpr) -> str: base = node.base.accept(self) index = node.index.accept(self) + if len(index) > 2 and index.startswith("(") and index.endswith(")"): + index = index[1:-1] return f"{base}[{index}]" def visit_tuple_expr(self, node: TupleExpr) -> str: - return ", ".join(n.accept(self) for n in node.items) + return f"({', '.join(n.accept(self) for n in node.items)})" def visit_list_expr(self, node: ListExpr) -> str: return f"[{', '.join(n.accept(self) for n in node.items)}]" @@ -1010,6 +1012,37 @@ def get_base_types(self, cdef: ClassDef) -> list[str]: elif isinstance(base, IndexExpr): p = AliasPrinter(self) base_types.append(base.accept(p)) + elif isinstance(base, CallExpr): + # namedtuple(typename, fields), NamedTuple(typename, fields) calls can + # be used as a base class. The first argument is a string literal that + # is usually the same as the class name. + # + # Note: + # A call-based named tuple as a base class cannot be safely converted to + # a class-based NamedTuple definition because class attributes defined + # in the body of the class inheriting from the named tuple call are not + # namedtuple fields at runtime. + if self.is_namedtuple(base): + nt_fields = self._get_namedtuple_fields(base) + assert isinstance(base.args[0], StrExpr) + typename = base.args[0].value + if nt_fields is not None: + # A valid namedtuple() call, use NamedTuple() instead with + # Incomplete as field types + fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields) + base_types.append(f"NamedTuple({typename!r}, [{fields_str}])") + self.add_typing_import("NamedTuple") + else: + # Invalid namedtuple() call, cannot determine fields + base_types.append("Incomplete") + elif self.is_typed_namedtuple(base): + p = AliasPrinter(self) + base_types.append(base.accept(p)) + else: + # At this point, we don't know what the base class is, so we + # just use Incomplete as the base class. + base_types.append("Incomplete") + self.import_tracker.require_name("Incomplete") return base_types def visit_block(self, o: Block) -> None: @@ -1022,8 +1055,11 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: foundl = [] for lvalue in o.lvalues: - if isinstance(lvalue, NameExpr) and self.is_namedtuple(o.rvalue): - assert isinstance(o.rvalue, CallExpr) + if ( + isinstance(lvalue, NameExpr) + and isinstance(o.rvalue, CallExpr) + and (self.is_namedtuple(o.rvalue) or self.is_typed_namedtuple(o.rvalue)) + ): self.process_namedtuple(lvalue, o.rvalue) continue if ( @@ -1069,37 +1105,79 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None: if all(foundl): self._state = VAR - def is_namedtuple(self, expr: Expression) -> bool: - if not isinstance(expr, CallExpr): - return False + def is_namedtuple(self, expr: CallExpr) -> bool: callee = expr.callee - return (isinstance(callee, NameExpr) and callee.name.endswith("namedtuple")) or ( - isinstance(callee, MemberExpr) and callee.name == "namedtuple" + return ( + isinstance(callee, NameExpr) + and (self.refers_to_fullname(callee.name, "collections.namedtuple")) + ) or ( + isinstance(callee, MemberExpr) + and isinstance(callee.expr, NameExpr) + and f"{callee.expr.name}.{callee.name}" == "collections.namedtuple" ) + def is_typed_namedtuple(self, expr: CallExpr) -> bool: + callee = expr.callee + return ( + isinstance(callee, NameExpr) + and self.refers_to_fullname(callee.name, TYPED_NAMEDTUPLE_NAMES) + ) or ( + isinstance(callee, MemberExpr) + and isinstance(callee.expr, NameExpr) + and f"{callee.expr.name}.{callee.name}" in TYPED_NAMEDTUPLE_NAMES + ) + + def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None: + if self.is_namedtuple(call): + fields_arg = call.args[1] + if isinstance(fields_arg, StrExpr): + field_names = fields_arg.value.replace(",", " ").split() + elif isinstance(fields_arg, (ListExpr, TupleExpr)): + field_names = [] + for field in fields_arg.items: + if not isinstance(field, StrExpr): + return None + field_names.append(field.value) + else: + return None # Invalid namedtuple fields type + if field_names: + self.import_tracker.require_name("Incomplete") + return [(field_name, "Incomplete") for field_name in field_names] + elif self.is_typed_namedtuple(call): + fields_arg = call.args[1] + if not isinstance(fields_arg, (ListExpr, TupleExpr)): + return None + fields: list[tuple[str, str]] = [] + b = AliasPrinter(self) + for field in fields_arg.items: + if not (isinstance(field, TupleExpr) and len(field.items) == 2): + return None + field_name, field_type = field.items + if not isinstance(field_name, StrExpr): + return None + fields.append((field_name.value, field_type.accept(b))) + return fields + else: + return None # Not a named tuple call + def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: if self._state != EMPTY: self.add("\n") - if isinstance(rvalue.args[1], StrExpr): - items = rvalue.args[1].value.replace(",", " ").split() - elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)): - list_items = rvalue.args[1].items - assert is_StrExpr_list(list_items) - items = [item.value for item in list_items] - else: + fields = self._get_namedtuple_fields(rvalue) + if fields is None: self.add(f"{self._indent}{lvalue.name}: Incomplete") self.import_tracker.require_name("Incomplete") return self.import_tracker.require_name("NamedTuple") self.add(f"{self._indent}class {lvalue.name}(NamedTuple):") - if not items: + if len(fields) == 0: self.add(" ...\n") + self._state = EMPTY_CLASS else: - self.import_tracker.require_name("Incomplete") self.add("\n") - for item in items: - self.add(f"{self._indent} {item}: Incomplete\n") - self._state = CLASS + for f_name, f_type in fields: + self.add(f"{self._indent} {f_name}: {f_type}\n") + self._state = CLASS def is_typeddict(self, expr: CallExpr) -> bool: callee = expr.callee diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index 5e934a52f8e83..eba838f32cd5a 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -641,8 +641,9 @@ class A: def _bar(cls) -> None: ... [case testNamedtuple] -import collections, x +import collections, typing, x X = collections.namedtuple('X', ['a', 'b']) +Y = typing.NamedTuple('Y', [('a', int), ('b', str)]) [out] from _typeshed import Incomplete from typing import NamedTuple @@ -651,14 +652,21 @@ class X(NamedTuple): a: Incomplete b: Incomplete +class Y(NamedTuple): + a: int + b: str + [case testEmptyNamedtuple] -import collections +import collections, typing X = collections.namedtuple('X', []) +Y = typing.NamedTuple('Y', []) [out] from typing import NamedTuple class X(NamedTuple): ... +class Y(NamedTuple): ... + [case testNamedtupleAltSyntax] from collections import namedtuple, xx X = namedtuple('X', 'a b') @@ -697,8 +705,10 @@ class X(NamedTuple): [case testNamedtupleWithUnderscore] from collections import namedtuple as _namedtuple +from typing import NamedTuple as _NamedTuple def f(): ... X = _namedtuple('X', 'a b') +Y = _NamedTuple('Y', [('a', int), ('b', str)]) def g(): ... [out] from _typeshed import Incomplete @@ -710,6 +720,10 @@ class X(NamedTuple): a: Incomplete b: Incomplete +class Y(NamedTuple): + a: int + b: str + def g() -> None: ... [case testNamedtupleBaseClass] @@ -728,10 +742,14 @@ class Y(_X): ... [case testNamedtupleAltSyntaxFieldsTuples] from collections import namedtuple, xx +from typing import NamedTuple X = namedtuple('X', ()) Y = namedtuple('Y', ('a',)) Z = namedtuple('Z', ('a', 'b', 'c', 'd', 'e')) xx +R = NamedTuple('R', ()) +S = NamedTuple('S', (('a', int),)) +T = NamedTuple('T', (('a', int), ('b', str))) [out] from _typeshed import Incomplete from typing import NamedTuple @@ -748,13 +766,62 @@ class Z(NamedTuple): d: Incomplete e: Incomplete +class R(NamedTuple): ... + +class S(NamedTuple): + a: int + +class T(NamedTuple): + a: int + b: str + [case testDynamicNamedTuple] from collections import namedtuple +from typing import NamedTuple N = namedtuple('N', ['x', 'y'] + ['z']) +M = NamedTuple('M', [('x', int), ('y', str)] + [('z', float)]) +class X(namedtuple('X', ['a', 'b'] + ['c'])): ... [out] from _typeshed import Incomplete N: Incomplete +M: Incomplete +class X(Incomplete): ... + +[case testNamedTupleInClassBases] +import collections, typing +from collections import namedtuple +from typing import NamedTuple +class X(namedtuple('X', ['a', 'b'])): ... +class Y(NamedTuple('Y', [('a', int), ('b', str)])): ... +class R(collections.namedtuple('R', ['a', 'b'])): ... +class S(typing.NamedTuple('S', [('a', int), ('b', str)])): ... +[out] +import typing +from _typeshed import Incomplete +from typing import NamedTuple + +class X(NamedTuple('X', [('a', Incomplete), ('b', Incomplete)])): ... +class Y(NamedTuple('Y', [('a', int), ('b', str)])): ... +class R(NamedTuple('R', [('a', Incomplete), ('b', Incomplete)])): ... +class S(typing.NamedTuple('S', [('a', int), ('b', str)])): ... + +[case testNotNamedTuple] +from not_collections import namedtuple +from not_typing import NamedTuple +from collections import notnamedtuple +from typing import NotNamedTuple +X = namedtuple('X', ['a', 'b']) +Y = notnamedtuple('Y', ['a', 'b']) +Z = NamedTuple('Z', [('a', int), ('b', str)]) +W = NotNamedTuple('W', [('a', int), ('b', str)]) +[out] +from _typeshed import Incomplete + +X: Incomplete +Y: Incomplete +Z: Incomplete +W: Incomplete [case testArbitraryBaseClass] import x