Skip to content

Commit

Permalink
Support recursive TypedDicts (#13373)
Browse files Browse the repository at this point in the history
This is a continuation of #13297
Depends on #13371 

It was actually quite easy, essentially just a 1-to-1 mapping from the other PR.
  • Loading branch information
ilevkivskyi authored Aug 11, 2022
1 parent 601802c commit cba07d7
Show file tree
Hide file tree
Showing 11 changed files with 354 additions and 48 deletions.
15 changes: 11 additions & 4 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,17 @@ def lookup_fully_qualified_alias(
if isinstance(node, TypeAlias):
return node
elif isinstance(node, TypeInfo):
if node.tuple_alias:
return node.tuple_alias
alias = TypeAlias.from_tuple_type(node)
node.tuple_alias = alias
if node.special_alias:
# Already fixed up.
return node.special_alias
if node.tuple_type:
alias = TypeAlias.from_tuple_type(node)
elif node.typeddict_type:
alias = TypeAlias.from_typeddict_type(node)
else:
assert allow_missing
return missing_alias()
node.special_alias = alias
return alias
else:
# Looks like a missing TypeAlias during an initial daemon load, put something there
Expand Down
44 changes: 36 additions & 8 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2656,7 +2656,7 @@ class is generic then it will be a type constructor of higher kind.
"bases",
"_promote",
"tuple_type",
"tuple_alias",
"special_alias",
"is_named_tuple",
"typeddict_type",
"is_newtype",
Expand Down Expand Up @@ -2795,8 +2795,16 @@ class is generic then it will be a type constructor of higher kind.
# It is useful for plugins to add their data to save in the cache.
metadata: Dict[str, JsonDict]

# Store type alias representing this type (for named tuples).
tuple_alias: Optional["TypeAlias"]
# Store type alias representing this type (for named tuples and TypedDicts).
# Although definitions of these types are stored in symbol tables as TypeInfo,
# when a type analyzer will find them, it should construct a TupleType, or
# a TypedDict type. However, we can't use the plain types, since if the definition
# is recursive, this will create an actual recursive structure of types (i.e. as
# internal Python objects) causing infinite recursions everywhere during type checking.
# To overcome this, we create a TypeAlias node, that will point to these types.
# We store this node in the `special_alias` attribute, because it must be the same node
# in case we are doing multiple semantic analysis passes.
special_alias: Optional["TypeAlias"]

FLAGS: Final = [
"is_abstract",
Expand Down Expand Up @@ -2844,7 +2852,7 @@ def __init__(self, names: "SymbolTable", defn: ClassDef, module_name: str) -> No
self._promote = []
self.alt_promote = None
self.tuple_type = None
self.tuple_alias = None
self.special_alias = None
self.is_named_tuple = False
self.typeddict_type = None
self.is_newtype = False
Expand Down Expand Up @@ -2976,13 +2984,22 @@ def direct_base_classes(self) -> "List[TypeInfo]":
return [base.type for base in self.bases]

def update_tuple_type(self, typ: "mypy.types.TupleType") -> None:
"""Update tuple_type and tuple_alias as needed."""
"""Update tuple_type and special_alias as needed."""
self.tuple_type = typ
alias = TypeAlias.from_tuple_type(self)
if not self.tuple_alias:
self.tuple_alias = alias
if not self.special_alias:
self.special_alias = alias
else:
self.tuple_alias.target = alias.target
self.special_alias.target = alias.target

def update_typeddict_type(self, typ: "mypy.types.TypedDictType") -> None:
"""Update typeddict_type and special_alias as needed."""
self.typeddict_type = typ
alias = TypeAlias.from_typeddict_type(self)
if not self.special_alias:
self.special_alias = alias
else:
self.special_alias.target = alias.target

def __str__(self) -> str:
"""Return a string representation of the type.
Expand Down Expand Up @@ -3283,6 +3300,17 @@ def from_tuple_type(cls, info: TypeInfo) -> "TypeAlias":
info.column,
)

@classmethod
def from_typeddict_type(cls, info: TypeInfo) -> "TypeAlias":
"""Generate an alias to the TypedDict type described by a given TypeInfo."""
assert info.typeddict_type
return TypeAlias(
info.typeddict_type.copy_modified(fallback=mypy.types.Instance(info, [])),
info.fullname,
info.line,
info.column,
)

@property
def name(self) -> str:
return self._fullname.split(".")[-1]
Expand Down
47 changes: 34 additions & 13 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,17 +1378,7 @@ def analyze_class(self, defn: ClassDef) -> None:
self.mark_incomplete(defn.name, defn)
return

is_typeddict, info = self.typed_dict_analyzer.analyze_typeddict_classdef(defn)
if is_typeddict:
for decorator in defn.decorators:
decorator.accept(self)
if isinstance(decorator, RefExpr):
if decorator.fullname in FINAL_DECORATOR_NAMES:
self.fail("@final cannot be used with TypedDict", decorator)
if info is None:
self.mark_incomplete(defn.name, defn)
else:
self.prepare_class_def(defn, info)
if self.analyze_typeddict_classdef(defn):
return

if self.analyze_namedtuple_classdef(defn):
Expand Down Expand Up @@ -1423,6 +1413,28 @@ def analyze_class_body_common(self, defn: ClassDef) -> None:
self.apply_class_plugin_hooks(defn)
self.leave_class()

def analyze_typeddict_classdef(self, defn: ClassDef) -> bool:
if (
defn.info
and defn.info.typeddict_type
and not has_placeholder(defn.info.typeddict_type)
):
# This is a valid TypedDict, and it is fully analyzed.
return True
is_typeddict, info = self.typed_dict_analyzer.analyze_typeddict_classdef(defn)
if is_typeddict:
for decorator in defn.decorators:
decorator.accept(self)
if isinstance(decorator, RefExpr):
if decorator.fullname in FINAL_DECORATOR_NAMES:
self.fail("@final cannot be used with TypedDict", decorator)
if info is None:
self.mark_incomplete(defn.name, defn)
else:
self.prepare_class_def(defn, info)
return True
return False

def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
"""Check if this class can define a named tuple."""
if (
Expand Down Expand Up @@ -1840,7 +1852,7 @@ def configure_tuple_base_class(self, defn: ClassDef, base: TupleType) -> Instanc
if info.tuple_type and info.tuple_type != base and not has_placeholder(info.tuple_type):
self.fail("Class has two incompatible bases derived from tuple", defn)
defn.has_incompatible_baseclass = True
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
if info.special_alias and has_placeholder(info.special_alias.target):
self.defer(force_progress=True)
info.update_tuple_type(base)

Expand Down Expand Up @@ -2660,7 +2672,11 @@ def analyze_namedtuple_assign(self, s: AssignmentStmt) -> bool:
def analyze_typeddict_assign(self, s: AssignmentStmt) -> bool:
"""Check if s defines a typed dict."""
if isinstance(s.rvalue, CallExpr) and isinstance(s.rvalue.analyzed, TypedDictExpr):
return True # This is a valid and analyzed typed dict definition, nothing to do here.
if s.rvalue.analyzed.info.typeddict_type and not has_placeholder(
s.rvalue.analyzed.info.typeddict_type
):
# This is a valid and analyzed typed dict definition, nothing to do here.
return True
if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], (NameExpr, MemberExpr)):
return False
lvalue = s.lvalues[0]
Expand Down Expand Up @@ -5504,6 +5520,11 @@ def defer(self, debug_context: Optional[Context] = None, force_progress: bool =
"""
assert not self.final_iteration, "Must not defer during final iteration"
if force_progress:
# Usually, we report progress if we have replaced a placeholder node
# with an actual valid node. However, sometimes we need to update an
# existing node *in-place*. For example, this is used by type aliases
# in context of forward references and/or recursive aliases, and in
# similar situations (recursive named tuples etc).
self.progress = True
self.deferred = True
# Store debug info for this deferral.
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal_namedtuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def build_namedtuple_typeinfo(
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
info.is_named_tuple = True
tuple_base = TupleType(types, fallback)
if info.tuple_alias and has_placeholder(info.tuple_alias.target):
if info.special_alias and has_placeholder(info.special_alias.target):
self.api.defer(force_progress=True)
info.update_tuple_type(tuple_base)
info.line = line
Expand Down
7 changes: 5 additions & 2 deletions mypy/semanal_newtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ def process_newtype_declaration(self, s: AssignmentStmt) -> bool:

old_type, should_defer = self.check_newtype_args(var_name, call, s)
old_type = get_proper_type(old_type)
if not call.analyzed:
if not isinstance(call.analyzed, NewTypeExpr):
call.analyzed = NewTypeExpr(var_name, old_type, line=call.line, column=call.column)
else:
call.analyzed.old_type = old_type
if old_type is None:
if should_defer:
# Base type is not ready.
Expand Down Expand Up @@ -230,6 +232,7 @@ def build_newtype_typeinfo(
existing_info: Optional[TypeInfo],
) -> TypeInfo:
info = existing_info or self.api.basic_new_typeinfo(name, base_type, line)
info.bases = [base_type] # Update in case there were nested placeholders.
info.is_newtype = True

# Add __init__ method
Expand All @@ -250,7 +253,7 @@ def build_newtype_typeinfo(
init_func._fullname = info.fullname + ".__init__"
info.names["__init__"] = SymbolTableNode(MDEF, init_func)

if info.tuple_type and has_placeholder(info.tuple_type):
if has_placeholder(old_type) or info.tuple_type and has_placeholder(info.tuple_type):
self.api.defer(force_progress=True)
return info

Expand Down
49 changes: 39 additions & 10 deletions mypy/semanal_typeddict.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
TypeInfo,
)
from mypy.options import Options
from mypy.semanal_shared import SemanticAnalyzerInterface
from mypy.semanal_shared import SemanticAnalyzerInterface, has_placeholder
from mypy.typeanal import check_for_explicit_any, has_any_from_unimported_type
from mypy.types import TPDICT_NAMES, AnyType, RequiredType, Type, TypedDictType, TypeOfAny

Expand Down Expand Up @@ -66,6 +66,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ
if base_expr.fullname in TPDICT_NAMES or self.is_typeddict(base_expr):
possible = True
if possible:
existing_info = None
if isinstance(defn.analyzed, TypedDictExpr):
existing_info = defn.analyzed.info
if (
len(defn.base_type_exprs) == 1
and isinstance(defn.base_type_exprs[0], RefExpr)
Expand All @@ -76,7 +79,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ
if fields is None:
return True, None # Defer
info = self.build_typeddict_typeinfo(
defn.name, fields, types, required_keys, defn.line
defn.name, fields, types, required_keys, defn.line, existing_info
)
defn.analyzed = TypedDictExpr(info)
defn.analyzed.line = defn.line
Expand Down Expand Up @@ -128,7 +131,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> Tuple[bool, Optional[Typ
keys.extend(new_keys)
types.extend(new_types)
required_keys.update(new_required_keys)
info = self.build_typeddict_typeinfo(defn.name, keys, types, required_keys, defn.line)
info = self.build_typeddict_typeinfo(
defn.name, keys, types, required_keys, defn.line, existing_info
)
defn.analyzed = TypedDictExpr(info)
defn.analyzed.line = defn.line
defn.analyzed.column = defn.column
Expand Down Expand Up @@ -173,7 +178,12 @@ def analyze_typeddict_classdef_fields(
if stmt.type is None:
types.append(AnyType(TypeOfAny.unannotated))
else:
analyzed = self.api.anal_type(stmt.type, allow_required=True)
analyzed = self.api.anal_type(
stmt.type,
allow_required=True,
allow_placeholder=self.options.enable_recursive_aliases
and not self.api.is_func_scope(),
)
if analyzed is None:
return None, [], set() # Need to defer
types.append(analyzed)
Expand Down Expand Up @@ -232,7 +242,7 @@ def check_typeddict(
name, items, types, total, ok = res
if not ok:
# Error. Construct dummy return value.
info = self.build_typeddict_typeinfo("TypedDict", [], [], set(), call.line)
info = self.build_typeddict_typeinfo("TypedDict", [], [], set(), call.line, None)
else:
if var_name is not None and name != var_name:
self.fail(
Expand All @@ -254,7 +264,12 @@ def check_typeddict(
types = [ # unwrap Required[T] to just T
t.item if isinstance(t, RequiredType) else t for t in types # type: ignore[misc]
]
info = self.build_typeddict_typeinfo(name, items, types, required_keys, call.line)
existing_info = None
if isinstance(node.analyzed, TypedDictExpr):
existing_info = node.analyzed.info
info = self.build_typeddict_typeinfo(
name, items, types, required_keys, call.line, existing_info
)
info.line = node.line
# Store generated TypeInfo under both names, see semanal_namedtuple for more details.
if name != var_name or is_func_scope:
Expand Down Expand Up @@ -357,7 +372,12 @@ def parse_typeddict_fields_with_types(
else:
self.fail_typeddict_arg("Invalid field type", field_type_expr)
return [], [], False
analyzed = self.api.anal_type(type, allow_required=True)
analyzed = self.api.anal_type(
type,
allow_required=True,
allow_placeholder=self.options.enable_recursive_aliases
and not self.api.is_func_scope(),
)
if analyzed is None:
return None
types.append(analyzed)
Expand All @@ -370,7 +390,13 @@ def fail_typeddict_arg(
return "", [], [], True, False

def build_typeddict_typeinfo(
self, name: str, items: List[str], types: List[Type], required_keys: Set[str], line: int
self,
name: str,
items: List[str],
types: List[Type],
required_keys: Set[str],
line: int,
existing_info: Optional[TypeInfo],
) -> TypeInfo:
# Prefer typing then typing_extensions if available.
fallback = (
Expand All @@ -379,8 +405,11 @@ def build_typeddict_typeinfo(
or self.api.named_type_or_none("mypy_extensions._TypedDict", [])
)
assert fallback is not None
info = self.api.basic_new_typeinfo(name, fallback, line)
info.typeddict_type = TypedDictType(dict(zip(items, types)), required_keys, fallback)
info = existing_info or self.api.basic_new_typeinfo(name, fallback, line)
typeddict_type = TypedDictType(dict(zip(items, types)), required_keys, fallback)
if info.special_alias and has_placeholder(info.special_alias.target):
self.api.defer(force_progress=True)
info.update_typeddict_type(typeddict_type)
return info

# Helpers
Expand Down
16 changes: 8 additions & 8 deletions mypy/server/astmerge.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def replacement_map_from_symbol_table(
node.node.names, new_node.node.names, prefix
)
replacements.update(type_repl)
if node.node.tuple_alias and new_node.node.tuple_alias:
replacements[new_node.node.tuple_alias] = node.node.tuple_alias
if node.node.special_alias and new_node.node.special_alias:
replacements[new_node.node.special_alias] = node.node.special_alias
return replacements


Expand Down Expand Up @@ -338,10 +338,10 @@ def fixup(self, node: SN) -> SN:
new = self.replacements[node]
skip_slots: Tuple[str, ...] = ()
if isinstance(node, TypeInfo) and isinstance(new, TypeInfo):
# Special case: tuple_alias is not exposed in symbol tables, but may appear
# Special case: special_alias is not exposed in symbol tables, but may appear
# in external types (e.g. named tuples), so we need to update it manually.
skip_slots = ("tuple_alias",)
replace_object_state(new.tuple_alias, node.tuple_alias)
skip_slots = ("special_alias",)
replace_object_state(new.special_alias, node.special_alias)
replace_object_state(new, node, skip_slots=skip_slots)
return cast(SN, new)
return node
Expand Down Expand Up @@ -372,8 +372,8 @@ def process_type_info(self, info: Optional[TypeInfo]) -> None:
self.fixup_type(target)
self.fixup_type(info.tuple_type)
self.fixup_type(info.typeddict_type)
if info.tuple_alias:
self.fixup_type(info.tuple_alias.target)
if info.special_alias:
self.fixup_type(info.special_alias.target)
info.defn.info = self.fixup(info)
replace_nodes_in_symbol_table(info.names, self.replacements)
for i, item in enumerate(info.mro):
Expand Down Expand Up @@ -547,7 +547,7 @@ def replace_nodes_in_symbol_table(
new = replacements[node.node]
old = node.node
# Needed for TypeInfo, see comment in fixup() above.
replace_object_state(new, old, skip_slots=("tuple_alias",))
replace_object_state(new, old, skip_slots=("special_alias",))
node.node = new
if isinstance(node.node, (Var, TypeAlias)):
# Handle them here just in case these aren't exposed through the AST.
Expand Down
Loading

0 comments on commit cba07d7

Please sign in to comment.