Skip to content

stubgen: Use NamedTuple class syntax #10625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def __init__(self,
# Disable implicit exports of package-internal imports?
self.export_less = export_less
# Add imports that could be implicitly generated
self.import_tracker.add_import_from("collections", [("namedtuple", None)])
self.import_tracker.add_import_from("typing", [("NamedTuple", None)])
# Names in __all__ are required
for name in _all_ or ():
if name not in IGNORED_DUNDERS:
Expand Down Expand Up @@ -900,18 +900,24 @@ def is_namedtuple(self, expr: Expression) -> bool:
def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
if self._state != EMPTY:
self.add('\n')
name = repr(getattr(rvalue.args[0], 'value', ERROR_MARKER))
if isinstance(rvalue.args[1], StrExpr):
items = repr(rvalue.args[1].value)
items = rvalue.args[1].value.split(" ")
elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)):
list_items = cast(List[StrExpr], rvalue.args[1].items)
items = '[%s]' % ', '.join(repr(item.value) for item in list_items)
items = [item.value for item in list_items]
else:
self.add('%s%s: Any' % (self._indent, lvalue.name))
self.import_tracker.require_name('Any')
return
self.import_tracker.require_name('namedtuple')
self.add('%s%s = namedtuple(%s, %s)\n' % (self._indent, lvalue.name, name, items))
self.import_tracker.require_name('NamedTuple')
self.add('{}class {}(NamedTuple):'.format(self._indent, lvalue.name))
if len(items) == 0:
self.add(' ...\n')
else:
self.import_tracker.require_name('Any')
self.add('\n')
for item in items:
self.add('{} {}: Any\n'.format(self._indent, item))
self._state = CLASS

def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool:
Expand Down
52 changes: 38 additions & 14 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -591,30 +591,44 @@ class A:
import collections, x
X = collections.namedtuple('X', ['a', 'b'])
[out]
from collections import namedtuple
from typing import Any, NamedTuple

class X(NamedTuple):
a: Any
b: Any

X = namedtuple('X', ['a', 'b'])
[case testEmptyNamedtuple]
import collections
X = collections.namedtuple('X', [])
[out]
from typing import NamedTuple

class X(NamedTuple): ...

[case testNamedtupleAltSyntax]
from collections import namedtuple, xx
X = namedtuple('X', 'a b')
xx
[out]
from collections import namedtuple
from typing import Any, NamedTuple

X = namedtuple('X', 'a b')
class X(NamedTuple):
a: Any
b: Any

[case testNamedtupleWithUnderscore]
from collections import namedtuple as _namedtuple
def f(): ...
X = _namedtuple('X', 'a b')
def g(): ...
[out]
from collections import namedtuple
from typing import Any, NamedTuple

def f() -> None: ...

X = namedtuple('X', 'a b')
class X(NamedTuple):
a: Any
b: Any

def g() -> None: ...

Expand All @@ -623,9 +637,11 @@ import collections, x
_X = collections.namedtuple('_X', ['a', 'b'])
class Y(_X): ...
[out]
from collections import namedtuple
from typing import Any, NamedTuple

_X = namedtuple('_X', ['a', 'b'])
class _X(NamedTuple):
a: Any
b: Any

class Y(_X): ...

Expand All @@ -636,13 +652,19 @@ Y = namedtuple('Y', ('a',))
Z = namedtuple('Z', ('a', 'b', 'c', 'd', 'e'))
xx
[out]
from collections import namedtuple
from typing import Any, NamedTuple

X = namedtuple('X', [])
class X(NamedTuple): ...

Y = namedtuple('Y', ['a'])
class Y(NamedTuple):
a: Any

Z = namedtuple('Z', ['a', 'b', 'c', 'd', 'e'])
class Z(NamedTuple):
a: Any
b: Any
c: Any
d: Any
e: Any

[case testDynamicNamedTuple]
from collections import namedtuple
Expand Down Expand Up @@ -2187,10 +2209,12 @@ from collections import namedtuple
class C:
N = namedtuple('N', ['x', 'y'])
[out]
from collections import namedtuple
from typing import Any, NamedTuple

class C:
N = namedtuple('N', ['x', 'y'])
class N(NamedTuple):
x: Any
y: Any

[case testImports_directImportsWithAlias]
import p.a as a
Expand Down