From 885e361b1cf97260d80e9dfa4e494ff362f0edff Mon Sep 17 00:00:00 2001 From: jhance Date: Thu, 17 Nov 2022 11:41:49 -0800 Subject: [PATCH] Handle prefix/suffix in typevartuple *args support (#14112) This requires handling more cases in the various places that we previously modified to support *args in general. We also need to refresh the formals-to-actuals twice in checkexpr as now it can happen in the infer_function_type_arguments_using_context call. The handling here is kind of asymmetric, because we can convert prefices into positional arguments, but there is no equivalent for suffices, so we represent that as a Tuple[Unpack[...], ] and handle that case separately in some spots. We also support various edge cases like passing in a tuple without any typevartuples involved. --- mypy/applytype.py | 83 +++++++++++++---------- mypy/checker.py | 17 +++-- mypy/checkexpr.py | 83 ++++++++++++++++++++--- mypy/constraints.py | 22 ++++++- mypy/expandtype.py | 88 +++++++++++++++++++++---- mypy/typevartuples.py | 20 +++++- test-data/unit/check-typevar-tuple.test | 33 ++++++++++ 7 files changed, 284 insertions(+), 62 deletions(-) diff --git a/mypy/applytype.py b/mypy/applytype.py index d7f31b36c244..a81ed3cd1f16 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -4,13 +4,14 @@ import mypy.subtypes from mypy.expandtype import expand_type, expand_unpack_with_variables -from mypy.nodes import ARG_POS, ARG_STAR, Context +from mypy.nodes import ARG_STAR, Context from mypy.types import ( AnyType, CallableType, Parameters, ParamSpecType, PartialType, + TupleType, Type, TypeVarId, TypeVarLikeType, @@ -19,6 +20,7 @@ UnpackType, get_proper_type, ) +from mypy.typevartuples import find_unpack_in_list, replace_starargs def get_target_type( @@ -114,39 +116,57 @@ def apply_generic_arguments( # Apply arguments to argument types. var_arg = callable.var_arg() if var_arg is not None and isinstance(var_arg.typ, UnpackType): - expanded = expand_unpack_with_variables(var_arg.typ, id_to_type) - assert isinstance(expanded, list) - # Handle other cases later. - for t in expanded: - assert not isinstance(t, UnpackType) star_index = callable.arg_kinds.index(ARG_STAR) - arg_kinds = ( - callable.arg_kinds[:star_index] - + [ARG_POS] * len(expanded) - + callable.arg_kinds[star_index + 1 :] + callable = callable.copy_modified( + arg_types=( + [ + expand_type(at, id_to_type, allow_erased_callables) + for at in callable.arg_types[:star_index] + ] + + [callable.arg_types[star_index]] + + [ + expand_type(at, id_to_type, allow_erased_callables) + for at in callable.arg_types[star_index + 1 :] + ] + ) ) - arg_names = ( - callable.arg_names[:star_index] - + [None] * len(expanded) - + callable.arg_names[star_index + 1 :] - ) - arg_types = ( - [ - expand_type(at, id_to_type, allow_erased_callables) - for at in callable.arg_types[:star_index] - ] - + expanded - + [ - expand_type(at, id_to_type, allow_erased_callables) - for at in callable.arg_types[star_index + 1 :] + + unpacked_type = get_proper_type(var_arg.typ.type) + if isinstance(unpacked_type, TupleType): + # Assuming for now that because we convert prefixes to positional arguments, + # the first argument is always an unpack. + expanded_tuple = expand_type(unpacked_type, id_to_type) + if isinstance(expanded_tuple, TupleType): + # TODO: handle the case where the tuple has an unpack. This will + # hit an assert below. + expanded_unpack = find_unpack_in_list(expanded_tuple.items) + if expanded_unpack is not None: + callable = callable.copy_modified( + arg_types=( + callable.arg_types[:star_index] + + [expanded_tuple] + + callable.arg_types[star_index + 1 :] + ) + ) + else: + callable = replace_starargs(callable, expanded_tuple.items) + else: + # TODO: handle the case for if we get a variable length tuple. + assert False, f"mypy bug: unimplemented case, {expanded_tuple}" + elif isinstance(unpacked_type, TypeVarTupleType): + expanded_tvt = expand_unpack_with_variables(var_arg.typ, id_to_type) + assert isinstance(expanded_tvt, list) + for t in expanded_tvt: + assert not isinstance(t, UnpackType) + callable = replace_starargs(callable, expanded_tvt) + else: + assert False, "mypy bug: unhandled case applying unpack" + else: + callable = callable.copy_modified( + arg_types=[ + expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types ] ) - else: - arg_types = [ - expand_type(at, id_to_type, allow_erased_callables) for at in callable.arg_types - ] - arg_kinds = callable.arg_kinds - arg_names = callable.arg_names # Apply arguments to TypeGuard if any. if callable.type_guard is not None: @@ -158,10 +178,7 @@ def apply_generic_arguments( remaining_tvars = [tv for tv in tvars if tv.id not in id_to_type] return callable.copy_modified( - arg_types=arg_types, ret_type=expand_type(callable.ret_type, id_to_type, allow_erased_callables), variables=remaining_tvars, type_guard=type_guard, - arg_kinds=arg_kinds, - arg_names=arg_names, ) diff --git a/mypy/checker.py b/mypy/checker.py index 57725bd9186b..c7de4911501a 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1178,12 +1178,17 @@ def check_func_def( if isinstance(arg_type, ParamSpecType): pass elif isinstance(arg_type, UnpackType): - arg_type = TupleType( - [arg_type], - fallback=self.named_generic_type( - "builtins.tuple", [self.named_type("builtins.object")] - ), - ) + if isinstance(get_proper_type(arg_type.type), TupleType): + # Instead of using Tuple[Unpack[Tuple[...]]], just use + # Tuple[...] + arg_type = arg_type.type + else: + arg_type = TupleType( + [arg_type], + fallback=self.named_generic_type( + "builtins.tuple", [self.named_type("builtins.object")] + ), + ) else: # builtins.tuple[T] is typing.Tuple[T, ...] arg_type = self.named_generic_type("builtins.tuple", [arg_type]) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3d2c69073bc0..b41a38825fb3 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -150,6 +150,7 @@ TypeVarType, UninhabitedType, UnionType, + UnpackType, flatten_nested_unions, get_proper_type, get_proper_types, @@ -1404,13 +1405,21 @@ def check_callable_call( ) callee = freshen_function_type_vars(callee) callee = self.infer_function_type_arguments_using_context(callee, context) + if need_refresh: + # Argument kinds etc. may have changed due to + # ParamSpec or TypeVarTuple variables being replaced with an arbitrary + # number of arguments; recalculate actual-to-formal map + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee.arg_kinds, + callee.arg_names, + lambda i: self.accept(args[i]), + ) callee = self.infer_function_type_arguments( callee, args, arg_kinds, formal_to_actual, context ) if need_refresh: - # Argument kinds etc. may have changed due to - # ParamSpec variables being replaced with an arbitrary - # number of arguments; recalculate actual-to-formal map formal_to_actual = map_actuals_to_formals( arg_kinds, arg_names, @@ -1999,11 +2008,66 @@ def check_argument_types( # Keep track of consumed tuple *arg items. mapper = ArgTypeExpander(self.argument_infer_context()) for i, actuals in enumerate(formal_to_actual): - for actual in actuals: - actual_type = arg_types[actual] + orig_callee_arg_type = get_proper_type(callee.arg_types[i]) + + # Checking the case that we have more than one item but the first argument + # is an unpack, so this would be something like: + # [Tuple[Unpack[Ts]], int] + # + # In this case we have to check everything together, we do this by re-unifying + # the suffices to the tuple, e.g. a single actual like + # Tuple[Unpack[Ts], int] + expanded_tuple = False + if len(actuals) > 1: + first_actual_arg_type = get_proper_type(arg_types[actuals[0]]) + if ( + isinstance(first_actual_arg_type, TupleType) + and len(first_actual_arg_type.items) == 1 + and isinstance(get_proper_type(first_actual_arg_type.items[0]), UnpackType) + ): + # TODO: use walrus operator + actual_types = [first_actual_arg_type.items[0]] + [ + arg_types[a] for a in actuals[1:] + ] + actual_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (len(actuals) - 1) + + assert isinstance(orig_callee_arg_type, TupleType) + assert orig_callee_arg_type.items + callee_arg_types = orig_callee_arg_type.items + callee_arg_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * ( + len(orig_callee_arg_type.items) - 1 + ) + expanded_tuple = True + + if not expanded_tuple: + actual_types = [arg_types[a] for a in actuals] + actual_kinds = [arg_kinds[a] for a in actuals] + if isinstance(orig_callee_arg_type, UnpackType): + unpacked_type = get_proper_type(orig_callee_arg_type.type) + # Only case we know of thus far. + assert isinstance(unpacked_type, TupleType) + actual_types = [arg_types[a] for a in actuals] + actual_kinds = [arg_kinds[a] for a in actuals] + callee_arg_types = unpacked_type.items + callee_arg_kinds = [ARG_POS] * len(actuals) + else: + callee_arg_types = [orig_callee_arg_type] * len(actuals) + callee_arg_kinds = [callee.arg_kinds[i]] * len(actuals) + + assert len(actual_types) == len(actuals) == len(actual_kinds) + + if len(callee_arg_types) != len(actual_types): + # TODO: Improve error message + self.chk.fail("Invalid number of arguments", context) + continue + + assert len(callee_arg_types) == len(actual_types) + assert len(callee_arg_types) == len(callee_arg_kinds) + for actual, actual_type, actual_kind, callee_arg_type, callee_arg_kind in zip( + actuals, actual_types, actual_kinds, callee_arg_types, callee_arg_kinds + ): if actual_type is None: continue # Some kind of error was already reported. - actual_kind = arg_kinds[actual] # Check that a *arg is valid as varargs. if actual_kind == nodes.ARG_STAR and not self.is_valid_var_arg(actual_type): self.msg.invalid_var_arg(actual_type, context) @@ -2013,13 +2077,13 @@ def check_argument_types( is_mapping = is_subtype(actual_type, self.chk.named_type("typing.Mapping")) self.msg.invalid_keyword_var_arg(actual_type, is_mapping, context) expanded_actual = mapper.expand_actual_type( - actual_type, actual_kind, callee.arg_names[i], callee.arg_kinds[i] + actual_type, actual_kind, callee.arg_names[i], callee_arg_kind ) check_arg( expanded_actual, actual_type, - arg_kinds[actual], - callee.arg_types[i], + actual_kind, + callee_arg_type, actual + 1, i + 1, callee, @@ -4719,6 +4783,7 @@ def is_valid_var_arg(self, typ: Type) -> bool: ) or isinstance(typ, AnyType) or isinstance(typ, ParamSpecType) + or isinstance(typ, UnpackType) ) def is_valid_keyword_var_arg(self, typ: Type) -> bool: diff --git a/mypy/constraints.py b/mypy/constraints.py index 7123c590b7ef..4e78e5ff1117 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -133,8 +133,26 @@ def infer_constraints_for_callable( ) ) - assert isinstance(unpack_type.type, TypeVarTupleType) - constraints.append(Constraint(unpack_type.type, SUPERTYPE_OF, TypeList(actual_types))) + unpacked_type = get_proper_type(unpack_type.type) + if isinstance(unpacked_type, TypeVarTupleType): + constraints.append(Constraint(unpacked_type, SUPERTYPE_OF, TypeList(actual_types))) + elif isinstance(unpacked_type, TupleType): + # Prefixes get converted to positional args, so technically the only case we + # should have here is like Tuple[Unpack[Ts], Y1, Y2, Y3]. If this turns out + # not to hold we can always handle the prefixes too. + inner_unpack = unpacked_type.items[0] + assert isinstance(inner_unpack, UnpackType) + inner_unpacked_type = get_proper_type(inner_unpack.type) + assert isinstance(inner_unpacked_type, TypeVarTupleType) + suffix_len = len(unpacked_type.items) - 1 + constraints.append( + Constraint( + inner_unpacked_type, SUPERTYPE_OF, TypeList(actual_types[:-suffix_len]) + ) + ) + else: + assert False, "mypy bug: unhandled constraint inference case" + else: for actual in actuals: actual_arg_type = arg_types[actual] diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 70fa62291aa3..43f4e6bcd75b 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -2,7 +2,7 @@ from typing import Iterable, Mapping, Sequence, TypeVar, cast, overload -from mypy.nodes import ARG_STAR, Var +from mypy.nodes import ARG_POS, ARG_STAR, Var from mypy.type_visitor import TypeTranslator from mypy.types import ( AnyType, @@ -36,7 +36,11 @@ UnpackType, get_proper_type, ) -from mypy.typevartuples import split_with_instance, split_with_prefix_and_suffix +from mypy.typevartuples import ( + find_unpack_in_list, + split_with_instance, + split_with_prefix_and_suffix, +) @overload @@ -282,21 +286,83 @@ def visit_callable_type(self, t: CallableType) -> Type: var_arg = t.var_arg() if var_arg is not None and isinstance(var_arg.typ, UnpackType): - expanded = self.expand_unpack(var_arg.typ) - # Handle other cases later. - assert isinstance(expanded, list) - assert len(expanded) == 1 and isinstance(expanded[0], UnpackType) star_index = t.arg_kinds.index(ARG_STAR) - arg_types = ( - self.expand_types(t.arg_types[:star_index]) - + expanded - + self.expand_types(t.arg_types[star_index + 1 :]) - ) + + # We have something like Unpack[Tuple[X1, X2, Unpack[Ts], Y1, Y2]] + if isinstance(get_proper_type(var_arg.typ.type), TupleType): + expanded_tuple = get_proper_type(var_arg.typ.type.accept(self)) + # TODO: handle the case that expanded_tuple is a variable length tuple. + assert isinstance(expanded_tuple, TupleType) + expanded_unpack_index = find_unpack_in_list(expanded_tuple.items) + # This is the case where we just have Unpack[Tuple[X1, X2, X3]] + # (for example if either the tuple had no unpacks, or the unpack in the + # tuple got fully expanded to something with fixed length) + if expanded_unpack_index is None: + arg_names = ( + t.arg_names[:star_index] + + [None] * len(expanded_tuple.items) + + t.arg_names[star_index + 1 :] + ) + arg_kinds = ( + t.arg_kinds[:star_index] + + [ARG_POS] * len(expanded_tuple.items) + + t.arg_kinds[star_index + 1 :] + ) + arg_types = ( + self.expand_types(t.arg_types[:star_index]) + + expanded_tuple.items + + self.expand_types(t.arg_types[star_index + 1 :]) + ) + else: + # If Unpack[Ts] simplest form still has an unpack or is a + # homogenous tuple, then only the prefix can be represented as + # positional arguments, and we pass Tuple[Unpack[Ts-1], Y1, Y2] + # as the star arg, for example. + prefix_len = expanded_unpack_index + arg_names = ( + t.arg_names[:star_index] + [None] * prefix_len + t.arg_names[star_index:] + ) + arg_kinds = ( + t.arg_kinds[:star_index] + + [ARG_POS] * prefix_len + + t.arg_kinds[star_index:] + ) + arg_types = ( + self.expand_types(t.arg_types[:star_index]) + + expanded_tuple.items[:prefix_len] + # Constructing the Unpack containing the tuple without the prefix. + + [ + UnpackType( + expanded_tuple.copy_modified( + items=expanded_tuple.items[prefix_len:] + ) + ) + ] + + self.expand_types(t.arg_types[star_index + 1 :]) + ) + else: + expanded = self.expand_unpack(var_arg.typ) + # Handle other cases later. + assert isinstance(expanded, list) + assert len(expanded) == 1 and isinstance(expanded[0], UnpackType) + + # In this case we keep the arg as ARG_STAR. + arg_names = t.arg_names + arg_kinds = t.arg_kinds + arg_types = ( + self.expand_types(t.arg_types[:star_index]) + + expanded + + self.expand_types(t.arg_types[star_index + 1 :]) + ) else: arg_types = self.expand_types(t.arg_types) + arg_names = t.arg_names + arg_kinds = t.arg_kinds return t.copy_modified( arg_types=arg_types, + arg_names=arg_names, + arg_kinds=arg_kinds, ret_type=t.ret_type.accept(self), type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None), ) diff --git a/mypy/typevartuples.py b/mypy/typevartuples.py index e93f99d8a825..4b3b5cc2dca7 100644 --- a/mypy/typevartuples.py +++ b/mypy/typevartuples.py @@ -4,7 +4,8 @@ from typing import Sequence, TypeVar -from mypy.types import Instance, ProperType, Type, UnpackType, get_proper_type +from mypy.nodes import ARG_POS, ARG_STAR +from mypy.types import CallableType, Instance, ProperType, Type, UnpackType, get_proper_type def find_unpack_in_list(items: Sequence[Type]) -> int | None: @@ -150,3 +151,20 @@ def extract_unpack(types: Sequence[Type]) -> ProperType | None: if isinstance(proper_type, UnpackType): return get_proper_type(proper_type.type) return None + + +def replace_starargs(callable: CallableType, types: list[Type]) -> CallableType: + star_index = callable.arg_kinds.index(ARG_STAR) + arg_kinds = ( + callable.arg_kinds[:star_index] + + [ARG_POS] * len(types) + + callable.arg_kinds[star_index + 1 :] + ) + arg_names = ( + callable.arg_names[:star_index] + + [None] * len(types) + + callable.arg_names[star_index + 1 :] + ) + arg_types = callable.arg_types[:star_index] + types + callable.arg_types[star_index + 1 :] + + return callable.copy_modified(arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds) diff --git a/test-data/unit/check-typevar-tuple.test b/test-data/unit/check-typevar-tuple.test index d8f6cde10441..d85990293aea 100644 --- a/test-data/unit/check-typevar-tuple.test +++ b/test-data/unit/check-typevar-tuple.test @@ -381,4 +381,37 @@ def args_to_tuple(*args: Unpack[Ts]) -> Tuple[Unpack[Ts]]: reveal_type(args_to_tuple(1, 'a')) # N: Revealed type is "Tuple[Literal[1]?, Literal['a']?]" +def with_prefix_suffix(*args: Unpack[Tuple[bool, str, Unpack[Ts], int]]) -> Tuple[bool, str, Unpack[Ts], int]: + reveal_type(args) # N: Revealed type is "Tuple[builtins.bool, builtins.str, Unpack[Ts`-1], builtins.int]" + return args + +reveal_type(with_prefix_suffix(True, "bar", "foo", 5)) # N: Revealed type is "Tuple[builtins.bool, builtins.str, Literal['foo']?, builtins.int]" +reveal_type(with_prefix_suffix(True, "bar", 5)) # N: Revealed type is "Tuple[builtins.bool, builtins.str, builtins.int]" + +with_prefix_suffix(True, "bar", "foo", 1.0) # E: Argument 4 to "with_prefix_suffix" has incompatible type "float"; expected "int" +with_prefix_suffix(True, "bar") # E: Too few arguments for "with_prefix_suffix" + +t = (True, "bar", "foo", 5) +reveal_type(with_prefix_suffix(*t)) # N: Revealed type is "Tuple[builtins.bool, builtins.str, builtins.str, builtins.int]" +reveal_type(with_prefix_suffix(True, *("bar", "foo"), 5)) # N: Revealed type is "Tuple[builtins.bool, builtins.str, Literal['foo']?, builtins.int]" + +# TODO: handle list case +#reveal_type(with_prefix_suffix(True, "bar", *["foo1", "foo2"], 5)) + +bad_t = (True, "bar") +with_prefix_suffix(*bad_t) # E: Too few arguments for "with_prefix_suffix" + +def foo(*args: Unpack[Ts]) -> None: + reveal_type(with_prefix_suffix(True, "bar", *args, 5)) # N: Revealed type is "Tuple[builtins.bool, builtins.str, Unpack[Ts`-1], builtins.int]" + +def concrete(*args: Unpack[Tuple[int, str]]) -> None: + reveal_type(args) # N: Revealed type is "Tuple[builtins.int, builtins.str]" + +concrete(0, "foo") +concrete(0, 1) # E: Argument 2 to "concrete" has incompatible type "int"; expected "Unpack[Tuple[int, str]]" +concrete("foo", "bar") # E: Argument 1 to "concrete" has incompatible type "str"; expected "Unpack[Tuple[int, str]]" +concrete(0, "foo", 1) # E: Invalid number of arguments +concrete(0) # E: Invalid number of arguments +concrete() # E: Invalid number of arguments + [builtins fixtures/tuple.pyi]