Skip to content

Commit

Permalink
Support inferring Unpack mixed with other items (#12769)
Browse files Browse the repository at this point in the history
The main substance here modifies mypy/constraints.py to not assume that
template.items has length 1 in the case that there is an unpack. We
instead assume that that there is only a singular unpack, and do a
former pass to find what index it is in, and then resolve the unpack to
the corresponding subset of whatever tuple we are matching against.
  • Loading branch information
jhance authored Jun 10, 2022
1 parent 3833277 commit 9ccd081
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 17 deletions.
69 changes: 55 additions & 14 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,20 +702,46 @@ def visit_tuple_type(self, template: TupleType) -> List[Constraint]:
isinstance(actual, Instance)
and actual.type.fullname == "builtins.tuple"
)
if len(template.items) == 1:
item = get_proper_type(template.items[0])
if isinstance(item, UnpackType):
unpacked_type = get_proper_type(item.type)
if isinstance(unpacked_type, TypeVarTupleType):
if (
isinstance(actual, (TupleType, AnyType))
or is_varlength_tuple
):
return [Constraint(
type_var=unpacked_type.id,
op=self.direction,
target=actual,
)]
unpack_index = find_unpack_in_tuple(template)

if unpack_index is not None:
unpack_item = get_proper_type(template.items[unpack_index])
assert isinstance(unpack_item, UnpackType)

unpacked_type = get_proper_type(unpack_item.type)
if isinstance(unpacked_type, TypeVarTupleType):
if is_varlength_tuple:
# This case is only valid when the unpack is the only
# item in the tuple.
#
# TODO: We should support this in the case that all the items
# in the tuple besides the unpack have the same type as the
# varlength tuple's type. E.g. Tuple[int, ...] should be valid
# where we expect Tuple[int, Unpack[Ts]], but not for Tuple[str, Unpack[Ts]].
assert len(template.items) == 1

if (
isinstance(actual, (TupleType, AnyType))
or is_varlength_tuple
):
modified_actual = actual
if isinstance(actual, TupleType):
# Exclude the items from before and after the unpack index.
head = unpack_index
tail = len(template.items) - unpack_index - 1
if tail:
modified_actual = actual.copy_modified(
items=actual.items[head:-tail],
)
else:
modified_actual = actual.copy_modified(
items=actual.items[head:],
)
return [Constraint(
type_var=unpacked_type.id,
op=self.direction,
target=modified_actual,
)]

if isinstance(actual, TupleType) and len(actual.items) == len(template.items):
res: List[Constraint] = []
Expand Down Expand Up @@ -828,3 +854,18 @@ def find_matching_overload_items(overloaded: Overloaded,
# it maintains backward compatibility.
res = items[:]
return res


def find_unpack_in_tuple(t: TupleType) -> Optional[int]:
unpack_index: Optional[int] = None
for i, item in enumerate(t.items):
proper_item = get_proper_type(item)
if isinstance(proper_item, UnpackType):
# We cannot fail here, so we must check this in an earlier
# semanal phase.
# Funky code here avoids mypyc narrowing the type of unpack_index.
old_index = unpack_index
assert old_index is None
# Don't return so that we can also sanity check there is only one.
unpack_index = i
return unpack_index
2 changes: 2 additions & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def expand_unpack(self, t: UnpackType) -> Optional[Union[List[Type], Instance, A
return repl
elif isinstance(repl, TypeVarTupleType):
return [UnpackType(typ=repl)]
elif isinstance(repl, UnpackType):
return [repl]
elif isinstance(repl, UninhabitedType):
return None
else:
Expand Down
2 changes: 1 addition & 1 deletion mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def visit_partial_type(self, t: PartialType) -> Type:
return t

def visit_unpack_type(self, t: UnpackType) -> Type:
return t.type.accept(self)
return UnpackType(t.type.accept(self))

def visit_callable_type(self, t: CallableType) -> Type:
return t.copy_modified(arg_types=self.translate_types(t.arg_types),
Expand Down
20 changes: 18 additions & 2 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
TupleType, Instance, FunctionLike, Type, CallableType, TypeVarLikeType, Overloaded,
TypeVarType, UninhabitedType, FormalArgument, UnionType, NoneType,
AnyType, TypeOfAny, TypeType, ProperType, LiteralType, get_proper_type, get_proper_types,
TypeAliasType, TypeQuery, ParamSpecType, Parameters, ENUM_REMOVED_PROPS
TypeAliasType, TypeQuery, ParamSpecType, Parameters, UnpackType, TypeVarTupleType,
ENUM_REMOVED_PROPS,
)
from mypy.nodes import (
FuncBase, FuncItem, FuncDef, OverloadedFuncDef, TypeInfo, ARG_STAR, ARG_STAR2, ARG_POS,
Expand Down Expand Up @@ -42,7 +43,22 @@ def tuple_fallback(typ: TupleType) -> Instance:
info = typ.partial_fallback.type
if info.fullname != 'builtins.tuple':
return typ.partial_fallback
return Instance(info, [join_type_list(typ.items)])
items = []
for item in typ.items:
proper_type = get_proper_type(item)
if isinstance(proper_type, UnpackType):
unpacked_type = get_proper_type(proper_type.type)
if isinstance(unpacked_type, TypeVarTupleType):
items.append(unpacked_type.upper_bound)
elif isinstance(unpacked_type, TupleType):
# TODO: might make sense to do recursion here to support nested unpacks
# of tuple constants
items.extend(unpacked_type.items)
else:
raise NotImplementedError
else:
items.append(item)
return Instance(info, [join_type_list(items)])


def type_object_type_from_function(signature: FunctionLike,
Expand Down
67 changes: 67 additions & 0 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,70 @@ reveal_type(g(args, args2)) # N: Revealed type is "Tuple[builtins.int, builtins
reveal_type(g(args, args3)) # N: Revealed type is "builtins.tuple[builtins.object, ...]"
reveal_type(g(any, any)) # N: Revealed type is "Any"
[builtins fixtures/tuple.pyi]

[case testTypeVarTupleMixed]
from typing import Tuple
from typing_extensions import Unpack, TypeVarTuple

Ts = TypeVarTuple("Ts")

def to_str(i: int) -> str:
...

def f(a: Tuple[int, Unpack[Ts]]) -> Tuple[str, Unpack[Ts]]:
return (to_str(a[0]),) + a[1:]

def g(a: Tuple[Unpack[Ts], int]) -> Tuple[Unpack[Ts], str]:
return a[:-1] + (to_str(a[-1]),)

def h(a: Tuple[bool, int, Unpack[Ts], str, object]) -> Tuple[Unpack[Ts]]:
return a[2:-2]

empty = ()
bad_args: Tuple[str, str]
var_len_tuple: Tuple[int, ...]

f_args: Tuple[int, str]
f_args2: Tuple[int]
f_args3: Tuple[int, str, bool]

reveal_type(f(f_args)) # N: Revealed type is "Tuple[builtins.str, builtins.str]"
reveal_type(f(f_args2)) # N: Revealed type is "Tuple[builtins.str]"
reveal_type(f(f_args3)) # N: Revealed type is "Tuple[builtins.str, builtins.str, builtins.bool]"
f(empty) # E: Argument 1 to "f" has incompatible type "Tuple[]"; expected "Tuple[int]"
f(bad_args) # E: Argument 1 to "f" has incompatible type "Tuple[str, str]"; expected "Tuple[int, str]"
# TODO: This hits a crash where we assert len(templates.items) == 1. See visit_tuple_type
# in mypy/constraints.py.
#f(var_len_tuple)

g_args: Tuple[str, int]
reveal_type(g(g_args)) # N: Revealed type is "Tuple[builtins.str, builtins.str]"

h_args: Tuple[bool, int, str, int, str, object]
reveal_type(h(h_args)) # N: Revealed type is "Tuple[builtins.str, builtins.int]"
[builtins fixtures/tuple.pyi]

[case testTypeVarTupleChaining]
from typing import Tuple
from typing_extensions import Unpack, TypeVarTuple

Ts = TypeVarTuple("Ts")

def to_str(i: int) -> str:
...

def f(a: Tuple[int, Unpack[Ts]]) -> Tuple[str, Unpack[Ts]]:
return (to_str(a[0]),) + a[1:]

def g(a: Tuple[bool, int, Unpack[Ts], str, object]) -> Tuple[str, Unpack[Ts]]:
return f(a[1:-2])

def h(a: Tuple[bool, int, Unpack[Ts], str, object]) -> Tuple[str, Unpack[Ts]]:
x = f(a[1:-2])
return x

args: Tuple[bool, int, str, int, str, object]
reveal_type(g(args)) # N: Revealed type is "Tuple[builtins.str, builtins.str, builtins.int]"
reveal_type(h(args)) # N: Revealed type is "Tuple[builtins.str, builtins.str, builtins.int]"
[builtins fixtures/tuple.pyi]

0 comments on commit 9ccd081

Please sign in to comment.