Skip to content

Commit

Permalink
Implement basic *args support for variadic generics (#13889)
Browse files Browse the repository at this point in the history
This implements the most basic support for the *args feature but various
edge cases are not handled in this PR because of the large volume of
places that needed to be modified to support this.

In particular, we need to special handle the ARG_STAR argument in
several places for the case where the type is a UnpackType. Finally when
we actually check a function we need to construct a TupleType instead of
a builtins.tuple.
  • Loading branch information
jhance authored Oct 17, 2022
1 parent f12faae commit c810a9c
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 41 deletions.
35 changes: 32 additions & 3 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Callable, Sequence

import mypy.subtypes
from mypy.expandtype import expand_type
from mypy.nodes import Context
from mypy.expandtype import expand_type, expand_unpack_with_variables
from mypy.nodes import ARG_POS, ARG_STAR, Context
from mypy.types import (
AnyType,
CallableType,
Expand All @@ -16,6 +16,7 @@
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnpackType,
get_proper_type,
)

Expand Down Expand Up @@ -110,7 +111,33 @@ def apply_generic_arguments(
callable = callable.expand_param_spec(nt)

# Apply arguments to argument types.
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
expanded = expand_unpack_with_variables(var_arg.typ, id_to_type)
assert isinstance(expanded, list)
# Handle other cases later.
for t in expanded:
assert not isinstance(t, UnpackType)
star_index = callable.arg_kinds.index(ARG_STAR)
arg_kinds = (
callable.arg_kinds[:star_index]
+ [ARG_POS] * len(expanded)
+ callable.arg_kinds[star_index + 1 :]
)
arg_names = (
callable.arg_names[:star_index]
+ [None] * len(expanded)
+ callable.arg_names[star_index + 1 :]
)
arg_types = (
[expand_type(at, id_to_type) for at in callable.arg_types[:star_index]]
+ expanded
+ [expand_type(at, id_to_type) for at in callable.arg_types[star_index + 1 :]]
)
else:
arg_types = [expand_type(at, id_to_type) for at in callable.arg_types]
arg_kinds = callable.arg_kinds
arg_names = callable.arg_names

# Apply arguments to TypeGuard if any.
if callable.type_guard is not None:
Expand All @@ -126,4 +153,6 @@ def apply_generic_arguments(
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
arg_kinds=arg_kinds,
arg_names=arg_names,
)
12 changes: 11 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_unions,
get_proper_type,
get_proper_types,
Expand Down Expand Up @@ -1170,7 +1171,16 @@ def check_func_def(
ctx = typ
self.fail(message_registry.FUNCTION_PARAMETER_CANNOT_BE_COVARIANT, ctx)
if typ.arg_kinds[i] == nodes.ARG_STAR:
if not isinstance(arg_type, ParamSpecType):
if isinstance(arg_type, ParamSpecType):
pass
elif isinstance(arg_type, UnpackType):
arg_type = TupleType(
[arg_type],
fallback=self.named_generic_type(
"builtins.tuple", [self.named_type("builtins.object")]
),
)
else:
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
Expand Down
5 changes: 4 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
UnionType,
Expand Down Expand Up @@ -1397,7 +1398,9 @@ def check_callable_call(
)

if callee.is_generic():
need_refresh = any(isinstance(v, ParamSpecType) for v in callee.variables)
need_refresh = any(
isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables
)
callee = freshen_function_type_vars(callee)
callee = self.infer_function_type_arguments_using_context(callee, context)
callee = self.infer_function_type_arguments(
Expand Down
44 changes: 34 additions & 10 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,41 @@ def infer_constraints_for_callable(
mapper = ArgTypeExpander(context)

for i, actuals in enumerate(formal_to_actual):
for actual in actuals:
actual_arg_type = arg_types[actual]
if actual_arg_type is None:
continue
if isinstance(callee.arg_types[i], UnpackType):
unpack_type = callee.arg_types[i]
assert isinstance(unpack_type, UnpackType)

# In this case we are binding all of the actuals to *args
# and we want a constraint that the typevar tuple being unpacked
# is equal to a type list of all the actuals.
actual_types = []
for actual in actuals:
actual_arg_type = arg_types[actual]
if actual_arg_type is None:
continue

actual_type = mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)
actual_types.append(
mapper.expand_actual_type(
actual_arg_type,
arg_kinds[actual],
callee.arg_names[i],
callee.arg_kinds[i],
)
)

assert isinstance(unpack_type.type, TypeVarTupleType)
constraints.append(Constraint(unpack_type.type, SUPERTYPE_OF, TypeList(actual_types)))
else:
for actual in actuals:
actual_arg_type = arg_types[actual]
if actual_arg_type is None:
continue

actual_type = mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)

return constraints

Expand Down Expand Up @@ -165,7 +190,6 @@ def infer_constraints(template: Type, actual: Type, direction: int) -> list[Cons


def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Constraint]:

orig_template = template
template = get_proper_type(template)
actual = get_proper_type(actual)
Expand Down
74 changes: 48 additions & 26 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload

from mypy.nodes import ARG_STAR
from mypy.types import (
AnyType,
CallableType,
Expand Down Expand Up @@ -213,31 +214,7 @@ def visit_unpack_type(self, t: UnpackType) -> Type:
assert False, "Mypy bug: unpacking must happen at a higher level"

def expand_unpack(self, t: UnpackType) -> list[Type] | Instance | AnyType | None:
"""May return either a list of types to unpack to, any, or a single
variable length tuple. The latter may not be valid in all contexts.
"""
if isinstance(t.type, TypeVarTupleType):
repl = get_proper_type(self.variables.get(t.type.id, t))
if isinstance(repl, TupleType):
return repl.items
if isinstance(repl, TypeList):
return repl.items
elif isinstance(repl, Instance) and repl.type.fullname == "builtins.tuple":
return repl
elif isinstance(repl, AnyType):
# tuple[Any, ...] would be better, but we don't have
# the type info to construct that type here.
return repl
elif isinstance(repl, TypeVarTupleType):
return [UnpackType(typ=repl)]
elif isinstance(repl, UnpackType):
return [repl]
elif isinstance(repl, UninhabitedType):
return None
else:
raise NotImplementedError(f"Invalid type replacement to expand: {repl}")
else:
raise NotImplementedError(f"Invalid type to expand: {t.type}")
return expand_unpack_with_variables(t, self.variables)

def visit_parameters(self, t: Parameters) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types))
Expand Down Expand Up @@ -267,8 +244,23 @@ def visit_callable_type(self, t: CallableType) -> Type:
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)

var_arg = t.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
expanded = self.expand_unpack(var_arg.typ)
# Handle other cases later.
assert isinstance(expanded, list)
assert len(expanded) == 1 and isinstance(expanded[0], UnpackType)
star_index = t.arg_kinds.index(ARG_STAR)
arg_types = (
self.expand_types(t.arg_types[:star_index])
+ expanded
+ self.expand_types(t.arg_types[star_index + 1 :])
)
else:
arg_types = self.expand_types(t.arg_types)

return t.copy_modified(
arg_types=self.expand_types(t.arg_types),
arg_types=arg_types,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
)
Expand Down Expand Up @@ -361,3 +353,33 @@ def expand_types(self, types: Iterable[Type]) -> list[Type]:
for t in types:
a.append(t.accept(self))
return a


def expand_unpack_with_variables(
t: UnpackType, variables: Mapping[TypeVarId, Type]
) -> list[Type] | Instance | AnyType | None:
"""May return either a list of types to unpack to, any, or a single
variable length tuple. The latter may not be valid in all contexts.
"""
if isinstance(t.type, TypeVarTupleType):
repl = get_proper_type(variables.get(t.type.id, t))
if isinstance(repl, TupleType):
return repl.items
if isinstance(repl, TypeList):
return repl.items
elif isinstance(repl, Instance) and repl.type.fullname == "builtins.tuple":
return repl
elif isinstance(repl, AnyType):
# tuple[Any, ...] would be better, but we don't have
# the type info to construct that type here.
return repl
elif isinstance(repl, TypeVarTupleType):
return [UnpackType(typ=repl)]
elif isinstance(repl, UnpackType):
return [repl]
elif isinstance(repl, UninhabitedType):
return None
else:
raise NotImplementedError(f"Invalid type replacement to expand: {repl}")
else:
raise NotImplementedError(f"Invalid type to expand: {t.type}")
4 changes: 4 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
TypedDictType,
TypeOfAny,
TypeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UninhabitedType,
Expand Down Expand Up @@ -2263,6 +2264,9 @@ def format_literal_value(typ: LiteralType) -> str:
elif isinstance(typ, TypeVarType):
# This is similar to non-generic instance types.
return typ.name
elif isinstance(typ, TypeVarTupleType):
# This is similar to non-generic instance types.
return typ.name
elif isinstance(typ, ParamSpecType):
# Concatenate[..., P]
if typ.prefix.arg_types:
Expand Down
16 changes: 16 additions & 0 deletions test-data/unit/check-typevar-tuple.test
Original file line number Diff line number Diff line change
Expand Up @@ -346,4 +346,20 @@ expect_variadic_array(u)
expect_variadic_array_2(u)


[builtins fixtures/tuple.pyi]

[case testPep646TypeVarStarArgs]
from typing import Tuple
from typing_extensions import TypeVarTuple, Unpack

Ts = TypeVarTuple("Ts")

# TODO: add less trivial tests with prefix/suffix etc.
# TODO: add tests that call with a type var tuple instead of just args.
def args_to_tuple(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]:
reveal_type(args) # N: Revealed type is "Tuple[Unpack[Ts`-1]]"
return args

reveal_type(args_to_tuple(1, 'a')) # N: Revealed type is "Tuple[Literal[1]?, Literal['a']?]"

[builtins fixtures/tuple.pyi]

0 comments on commit c810a9c

Please sign in to comment.