Skip to content

Commit

Permalink
Enable generic TypedDicts (#13389)
Browse files Browse the repository at this point in the history
Fixes #3863

This builds on top of some infra I added for recursive types (Ref #13297). Implementation is quite straightforward. The only non-trivial thing is that when extending/merging TypedDicts, the item types need to me mapped to supertype during semantic analysis. This means we can't call `is_subtype()` etc., and can in theory get types like `Union[int, int]`. But OTOH this equally applies to type aliases, and doesn't seem to cause problems.
  • Loading branch information
ilevkivskyi authored Aug 15, 2022
1 parent 8deeaf3 commit dc5f891
Show file tree
Hide file tree
Showing 18 changed files with 630 additions and 129 deletions.
1 change: 1 addition & 0 deletions misc/proper_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def is_special_target(right: ProperType) -> bool:
"mypy.types.PartialType",
"mypy.types.ErasedType",
"mypy.types.DeletedType",
"mypy.types.RequiredType",
):
# Special case: these are not valid targets for a type alias and thus safe.
# TODO: introduce a SyntheticType base to simplify this?
Expand Down
164 changes: 141 additions & 23 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,12 @@ def __init__(self, chk: mypy.checker.TypeChecker, msg: MessageBuilder, plugin: P

self.resolved_type = {}

# Callee in a call expression is in some sense both runtime context and
# type context, because we support things like C[int](...). Store information
# on whether current expression is a callee, to give better error messages
# related to type context.
self.is_callee = False

def reset(self) -> None:
self.resolved_type = {}

Expand Down Expand Up @@ -319,7 +325,11 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
result = node.type
elif isinstance(node, TypeInfo):
# Reference to a type object.
result = type_object_type(node, self.named_type)
if node.typeddict_type:
# We special-case TypedDict, because they don't define any constructor.
result = self.typeddict_callable(node)
else:
result = type_object_type(node, self.named_type)
if isinstance(result, CallableType) and isinstance( # type: ignore
result.ret_type, Instance
):
Expand Down Expand Up @@ -386,17 +396,29 @@ def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
return self.accept(e.analyzed, self.type_context[-1])
return self.visit_call_expr_inner(e, allow_none_return=allow_none_return)

def refers_to_typeddict(self, base: Expression) -> bool:
if not isinstance(base, RefExpr):
return False
if isinstance(base.node, TypeInfo) and base.node.typeddict_type is not None:
# Direct reference.
return True
return isinstance(base.node, TypeAlias) and isinstance(
get_proper_type(base.node.target), TypedDictType
)

def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) -> Type:
if (
isinstance(e.callee, RefExpr)
and isinstance(e.callee.node, TypeInfo)
and e.callee.node.typeddict_type is not None
self.refers_to_typeddict(e.callee)
or isinstance(e.callee, IndexExpr)
and self.refers_to_typeddict(e.callee.base)
):
# Use named fallback for better error messages.
typeddict_type = e.callee.node.typeddict_type.copy_modified(
fallback=Instance(e.callee.node, [])
)
return self.check_typeddict_call(typeddict_type, e.arg_kinds, e.arg_names, e.args, e)
typeddict_callable = get_proper_type(self.accept(e.callee, is_callee=True))
if isinstance(typeddict_callable, CallableType):
typeddict_type = get_proper_type(typeddict_callable.ret_type)
assert isinstance(typeddict_type, TypedDictType)
return self.check_typeddict_call(
typeddict_type, e.arg_kinds, e.arg_names, e.args, e, typeddict_callable
)
if (
isinstance(e.callee, NameExpr)
and e.callee.name in ("isinstance", "issubclass")
Expand Down Expand Up @@ -457,7 +479,9 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
ret_type=self.object_type(),
fallback=self.named_type("builtins.function"),
)
callee_type = get_proper_type(self.accept(e.callee, type_context, always_allow_any=True))
callee_type = get_proper_type(
self.accept(e.callee, type_context, always_allow_any=True, is_callee=True)
)
if (
self.chk.options.disallow_untyped_calls
and self.chk.in_checked_function()
Expand Down Expand Up @@ -628,28 +652,33 @@ def check_typeddict_call(
arg_names: Sequence[Optional[str]],
args: List[Expression],
context: Context,
orig_callee: Optional[Type],
) -> Type:
if len(args) >= 1 and all([ak == ARG_NAMED for ak in arg_kinds]):
# ex: Point(x=42, y=1337)
assert all(arg_name is not None for arg_name in arg_names)
item_names = cast(List[str], arg_names)
item_args = args
return self.check_typeddict_call_with_kwargs(
callee, dict(zip(item_names, item_args)), context
callee, dict(zip(item_names, item_args)), context, orig_callee
)

if len(args) == 1 and arg_kinds[0] == ARG_POS:
unique_arg = args[0]
if isinstance(unique_arg, DictExpr):
# ex: Point({'x': 42, 'y': 1337})
return self.check_typeddict_call_with_dict(callee, unique_arg, context)
return self.check_typeddict_call_with_dict(
callee, unique_arg, context, orig_callee
)
if isinstance(unique_arg, CallExpr) and isinstance(unique_arg.analyzed, DictExpr):
# ex: Point(dict(x=42, y=1337))
return self.check_typeddict_call_with_dict(callee, unique_arg.analyzed, context)
return self.check_typeddict_call_with_dict(
callee, unique_arg.analyzed, context, orig_callee
)

if len(args) == 0:
# ex: EmptyDict()
return self.check_typeddict_call_with_kwargs(callee, {}, context)
return self.check_typeddict_call_with_kwargs(callee, {}, context, orig_callee)

self.chk.fail(message_registry.INVALID_TYPEDDICT_ARGS, context)
return AnyType(TypeOfAny.from_error)
Expand Down Expand Up @@ -683,18 +712,59 @@ def match_typeddict_call_with_dict(
return False

def check_typeddict_call_with_dict(
self, callee: TypedDictType, kwargs: DictExpr, context: Context
self,
callee: TypedDictType,
kwargs: DictExpr,
context: Context,
orig_callee: Optional[Type],
) -> Type:
validated_kwargs = self.validate_typeddict_kwargs(kwargs=kwargs)
if validated_kwargs is not None:
return self.check_typeddict_call_with_kwargs(
callee, kwargs=validated_kwargs, context=context
callee, kwargs=validated_kwargs, context=context, orig_callee=orig_callee
)
else:
return AnyType(TypeOfAny.from_error)

def typeddict_callable(self, info: TypeInfo) -> CallableType:
"""Construct a reasonable type for a TypedDict type in runtime context.
If it appears as a callee, it will be special-cased anyway, e.g. it is
also allowed to accept a single positional argument if it is a dict literal.
Note it is not safe to move this to type_object_type() since it will crash
on plugin-generated TypedDicts, that may not have the special_alias.
"""
assert info.special_alias is not None
target = info.special_alias.target
assert isinstance(target, ProperType) and isinstance(target, TypedDictType)
expected_types = list(target.items.values())
kinds = [ArgKind.ARG_NAMED] * len(expected_types)
names = list(target.items.keys())
return CallableType(
expected_types,
kinds,
names,
target,
self.named_type("builtins.type"),
variables=info.defn.type_vars,
)

def typeddict_callable_from_context(self, callee: TypedDictType) -> CallableType:
return CallableType(
list(callee.items.values()),
[ArgKind.ARG_NAMED] * len(callee.items),
list(callee.items.keys()),
callee,
self.named_type("builtins.type"),
)

def check_typeddict_call_with_kwargs(
self, callee: TypedDictType, kwargs: Dict[str, Expression], context: Context
self,
callee: TypedDictType,
kwargs: Dict[str, Expression],
context: Context,
orig_callee: Optional[Type],
) -> Type:
if not (callee.required_keys <= set(kwargs.keys()) <= set(callee.items.keys())):
expected_keys = [
Expand All @@ -708,7 +778,38 @@ def check_typeddict_call_with_kwargs(
)
return AnyType(TypeOfAny.from_error)

for (item_name, item_expected_type) in callee.items.items():
orig_callee = get_proper_type(orig_callee)
if isinstance(orig_callee, CallableType):
infer_callee = orig_callee
else:
# Try reconstructing from type context.
if callee.fallback.type.special_alias is not None:
infer_callee = self.typeddict_callable(callee.fallback.type)
else:
# Likely a TypedDict type generated by a plugin.
infer_callee = self.typeddict_callable_from_context(callee)

# We don't show any errors, just infer types in a generic TypedDict type,
# a custom error message will be given below, if there are errors.
with self.msg.filter_errors(), self.chk.local_type_map():
orig_ret_type, _ = self.check_callable_call(
infer_callee,
list(kwargs.values()),
[ArgKind.ARG_NAMED] * len(kwargs),
context,
list(kwargs.keys()),
None,
None,
None,
)

ret_type = get_proper_type(orig_ret_type)
if not isinstance(ret_type, TypedDictType):
# If something went really wrong, type-check call with original type,
# this may give a better error message.
ret_type = callee

for (item_name, item_expected_type) in ret_type.items.items():
if item_name in kwargs:
item_value = kwargs[item_name]
self.chk.check_simple_assignment(
Expand All @@ -721,7 +822,7 @@ def check_typeddict_call_with_kwargs(
code=codes.TYPEDDICT_ITEM,
)

return callee
return orig_ret_type

def get_partial_self_var(self, expr: MemberExpr) -> Optional[Var]:
"""Get variable node for a partial self attribute.
Expand Down Expand Up @@ -2547,7 +2648,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr, is_lvalue: bool) -> Type
return self.analyze_ref_expr(e)
else:
# This is a reference to a non-module attribute.
original_type = self.accept(e.expr)
original_type = self.accept(e.expr, is_callee=self.is_callee)
base = e.expr
module_symbol_table = None

Expand Down Expand Up @@ -3670,6 +3771,8 @@ def visit_type_application(self, tapp: TypeApplication) -> Type:
elif isinstance(item, TupleType) and item.partial_fallback.type.is_named_tuple:
tp = type_object_type(item.partial_fallback.type, self.named_type)
return self.apply_type_arguments_to_callable(tp, item.partial_fallback.args, tapp)
elif isinstance(item, TypedDictType):
return self.typeddict_callable_from_context(item)
else:
self.chk.fail(message_registry.ONLY_CLASS_APPLICATION, tapp)
return AnyType(TypeOfAny.from_error)
Expand Down Expand Up @@ -3723,7 +3826,12 @@ class LongName(Generic[T]): ...
# For example:
# A = List[Tuple[T, T]]
# x = A() <- same as List[Tuple[Any, Any]], see PEP 484.
item = get_proper_type(set_any_tvars(alias, ctx.line, ctx.column))
disallow_any = self.chk.options.disallow_any_generics and self.is_callee
item = get_proper_type(
set_any_tvars(
alias, ctx.line, ctx.column, disallow_any=disallow_any, fail=self.msg.fail
)
)
if isinstance(item, Instance):
# Normally we get a callable type (or overloaded) with .is_type_obj() true
# representing the class's constructor
Expand All @@ -3738,6 +3846,8 @@ class LongName(Generic[T]): ...
tuple_fallback(item).type.fullname != "builtins.tuple"
):
return type_object_type(tuple_fallback(item).type, self.named_type)
elif isinstance(item, TypedDictType):
return self.typeddict_callable_from_context(item)
elif isinstance(item, AnyType):
return AnyType(TypeOfAny.from_another_any, source_any=item)
else:
Expand Down Expand Up @@ -3962,7 +4072,12 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
# to avoid the second error, we always return TypedDict type that was requested
typeddict_context = self.find_typeddict_context(self.type_context[-1], e)
if typeddict_context:
self.check_typeddict_call_with_dict(callee=typeddict_context, kwargs=e, context=e)
orig_ret_type = self.check_typeddict_call_with_dict(
callee=typeddict_context, kwargs=e, context=e, orig_callee=None
)
ret_type = get_proper_type(orig_ret_type)
if isinstance(ret_type, TypedDictType):
return ret_type.copy_modified()
return typeddict_context.copy_modified()

# fast path attempt
Expand Down Expand Up @@ -4494,6 +4609,7 @@ def accept(
type_context: Optional[Type] = None,
allow_none_return: bool = False,
always_allow_any: bool = False,
is_callee: bool = False,
) -> Type:
"""Type check a node in the given type context. If allow_none_return
is True and this expression is a call, allow it to return None. This
Expand All @@ -4502,6 +4618,8 @@ def accept(
if node in self.type_overrides:
return self.type_overrides[node]
self.type_context.append(type_context)
old_is_callee = self.is_callee
self.is_callee = is_callee
try:
if allow_none_return and isinstance(node, CallExpr):
typ = self.visit_call_expr(node, allow_none_return=True)
Expand All @@ -4517,7 +4635,7 @@ def accept(
report_internal_error(
err, self.chk.errors.file, node.line, self.chk.errors, self.chk.options
)

self.is_callee = old_is_callee
self.type_context.pop()
assert typ is not None
self.chk.store_type(node, typ)
Expand Down
2 changes: 2 additions & 0 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,8 @@ def analyze_type_callable_member_access(name: str, typ: FunctionLike, mx: Member
assert isinstance(ret_type, ProperType)
if isinstance(ret_type, TupleType):
ret_type = tuple_fallback(ret_type)
if isinstance(ret_type, TypedDictType):
ret_type = ret_type.fallback
if isinstance(ret_type, Instance):
if not mx.is_operator:
# When Python sees an operator (eg `3 == 4`), it automatically translates that
Expand Down
6 changes: 5 additions & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,11 @@ def visit_tuple_type(self, t: TupleType) -> Type:
return items

def visit_typeddict_type(self, t: TypedDictType) -> Type:
return t.copy_modified(item_types=self.expand_types(t.items.values()))
fallback = t.fallback.accept(self)
fallback = get_proper_type(fallback)
if not isinstance(fallback, Instance):
fallback = t.fallback
return t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)

def visit_literal_type(self, t: LiteralType) -> Type:
# TODO: Verify this implementation is correct
Expand Down
1 change: 1 addition & 0 deletions mypy/fixup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def visit_type_info(self, info: TypeInfo) -> None:
info.update_tuple_type(info.tuple_type)
if info.typeddict_type:
info.typeddict_type.accept(self.type_fixer)
info.update_typeddict_type(info.typeddict_type)
if info.declared_metaclass:
info.declared_metaclass.accept(self.type_fixer)
if info.metaclass_type:
Expand Down
8 changes: 5 additions & 3 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3192,8 +3192,8 @@ class TypeAlias(SymbolNode):
following:
1. An alias targeting a generic class without explicit variables act as
the given class (this doesn't apply to Tuple and Callable, which are not proper
classes but special type constructors):
the given class (this doesn't apply to TypedDict, Tuple and Callable, which
are not proper classes but special type constructors):
A = List
AA = List[Any]
Expand Down Expand Up @@ -3305,7 +3305,9 @@ def from_typeddict_type(cls, info: TypeInfo) -> TypeAlias:
"""Generate an alias to the TypedDict type described by a given TypeInfo."""
assert info.typeddict_type
return TypeAlias(
info.typeddict_type.copy_modified(fallback=mypy.types.Instance(info, [])),
info.typeddict_type.copy_modified(
fallback=mypy.types.Instance(info, info.defn.type_vars)
),
info.fullname,
info.line,
info.column,
Expand Down
Loading

0 comments on commit dc5f891

Please sign in to comment.