Skip to content

Commit

Permalink
Handle prefix/suffix in typevartuple *args support (#14112)
Browse files Browse the repository at this point in the history
This requires handling more cases in the various places that we
previously modified to support *args in general. We also need to refresh
the formals-to-actuals twice in checkexpr as now it can happen in the
infer_function_type_arguments_using_context call.

The handling here is kind of asymmetric, because we can convert prefices
into positional arguments, but there is no equivalent for suffices, so
we represent that as a Tuple[Unpack[...], <suffix>] and handle that case
separately in some spots.

We also support various edge cases like passing in a tuple without any
typevartuples involved.
  • Loading branch information
jhance committed Nov 17, 2022
1 parent 48c4a47 commit 885e361
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 62 deletions.
83 changes: 50 additions & 33 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

import mypy.subtypes
from mypy.expandtype import expand_type, expand_unpack_with_variables
from mypy.nodes import ARG_POS, ARG_STAR, Context
from mypy.nodes import ARG_STAR, Context
from mypy.types import (
AnyType,
CallableType,
Parameters,
ParamSpecType,
PartialType,
TupleType,
Type,
TypeVarId,
TypeVarLikeType,
Expand All @@ -19,6 +20,7 @@
UnpackType,
get_proper_type,
)
from mypy.typevartuples import find_unpack_in_list, replace_starargs


def get_target_type(
Expand Down Expand Up @@ -114,39 +116,57 @@ def apply_generic_arguments(
# Apply arguments to argument types.
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
expanded = expand_unpack_with_variables(var_arg.typ, id_to_type)
assert isinstance(expanded, list)
# Handle other cases later.
for t in expanded:
assert not isinstance(t, UnpackType)
star_index = callable.arg_kinds.index(ARG_STAR)
arg_kinds = (
callable.arg_kinds[:star_index]
+ [ARG_POS] * len(expanded)
+ callable.arg_kinds[star_index + 1 :]
callable = callable.copy_modified(
arg_types=(
[
expand_type(at, id_to_type, allow_erased_callables)
for at in callable.arg_types[:star_index]
]
+ [callable.arg_types[star_index]]
+ [
expand_type(at, id_to_type, allow_erased_callables)
for at in callable.arg_types[star_index + 1 :]
]
)
)
arg_names = (
callable.arg_names[:star_index]
+ [None] * len(expanded)
+ callable.arg_names[star_index + 1 :]
)
arg_types = (
[
expand_type(at, id_to_type, allow_erased_callables)
for at in callable.arg_types[:star_index]
]
+ expanded
+ [
expand_type(at, id_to_type, allow_erased_callables)
for at in callable.arg_types[star_index + 1 :]

unpacked_type = get_proper_type(var_arg.typ.type)
if isinstance(unpacked_type, TupleType):
# Assuming for now that because we convert prefixes to positional arguments,
# the first argument is always an unpack.
expanded_tuple = expand_type(unpacked_type, id_to_type)
if isinstance(expanded_tuple, TupleType):
# TODO: handle the case where the tuple has an unpack. This will
# hit an assert below.
expanded_unpack = find_unpack_in_list(expanded_tuple.items)
if expanded_unpack is not None:
callable = callable.copy_modified(
arg_types=(
callable.arg_types[:star_index]
+ [expanded_tuple]
+ callable.arg_types[star_index + 1 :]
)
)
else:
callable = replace_starargs(callable, expanded_tuple.items)
else:
# TODO: handle the case for if we get a variable length tuple.
assert False, f"mypy bug: unimplemented case, {expanded_tuple}"
elif isinstance(unpacked_type, TypeVarTupleType):
expanded_tvt = expand_unpack_with_variables(var_arg.typ, id_to_type)
assert isinstance(expanded_tvt, list)
for t in expanded_tvt:
assert not isinstance(t, UnpackType)
callable = replace_starargs(callable, expanded_tvt)
else:
assert False, "mypy bug: unhandled case applying unpack"
else:
callable = callable.copy_modified(
arg_types=[
expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types
]
)
else:
arg_types = [
expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types
]
arg_kinds = callable.arg_kinds
arg_names = callable.arg_names

# Apply arguments to TypeGuard if any.
if callable.type_guard is not None:
Expand All @@ -158,10 +178,7 @@ def apply_generic_arguments(
remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type]

return callable.copy_modified(
arg_types=arg_types,
ret_type=expand_type(callable.ret_type, id_to_type, allow_erased_callables),
variables=remaining_tvars,
type_guard=type_guard,
arg_kinds=arg_kinds,
arg_names=arg_names,
)
17 changes: 11 additions & 6 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,12 +1178,17 @@ def check_func_def(
if isinstance(arg_type, ParamSpecType):
pass
elif isinstance(arg_type, UnpackType):
arg_type = TupleType(
[arg_type],
fallback=self.named_generic_type(
"builtins.tuple", [self.named_type("builtins.object")]
),
)
if isinstance(get_proper_type(arg_type.type), TupleType):
# Instead of using Tuple[Unpack[Tuple[...]]], just use
# Tuple[...]
arg_type = arg_type.type
else:
arg_type = TupleType(
[arg_type],
fallback=self.named_generic_type(
"builtins.tuple", [self.named_type("builtins.object")]
),
)
else:
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
Expand Down
83 changes: 74 additions & 9 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
TypeVarType,
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_unions,
get_proper_type,
get_proper_types,
Expand Down Expand Up @@ -1404,13 +1405,21 @@ def check_callable_call(
)
callee = freshen_function_type_vars(callee)
callee = self.infer_function_type_arguments_using_context(callee, context)
if need_refresh:
# Argument kinds etc. may have changed due to
# ParamSpec or TypeVarTuple variables being replaced with an arbitrary
# number of arguments; recalculate actual-to-formal map
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
callee.arg_kinds,
callee.arg_names,
lambda i: self.accept(args[i]),
)
callee = self.infer_function_type_arguments(
callee, args, arg_kinds, formal_to_actual, context
)
if need_refresh:
# Argument kinds etc. may have changed due to
# ParamSpec variables being replaced with an arbitrary
# number of arguments; recalculate actual-to-formal map
formal_to_actual = map_actuals_to_formals(
arg_kinds,
arg_names,
Expand Down Expand Up @@ -1999,11 +2008,66 @@ def check_argument_types(
# Keep track of consumed tuple *arg items.
mapper = ArgTypeExpander(self.argument_infer_context())
for i, actuals in enumerate(formal_to_actual):
for actual in actuals:
actual_type = arg_types[actual]
orig_callee_arg_type = get_proper_type(callee.arg_types[i])

# Checking the case that we have more than one item but the first argument
# is an unpack, so this would be something like:
# [Tuple[Unpack[Ts]], int]
#
# In this case we have to check everything together, we do this by re-unifying
# the suffices to the tuple, e.g. a single actual like
# Tuple[Unpack[Ts], int]
expanded_tuple = False
if len(actuals) > 1:
first_actual_arg_type = get_proper_type(arg_types[actuals[0]])
if (
isinstance(first_actual_arg_type, TupleType)
and len(first_actual_arg_type.items) == 1
and isinstance(get_proper_type(first_actual_arg_type.items[0]), UnpackType)
):
# TODO: use walrus operator
actual_types = [first_actual_arg_type.items[0]] + [
arg_types[a] for a in actuals[1:]
]
actual_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (len(actuals) - 1)

assert isinstance(orig_callee_arg_type, TupleType)
assert orig_callee_arg_type.items
callee_arg_types = orig_callee_arg_type.items
callee_arg_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (
len(orig_callee_arg_type.items) - 1
)
expanded_tuple = True

if not expanded_tuple:
actual_types = [arg_types[a] for a in actuals]
actual_kinds = [arg_kinds[a] for a in actuals]
if isinstance(orig_callee_arg_type, UnpackType):
unpacked_type = get_proper_type(orig_callee_arg_type.type)
# Only case we know of thus far.
assert isinstance(unpacked_type, TupleType)
actual_types = [arg_types[a] for a in actuals]
actual_kinds = [arg_kinds[a] for a in actuals]
callee_arg_types = unpacked_type.items
callee_arg_kinds = [ARG_POS] * len(actuals)
else:
callee_arg_types = [orig_callee_arg_type] * len(actuals)
callee_arg_kinds = [callee.arg_kinds[i]] * len(actuals)

assert len(actual_types) == len(actuals) == len(actual_kinds)

if len(callee_arg_types) != len(actual_types):
# TODO: Improve error message
self.chk.fail("Invalid number of arguments", context)
continue

assert len(callee_arg_types) == len(actual_types)
assert len(callee_arg_types) == len(callee_arg_kinds)
for actual, actual_type, actual_kind, callee_arg_type, callee_arg_kind in zip(
actuals, actual_types, actual_kinds, callee_arg_types, callee_arg_kinds
):
if actual_type is None:
continue # Some kind of error was already reported.
actual_kind = arg_kinds[actual]
# Check that a *arg is valid as varargs.
if actual_kind == nodes.ARG_STAR and not self.is_valid_var_arg(actual_type):
self.msg.invalid_var_arg(actual_type, context)
Expand All @@ -2013,13 +2077,13 @@ def check_argument_types(
is_mapping = is_subtype(actual_type, self.chk.named_type("typing.Mapping"))
self.msg.invalid_keyword_var_arg(actual_type, is_mapping, context)
expanded_actual = mapper.expand_actual_type(
actual_type, actual_kind, callee.arg_names[i], callee.arg_kinds[i]
actual_type, actual_kind, callee.arg_names[i], callee_arg_kind
)
check_arg(
expanded_actual,
actual_type,
arg_kinds[actual],
callee.arg_types[i],
actual_kind,
callee_arg_type,
actual + 1,
i + 1,
callee,
Expand Down Expand Up @@ -4719,6 +4783,7 @@ def is_valid_var_arg(self, typ: Type) -> bool:
)
or isinstance(typ, AnyType)
or isinstance(typ, ParamSpecType)
or isinstance(typ, UnpackType)
)

def is_valid_keyword_var_arg(self, typ: Type) -> bool:
Expand Down
22 changes: 20 additions & 2 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,26 @@ def infer_constraints_for_callable(
)
)

assert isinstance(unpack_type.type, TypeVarTupleType)
constraints.append(Constraint(unpack_type.type, SUPERTYPE_OF, TypeList(actual_types)))
unpacked_type = get_proper_type(unpack_type.type)
if isinstance(unpacked_type, TypeVarTupleType):
constraints.append(Constraint(unpacked_type, SUPERTYPE_OF, TypeList(actual_types)))
elif isinstance(unpacked_type, TupleType):
# Prefixes get converted to positional args, so technically the only case we
# should have here is like Tuple[Unpack[Ts], Y1, Y2, Y3]. If this turns out
# not to hold we can always handle the prefixes too.
inner_unpack = unpacked_type.items[0]
assert isinstance(inner_unpack, UnpackType)
inner_unpacked_type = get_proper_type(inner_unpack.type)
assert isinstance(inner_unpacked_type, TypeVarTupleType)
suffix_len = len(unpacked_type.items) - 1
constraints.append(
Constraint(
inner_unpacked_type, SUPERTYPE_OF, TypeList(actual_types[:-suffix_len])
)
)
else:
assert False, "mypy bug: unhandled constraint inference case"

else:
for actual in actuals:
actual_arg_type = arg_types[actual]
Expand Down
Loading

0 comments on commit 885e361

Please sign in to comment.