Skip to content

Commit

Permalink
Support recursive named tuples (#13371)
Browse files Browse the repository at this point in the history
This is a continuation of #13297 

The main change here is that although named tuples are still stored in symbol tables as `TypeInfo`s, when type analyzer sees them, it creates a `TypeAliasType` targeting what it would return before (a `TupleType` with a fallback to an instance of that `TypeInfo`). Although it is a significant change, IMO this is the simplest but still clean way to support recursive named tuples.

Also it is very simple to extend to TypedDicts, but I wanted to make the latter in a separate PR, to minimize the scope of changes.
It would be great if someone can take a look at this PR soon.

The most code changes are to make named tuples semantic analysis idempotent, previously they were analyzed "for real" only once, when all types were ready. It is not possible anymore if we want them to be recursive. So I pass in `existing_info` everywhere, and update it instead of creating a new one every time.
  • Loading branch information
ilevkivskyi committed Aug 10, 2022
1 parent 27c5a9e commit 03638dd
Show file tree
Hide file tree
Showing 16 changed files with 611 additions and 67 deletions.
5 changes: 3 additions & 2 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
The return type of the appropriate ``__get__`` overload for the descriptor.
"""
instance_type = get_proper_type(mx.original_type)
orig_descriptor_type = descriptor_type
descriptor_type = get_proper_type(descriptor_type)

if isinstance(descriptor_type, UnionType):
Expand All @@ -571,10 +572,10 @@ def analyze_descriptor_access(descriptor_type: Type, mx: MemberContext) -> Type:
[analyze_descriptor_access(typ, mx) for typ in descriptor_type.items]
)
elif not isinstance(descriptor_type, Instance):
return descriptor_type
return orig_descriptor_type

if not descriptor_type.type.has_readable_member("__get__"):
return descriptor_type
return orig_descriptor_type

dunder_get = descriptor_type.type.get_method("__get__")
if dunder_get is None:
Expand Down
7 changes: 7 additions & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def visit_type_info(self, info: TypeInfo) -> None:
p.accept(self.type_fixer)
if info.tuple_type:
info.tuple_type.accept(self.type_fixer)
info.update_tuple_type(info.tuple_type)
if info.typeddict_type:
info.typeddict_type.accept(self.type_fixer)
if info.declared_metaclass:
Expand Down Expand Up @@ -337,6 +338,12 @@ def lookup_fully_qualified_alias(
node = stnode.node if stnode else None
if isinstance(node, TypeAlias):
return node
elif isinstance(node, TypeInfo):
if node.tuple_alias:
return node.tuple_alias
alias = TypeAlias.from_tuple_type(node)
node.tuple_alias = alias
return alias
else:
# Looks like a missing TypeAlias during an initial daemon load, put something there
assert (
Expand Down
5 changes: 5 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,6 +2292,11 @@ def visit_instance(self, t: Instance) -> None:
self.instances.append(t)
super().visit_instance(t)

def visit_type_alias_type(self, t: TypeAliasType) -> None:
if t.alias and not t.is_recursive:
t.alias.target.accept(self)
super().visit_type_alias_type(t)


def find_type_overlaps(*types: Type) -> Set[str]:
"""Return a set of fullnames that share a short name and appear in either type.
Expand Down
25 changes: 25 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2656,6 +2656,7 @@ class is generic then it will be a type constructor of higher kind.
"bases",
"_promote",
"tuple_type",
"tuple_alias",
"is_named_tuple",
"typeddict_type",
"is_newtype",
Expand Down Expand Up @@ -2794,6 +2795,9 @@ class is generic then it will be a type constructor of higher kind.
# It is useful for plugins to add their data to save in the cache.
metadata: Dict[str, JsonDict]

# Store type alias representing this type (for named tuples).
tuple_alias: Optional["TypeAlias"]

FLAGS: Final = [
"is_abstract",
"is_enum",
Expand Down Expand Up @@ -2840,6 +2844,7 @@ def __init__(self, names: "SymbolTable", defn: ClassDef, module_name: str) -> No
self._promote = []
self.alt_promote = None
self.tuple_type = None
self.tuple_alias = None
self.is_named_tuple = False
self.typeddict_type = None
self.is_newtype = False
Expand Down Expand Up @@ -2970,6 +2975,15 @@ def direct_base_classes(self) -> "List[TypeInfo]":
"""
return [base.type for base in self.bases]

def update_tuple_type(self, typ: "mypy.types.TupleType") -> None:
"""Update tuple_type and tuple_alias as needed."""
self.tuple_type = typ
alias = TypeAlias.from_tuple_type(self)
if not self.tuple_alias:
self.tuple_alias = alias
else:
self.tuple_alias.target = alias.target

def __str__(self) -> str:
"""Return a string representation of the type.
Expand Down Expand Up @@ -3258,6 +3272,17 @@ def __init__(
self.eager = eager
super().__init__(line, column)

@classmethod
def from_tuple_type(cls, info: TypeInfo) -> "TypeAlias":
"""Generate an alias to the tuple type described by a given TypeInfo."""
assert info.tuple_type
return TypeAlias(
info.tuple_type.copy_modified(fallback=mypy.types.Instance(info, [])),
info.fullname,
info.line,
info.column,
)

@property
def name(self) -> str:
return self._fullname.split(".")[-1]
Expand Down
56 changes: 25 additions & 31 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,11 @@
PRIORITY_FALLBACKS,
SemanticAnalyzerInterface,
calculate_tuple_fallback,
has_placeholder,
set_callable_name as set_callable_name,
)
from mypy.semanal_typeddict import TypedDictAnalyzer
from mypy.tvar_scope import TypeVarLikeScope
from mypy.type_visitor import TypeQuery
from mypy.typeanal import (
TypeAnalyser,
TypeVarLikeList,
Expand Down Expand Up @@ -1425,7 +1425,12 @@ def analyze_class_body_common(self, defn: ClassDef) -> None:

def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
"""Check if this class can define a named tuple."""
if defn.info and defn.info.is_named_tuple:
if (
defn.info
and defn.info.is_named_tuple
and defn.info.tuple_type
and not has_placeholder(defn.info.tuple_type)
):
# Don't reprocess everything. We just need to process methods defined
# in the named tuple class body.
is_named_tuple, info = True, defn.info # type: bool, Optional[TypeInfo]
Expand Down Expand Up @@ -1782,10 +1787,9 @@ def configure_base_classes(
base_types: List[Instance] = []
info = defn.info

info.tuple_type = None
for base, base_expr in bases:
if isinstance(base, TupleType):
actual_base = self.configure_tuple_base_class(defn, base, base_expr)
actual_base = self.configure_tuple_base_class(defn, base)
base_types.append(actual_base)
elif isinstance(base, Instance):
if base.type.is_newtype:
Expand Down Expand Up @@ -1828,23 +1832,19 @@ def configure_base_classes(
return
self.calculate_class_mro(defn, self.object_type)

def configure_tuple_base_class(
self, defn: ClassDef, base: TupleType, base_expr: Expression
) -> Instance:
def configure_tuple_base_class(self, defn: ClassDef, base: TupleType) -> Instance:
info = defn.info

# There may be an existing valid tuple type from previous semanal iterations.
# Use equality to check if it is the case.
if info.tuple_type and info.tuple_type != base:
if info.tuple_type and info.tuple_type != base and not has_placeholder(info.tuple_type):
self.fail("Class has two incompatible bases derived from tuple", defn)
defn.has_incompatible_baseclass = True
info.tuple_type = base
if isinstance(base_expr, CallExpr):
defn.analyzed = NamedTupleExpr(base.partial_fallback.type)
defn.analyzed.line = defn.line
defn.analyzed.column = defn.column
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
self.defer(force_progress=True)
info.update_tuple_type(base)

if base.partial_fallback.type.fullname == "builtins.tuple":
if base.partial_fallback.type.fullname == "builtins.tuple" and not has_placeholder(base):
# Fallback can only be safely calculated after semantic analysis, since base
# classes may be incomplete. Postpone the calculation.
self.schedule_patch(PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(base))
Expand Down Expand Up @@ -2627,7 +2627,10 @@ def analyze_enum_assign(self, s: AssignmentStmt) -> bool:
def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool:
"""Check if s defines a namedtuple."""
if isinstance(s.rvalue, CallExpr) and isinstance(s.rvalue.analyzed, NamedTupleExpr):
return True # This is a valid and analyzed named tuple definition, nothing to do here.
if s.rvalue.analyzed.info.tuple_type and not has_placeholder(
s.rvalue.analyzed.info.tuple_type
):
return True # This is a valid and analyzed named tuple definition, nothing to do here.
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], (NameExpr, MemberExpr)):
return False
lvalue = s.lvalues[0]
Expand Down Expand Up @@ -3028,6 +3031,9 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
# unless using PEP 613 `cls: TypeAlias = A`
return False

if isinstance(s.rvalue, CallExpr) and s.rvalue.analyzed:
return False

existing = self.current_symbol_table().get(lvalue.name)
# Third rule: type aliases can't be re-defined. For example:
# A: Type[float] = int
Expand Down Expand Up @@ -3157,9 +3163,8 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> bool:
self.cannot_resolve_name(lvalue.name, "name", s)
return True
else:
self.progress = True
# We need to defer so that this change can get propagated to base classes.
self.defer(s)
self.defer(s, force_progress=True)
else:
self.add_symbol(lvalue.name, alias_node, s)
if isinstance(rvalue, RefExpr) and isinstance(rvalue.node, TypeAlias):
Expand Down Expand Up @@ -5484,7 +5489,7 @@ def tvar_scope_frame(self, frame: TypeVarLikeScope) -> Iterator[None]:
yield
self.tvar_scope = old_scope

def defer(self, debug_context: Optional[Context] = None) -> None:
def defer(self, debug_context: Optional[Context] = None, force_progress: bool = False) -> None:
"""Defer current analysis target to be analyzed again.
This must be called if something in the current target is
Expand All @@ -5498,6 +5503,8 @@ def defer(self, debug_context: Optional[Context] = None) -> None:
They are usually preferable to a direct defer() call.
"""
assert not self.final_iteration, "Must not defer during final iteration"
if force_progress:
self.progress = True
self.deferred = True
# Store debug info for this deferral.
line = (
Expand Down Expand Up @@ -5999,19 +6006,6 @@ def is_future_flag_set(self, flag: str) -> bool:
return self.modules[self.cur_mod_id].is_future_flag_set(flag)


class HasPlaceholders(TypeQuery[bool]):
def __init__(self) -> None:
super().__init__(any)

def visit_placeholder_type(self, t: PlaceholderType) -> bool:
return True


def has_placeholder(typ: Type) -> bool:
"""Check if a type contains any placeholder types (recursively)."""
return typ.accept(HasPlaceholders())


def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
if isinstance(sig, CallableType):
if len(sig.arg_types) == 0:
Expand Down
49 changes: 41 additions & 8 deletions mypy/semanal_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
PRIORITY_FALLBACKS,
SemanticAnalyzerInterface,
calculate_tuple_fallback,
has_placeholder,
set_callable_name,
)
from mypy.types import (
Expand Down Expand Up @@ -109,8 +110,11 @@ def analyze_namedtuple_classdef(
items, types, default_items = result
if is_func_scope and "@" not in defn.name:
defn.name += "@" + str(defn.line)
existing_info = None
if isinstance(defn.analyzed, NamedTupleExpr):
existing_info = defn.analyzed.info
info = self.build_namedtuple_typeinfo(
defn.name, items, types, default_items, defn.line
defn.name, items, types, default_items, defn.line, existing_info
)
defn.info = info
defn.analyzed = NamedTupleExpr(info, is_typed=True)
Expand Down Expand Up @@ -164,7 +168,14 @@ def check_namedtuple_classdef(
if stmt.type is None:
types.append(AnyType(TypeOfAny.unannotated))
else:
analyzed = self.api.anal_type(stmt.type)
# We never allow recursive types at function scope. Although it is
# possible to support this for named tuples, it is still tricky, and
# it would be inconsistent with type aliases.
analyzed = self.api.anal_type(
stmt.type,
allow_placeholder=self.options.enable_recursive_aliases
and not self.api.is_func_scope(),
)
if analyzed is None:
# Something is incomplete. We need to defer this named tuple.
return None
Expand Down Expand Up @@ -226,7 +237,7 @@ def check_namedtuple(
name += "@" + str(call.line)
else:
name = var_name = "namedtuple@" + str(call.line)
info = self.build_namedtuple_typeinfo(name, [], [], {}, node.line)
info = self.build_namedtuple_typeinfo(name, [], [], {}, node.line, None)
self.store_namedtuple_info(info, var_name, call, is_typed)
if name != var_name or is_func_scope:
# NOTE: we skip local namespaces since they are not serialized.
Expand Down Expand Up @@ -262,12 +273,22 @@ def check_namedtuple(
}
else:
default_items = {}
info = self.build_namedtuple_typeinfo(name, items, types, default_items, node.line)

existing_info = None
if isinstance(node.analyzed, NamedTupleExpr):
existing_info = node.analyzed.info
info = self.build_namedtuple_typeinfo(
name, items, types, default_items, node.line, existing_info
)

# If var_name is not None (i.e. this is not a base class expression), we always
# store the generated TypeInfo under var_name in the current scope, so that
# other definitions can use it.
if var_name:
self.store_namedtuple_info(info, var_name, call, is_typed)
else:
call.analyzed = NamedTupleExpr(info, is_typed=is_typed)
call.analyzed.set_line(call)
# There are three cases where we need to store the generated TypeInfo
# second time (for the purpose of serialization):
# * If there is a name mismatch like One = NamedTuple('Other', [...])
Expand Down Expand Up @@ -408,7 +429,12 @@ def parse_namedtuple_fields_with_types(
except TypeTranslationError:
self.fail("Invalid field type", type_node)
return None
analyzed = self.api.anal_type(type)
# We never allow recursive types at function scope.
analyzed = self.api.anal_type(
type,
allow_placeholder=self.options.enable_recursive_aliases
and not self.api.is_func_scope(),
)
# Workaround #4987 and avoid introducing a bogus UnboundType
if isinstance(analyzed, UnboundType):
analyzed = AnyType(TypeOfAny.from_error)
Expand All @@ -428,6 +454,7 @@ def build_namedtuple_typeinfo(
types: List[Type],
default_items: Mapping[str, Expression],
line: int,
existing_info: Optional[TypeInfo],
) -> TypeInfo:
strtype = self.api.named_type("builtins.str")
implicit_any = AnyType(TypeOfAny.special_form)
Expand All @@ -448,18 +475,23 @@ def build_namedtuple_typeinfo(
literals: List[Type] = [LiteralType(item, strtype) for item in items]
match_args_type = TupleType(literals, basetuple_type)

info = self.api.basic_new_typeinfo(name, fallback, line)
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
info.is_named_tuple = True
tuple_base = TupleType(types, fallback)
info.tuple_type = tuple_base
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
self.api.defer(force_progress=True)
info.update_tuple_type(tuple_base)
info.line = line
# For use by mypyc.
info.metadata["namedtuple"] = {"fields": items.copy()}

# We can't calculate the complete fallback type until after semantic
# analysis, since otherwise base classes might be incomplete. Postpone a
# callback function that patches the fallback.
self.api.schedule_patch(PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(tuple_base))
if not has_placeholder(tuple_base):
self.api.schedule_patch(
PRIORITY_FALLBACKS, lambda: calculate_tuple_fallback(tuple_base)
)

def add_field(
var: Var, is_initialized_in_class: bool = False, is_property: bool = False
Expand Down Expand Up @@ -489,6 +521,7 @@ def add_field(
if self.options.python_version >= (3, 10):
add_field(Var("__match_args__", match_args_type), is_initialized_in_class=True)

assert info.tuple_type is not None # Set by update_tuple_type() above.
tvd = TypeVarType(
SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], info.tuple_type
)
Expand Down
Loading

0 comments on commit 03638dd

Please sign in to comment.