From 9c90681dd63f87f047d863e899733892cba47391 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Wed, 24 Jan 2024 12:28:08 +0100 Subject: [PATCH] Use TypeVar defaults instead of Any when fixing instance types (PEP 696) --- mypy/messages.py | 15 ++- mypy/typeanal.py | 87 ++++++++++----- test-data/unit/check-typevar-defaults.test | 123 +++++++++++++++++++++ 3 files changed, 189 insertions(+), 36 deletions(-) diff --git a/mypy/messages.py b/mypy/messages.py index 450faf4c1688..75b8eeb3174c 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -3017,12 +3017,15 @@ def for_function(callee: CallableType) -> str: return "" -def wrong_type_arg_count(n: int, act: str, name: str) -> str: - s = f"{n} type arguments" - if n == 0: - s = "no type arguments" - elif n == 1: - s = "1 type argument" +def wrong_type_arg_count(low: int, high: int, act: str, name: str) -> str: + if low == high: + s = f"{low} type arguments" + if low == 0: + s = "no type arguments" + elif low == 1: + s = "1 type argument" + else: + s = f"between {low} and {high} type arguments" if act == "0": act = "none" return f'"{name}" expects {s}, but {act} given' diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 8a840424f76f..d10f26f5199a 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -9,6 +9,7 @@ from mypy import errorcodes as codes, message_registry, nodes from mypy.errorcodes import ErrorCode +from mypy.expandtype import expand_type from mypy.messages import MessageBuilder, format_type_bare, quote_type_string, wrong_type_arg_count from mypy.nodes import ( ARG_NAMED, @@ -75,6 +76,7 @@ TypeOfAny, TypeQuery, TypeType, + TypeVarId, TypeVarLikeType, TypeVarTupleType, TypeVarType, @@ -1834,14 +1836,14 @@ def get_omitted_any( return any_type -def fix_type_var_tuple_argument(any_type: Type, t: Instance) -> None: +def fix_type_var_tuple_argument(t: Instance) -> None: if t.type.has_type_var_tuple_type: args = list(t.args) assert t.type.type_var_tuple_prefix is not None tvt = t.type.defn.type_vars[t.type.type_var_tuple_prefix] assert isinstance(tvt, TypeVarTupleType) args[t.type.type_var_tuple_prefix] = UnpackType( - Instance(tvt.tuple_fallback.type, [any_type]) + Instance(tvt.tuple_fallback.type, [args[t.type.type_var_tuple_prefix]]) ) t.args = tuple(args) @@ -1855,26 +1857,42 @@ def fix_instance( use_generic_error: bool = False, unexpanded_type: Type | None = None, ) -> None: - """Fix a malformed instance by replacing all type arguments with Any. + """Fix a malformed instance by replacing all type arguments with TypeVar default or Any. Also emit a suitable error if this is not due to implicit Any's. """ - if len(t.args) == 0: - if use_generic_error: - fullname: str | None = None - else: - fullname = t.type.fullname - any_type = get_omitted_any(disallow_any, fail, note, t, options, fullname, unexpanded_type) - t.args = (any_type,) * len(t.type.type_vars) - fix_type_var_tuple_argument(any_type, t) - return - # Construct the correct number of type arguments, as - # otherwise the type checker may crash as it expects - # things to be right. - any_type = AnyType(TypeOfAny.from_error) - t.args = tuple(any_type for _ in t.type.type_vars) - fix_type_var_tuple_argument(any_type, t) - t.invalid = True + arg_count = len(t.args) + min_tv_count = sum(not tv.has_default() for tv in t.type.defn.type_vars) + max_tv_count = len(t.type.type_vars) + if arg_count < min_tv_count or arg_count > max_tv_count: + # Don't use existing args if arg_count doesn't match + t.args = () + + args: list[Type] = [*(t.args[:max_tv_count])] + any_type: AnyType | None = None + env: dict[TypeVarId, Type] = {} + + for tv, arg in itertools.zip_longest(t.type.defn.type_vars, t.args, fillvalue=None): + if tv is None: + continue + if arg is None: + if tv.has_default(): + arg = tv.default + else: + if any_type is None: + fullname = None if use_generic_error else t.type.fullname + any_type = get_omitted_any( + disallow_any, fail, note, t, options, fullname, unexpanded_type + ) + arg = any_type + args.append(arg) + env[tv.id] = arg + t.args = tuple(args) + fix_type_var_tuple_argument(t) + if not t.type.has_type_var_tuple_type: + fixed = expand_type(t, env) + assert isinstance(fixed, Instance) + t.args = fixed.args def instantiate_type_alias( @@ -1963,7 +1981,7 @@ def instantiate_type_alias( if use_standard_error: # This is used if type alias is an internal representation of another type, # for example a generic TypedDict or NamedTuple. - msg = wrong_type_arg_count(exp_len, str(act_len), node.name) + msg = wrong_type_arg_count(exp_len, exp_len, str(act_len), node.name) else: if node.tvar_tuple_index is not None: exp_len_str = f"at least {exp_len - 1}" @@ -2217,24 +2235,27 @@ def validate_instance(t: Instance, fail: MsgCallback, empty_tuple_index: bool) - # TODO: is it OK to fill with TypeOfAny.from_error instead of special form? return False if t.type.has_type_var_tuple_type: - correct = len(t.args) >= len(t.type.type_vars) - 1 + min_tv_count = sum( + not tv.has_default() and not isinstance(tv, TypeVarTupleType) + for tv in t.type.defn.type_vars + ) + correct = len(t.args) >= min_tv_count if any( isinstance(a, UnpackType) and isinstance(get_proper_type(a.type), Instance) for a in t.args ): correct = True - if not correct: - exp_len = f"at least {len(t.type.type_vars) - 1}" + if not t.args: + if not (empty_tuple_index and len(t.type.type_vars) == 1): + # The Any arguments should be set by the caller. + return False + elif not correct: fail( - f"Bad number of arguments, expected: {exp_len}, given: {len(t.args)}", + f"Bad number of arguments, expected: at least {min_tv_count}, given: {len(t.args)}", t, code=codes.TYPE_ARG, ) return False - elif not t.args: - if not (empty_tuple_index and len(t.type.type_vars) == 1): - # The Any arguments should be set by the caller. - return False else: # We also need to check if we are not performing a type variable tuple split. unpack = find_unpack_in_list(t.args) @@ -2254,15 +2275,21 @@ def validate_instance(t: Instance, fail: MsgCallback, empty_tuple_index: bool) - elif any(isinstance(a, UnpackType) for a in t.args): # A variadic unpack in fixed size instance (fixed unpacks must be flattened by the caller) fail(message_registry.INVALID_UNPACK_POSITION, t, code=codes.VALID_TYPE) + t.args = () return False elif len(t.args) != len(t.type.type_vars): # Invalid number of type parameters. - if t.args: + arg_count = len(t.args) + min_tv_count = sum(not tv.has_default() for tv in t.type.defn.type_vars) + max_tv_count = len(t.type.type_vars) + if arg_count and (arg_count < min_tv_count or arg_count > max_tv_count): fail( - wrong_type_arg_count(len(t.type.type_vars), str(len(t.args)), t.type.name), + wrong_type_arg_count(min_tv_count, max_tv_count, str(arg_count), t.type.name), t, code=codes.TYPE_ARG, ) + t.args = () + t.invalid = True return False return True diff --git a/test-data/unit/check-typevar-defaults.test b/test-data/unit/check-typevar-defaults.test index 9015d353fa08..c4d258d50ee5 100644 --- a/test-data/unit/check-typevar-defaults.test +++ b/test-data/unit/check-typevar-defaults.test @@ -116,3 +116,126 @@ def func_c1(x: Union[int, Callable[[Unpack[Ts1]], None]]) -> Tuple[Unpack[Ts1]]: # reveal_type(func_c1(callback1)) # Revealed type is "builtins.tuple[str]" # TODO # reveal_type(func_c1(2)) # Revealed type is "builtins.tuple[builtins.int, builtins.str]" # TODO [builtins fixtures/tuple.pyi] + +[case testTypeVarDefaultsClass1] +from typing import Generic, TypeVar + +T1 = TypeVar("T1") +T2 = TypeVar("T2", default=int) +T3 = TypeVar("T3", default=str) + +class ClassA1(Generic[T2, T3]): ... + +def func_a1( + a: ClassA1, + b: ClassA1[float], + c: ClassA1[float, float], + d: ClassA1[float, float, float], # E: "ClassA1" expects between 0 and 2 type arguments, but 3 given +) -> None: + reveal_type(a) # N: Revealed type is "__main__.ClassA1[builtins.int, builtins.str]" + reveal_type(b) # N: Revealed type is "__main__.ClassA1[builtins.float, builtins.str]" + reveal_type(c) # N: Revealed type is "__main__.ClassA1[builtins.float, builtins.float]" + reveal_type(d) # N: Revealed type is "__main__.ClassA1[builtins.int, builtins.str]" + +class ClassA2(Generic[T1, T2, T3]): ... + +def func_a2( + a: ClassA2, + b: ClassA2[float], + c: ClassA2[float, float], + d: ClassA2[float, float, float], + e: ClassA2[float, float, float, float], # E: "ClassA2" expects between 1 and 3 type arguments, but 4 given +) -> None: + reveal_type(a) # N: Revealed type is "__main__.ClassA2[Any, builtins.int, builtins.str]" + reveal_type(b) # N: Revealed type is "__main__.ClassA2[builtins.float, builtins.int, builtins.str]" + reveal_type(c) # N: Revealed type is "__main__.ClassA2[builtins.float, builtins.float, builtins.str]" + reveal_type(d) # N: Revealed type is "__main__.ClassA2[builtins.float, builtins.float, builtins.float]" + reveal_type(e) # N: Revealed type is "__main__.ClassA2[Any, builtins.int, builtins.str]" + +[case testTypeVarDefaultsClass2] +from typing import Generic, ParamSpec + +P1 = ParamSpec("P1") +P2 = ParamSpec("P2", default=[int, str]) +P3 = ParamSpec("P3", default=...) + +class ClassB1(Generic[P2, P3]): ... + +def func_b1( + a: ClassB1, + b: ClassB1[[float]], + c: ClassB1[[float], [float]], + d: ClassB1[[float], [float], [float]], # E: "ClassB1" expects between 0 and 2 type arguments, but 3 given +) -> None: + reveal_type(a) # N: Revealed type is "__main__.ClassB1[[builtins.int, builtins.str], ...]" + reveal_type(b) # N: Revealed type is "__main__.ClassB1[[builtins.float], ...]" + reveal_type(c) # N: Revealed type is "__main__.ClassB1[[builtins.float], [builtins.float]]" + reveal_type(d) # N: Revealed type is "__main__.ClassB1[[builtins.int, builtins.str], ...]" + +class ClassB2(Generic[P1, P2]): ... + +def func_b2( + a: ClassB2, + b: ClassB2[[float]], + c: ClassB2[[float], [float]], + d: ClassB2[[float], [float], [float]], # E: "ClassB2" expects between 1 and 2 type arguments, but 3 given +) -> None: + reveal_type(a) # N: Revealed type is "__main__.ClassB2[Any, [builtins.int, builtins.str]]" + reveal_type(b) # N: Revealed type is "__main__.ClassB2[[builtins.float], [builtins.int, builtins.str]]" + reveal_type(c) # N: Revealed type is "__main__.ClassB2[[builtins.float], [builtins.float]]" + reveal_type(d) # N: Revealed type is "__main__.ClassB2[Any, [builtins.int, builtins.str]]" + +[case testTypeVarDefaultsClass3] +from typing import Generic, Tuple, TypeVar +from typing_extensions import TypeVarTuple, Unpack + +T1 = TypeVar("T1") +T3 = TypeVar("T3", default=str) + +Ts1 = TypeVarTuple("Ts1") +Ts2 = TypeVarTuple("Ts2", default=Unpack[Tuple[int, str]]) +Ts3 = TypeVarTuple("Ts3", default=Unpack[Tuple[float, ...]]) +Ts4 = TypeVarTuple("Ts4", default=Unpack[Tuple[()]]) + +class ClassC1(Generic[Unpack[Ts2]]): ... + +def func_c1( + a: ClassC1, + b: ClassC1[float], +) -> None: + # reveal_type(a) # Revealed type is "__main__.ClassC1[builtins.int, builtins.str]" # TODO + reveal_type(b) # N: Revealed type is "__main__.ClassC1[builtins.float]" + +class ClassC2(Generic[T3, Unpack[Ts3]]): ... + +def func_c2( + a: ClassC2, + b: ClassC2[int], + c: ClassC2[int, Unpack[Tuple[()]]], +) -> None: + reveal_type(a) # N: Revealed type is "__main__.ClassC2[builtins.str, Unpack[builtins.tuple[builtins.float, ...]]]" + # reveal_type(b) # Revealed type is "__main__.ClassC2[builtins.int, Unpack[builtins.tuple[builtins.float, ...]]]" # TODO + reveal_type(c) # N: Revealed type is "__main__.ClassC2[builtins.int]" + +class ClassC3(Generic[T3, Unpack[Ts4]]): ... + +def func_c3( + a: ClassC3, + b: ClassC3[int], + c: ClassC3[int, Unpack[Tuple[float]]] +) -> None: + # reveal_type(a) # Revealed type is "__main__.ClassC3[builtins.str]" # TODO + reveal_type(b) # N: Revealed type is "__main__.ClassC3[builtins.int]" + reveal_type(c) # N: Revealed type is "__main__.ClassC3[builtins.int, builtins.float]" + +class ClassC4(Generic[T1, Unpack[Ts1], T3]): ... + +def func_c4( + a: ClassC4, + b: ClassC4[int], + c: ClassC4[int, float], +) -> None: + reveal_type(a) # N: Revealed type is "__main__.ClassC4[Any, Unpack[builtins.tuple[Any, ...]], builtins.str]" + # reveal_type(b) # Revealed type is "__main__.ClassC4[builtins.int, builtins.str]" # TODO + reveal_type(c) # N: Revealed type is "__main__.ClassC4[builtins.int, builtins.float]" +[builtins fixtures/tuple.pyi]