From 249f3f8285d9d2a0f77273ace805dac0eef684c6 Mon Sep 17 00:00:00 2001 From: Shantanu <12621235+hauntsaninja@users.noreply.github.com> Date: Mon, 18 Sep 2023 23:53:31 -0700 Subject: [PATCH] Fix inference for overloaded __call__ with generic self (#16053) Fixes #8283 Co-authored-by: ilevkivskyi --- mypy/checkexpr.py | 4 ++- mypy/checkmember.py | 13 ++++--- mypy/subtypes.py | 51 +++++++++++++++------------ test-data/unit/check-overloading.test | 24 +++++++++++++ test-data/unit/check-tuples.test | 14 ++++++++ 5 files changed, 76 insertions(+), 30 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index f46c8cb15c6f..7b9b84938930 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1475,6 +1475,7 @@ def check_call( callable_node: Expression | None = None, callable_name: str | None = None, object_type: Type | None = None, + original_type: Type | None = None, ) -> tuple[Type, Type]: """Type check a call. @@ -1537,7 +1538,7 @@ def check_call( is_super=False, is_operator=True, msg=self.msg, - original_type=callee, + original_type=original_type or callee, chk=self.chk, in_literal_context=self.is_literal_context(), ) @@ -1578,6 +1579,7 @@ def check_call( callable_node, callable_name, object_type, + original_type=callee, ) else: return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 4316a59281c3..1557b62917dc 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -331,13 +331,12 @@ def analyze_instance_member_access( signature = method.type signature = freshen_all_functions_type_vars(signature) if not method.is_static: - if name != "__call__": - # TODO: use proper treatment of special methods on unions instead - # of this hack here and below (i.e. mx.self_type). - dispatched_type = meet.meet_types(mx.original_type, typ) - signature = check_self_arg( - signature, dispatched_type, method.is_class, mx.context, name, mx.msg - ) + # TODO: use proper treatment of special methods on unions instead + # of this hack here and below (i.e. mx.self_type). + dispatched_type = meet.meet_types(mx.original_type, typ) + signature = check_self_arg( + signature, dispatched_type, method.is_class, mx.context, name, mx.msg + ) signature = bind_self(signature, mx.self_type, is_classmethod=method.is_class) # TODO: should we skip these steps for static methods as well? # Since generic static methods should not be allowed. diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 9ed2e4af4051..c5399db0a494 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -454,19 +454,22 @@ def visit_instance(self, left: Instance) -> bool: if isinstance(unpacked, Instance): return self._is_subtype(left, unpacked) if left.type.has_base(right.partial_fallback.type.fullname): - # Special case to consider Foo[*tuple[Any, ...]] (i.e. bare Foo) a - # subtype of Foo[], when Foo is user defined variadic tuple type. - mapped = map_instance_to_supertype(left, right.partial_fallback.type) - if len(mapped.args) == 1 and isinstance(mapped.args[0], UnpackType): - unpacked = get_proper_type(mapped.args[0].type) - if isinstance(unpacked, Instance): - assert unpacked.type.fullname == "builtins.tuple" - if isinstance(get_proper_type(unpacked.args[0]), AnyType): - return not self.proper_subtype - if mapped.type.fullname == "builtins.tuple" and isinstance( - get_proper_type(mapped.args[0]), AnyType - ): - return not self.proper_subtype + if not self.proper_subtype: + # Special case to consider Foo[*tuple[Any, ...]] (i.e. bare Foo) a + # subtype of Foo[], when Foo is user defined variadic tuple type. + mapped = map_instance_to_supertype(left, right.partial_fallback.type) + for arg in map(get_proper_type, mapped.args): + if isinstance(arg, UnpackType): + unpacked = get_proper_type(arg.type) + if not isinstance(unpacked, Instance): + break + assert unpacked.type.fullname == "builtins.tuple" + if not isinstance(get_proper_type(unpacked.args[0]), AnyType): + break + elif not isinstance(arg, AnyType): + break + else: + return True return False if isinstance(right, TypeVarTupleType): # tuple[Any, ...] is like Any in the world of tuples (see special case above). @@ -534,15 +537,19 @@ def visit_instance(self, left: Instance) -> bool: right_args = ( right_prefix + (TupleType(list(right_middle), fallback),) + right_suffix ) - if len(t.args) == 1 and isinstance(t.args[0], UnpackType): - unpacked = get_proper_type(t.args[0].type) - if isinstance(unpacked, Instance): - assert unpacked.type.fullname == "builtins.tuple" - if ( - isinstance(get_proper_type(unpacked.args[0]), AnyType) - and not self.proper_subtype - ): - return True + if not self.proper_subtype: + for arg in map(get_proper_type, t.args): + if isinstance(arg, UnpackType): + unpacked = get_proper_type(arg.type) + if not isinstance(unpacked, Instance): + break + assert unpacked.type.fullname == "builtins.tuple" + if not isinstance(get_proper_type(unpacked.args[0]), AnyType): + break + elif not isinstance(arg, AnyType): + break + else: + return True type_params = zip(left_args, right_args, right.type.defn.type_vars) else: type_params = zip(t.args, right.args, right.type.defn.type_vars) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 4546c7171856..443a6fb5cb10 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6650,3 +6650,27 @@ def d(x: int) -> int: ... def d(f: int, *, x: int) -> str: ... def d(*args, **kwargs): ... [builtins fixtures/tuple.pyi] + +[case testOverloadCallableGenericSelf] +from typing import Any, TypeVar, Generic, overload, reveal_type + +T = TypeVar("T") + +class MyCallable(Generic[T]): + def __init__(self, t: T): + self.t = t + + @overload + def __call__(self: "MyCallable[int]") -> str: ... + @overload + def __call__(self: "MyCallable[str]") -> int: ... + def __call__(self): ... + +c = MyCallable(5) +reveal_type(c) # N: Revealed type is "__main__.MyCallable[builtins.int]" +reveal_type(c()) # N: Revealed type is "builtins.str" + +c2 = MyCallable("test") +reveal_type(c2) # N: Revealed type is "__main__.MyCallable[builtins.str]" +reveal_type(c2()) # should be int # N: Revealed type is "builtins.int" +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-tuples.test b/test-data/unit/check-tuples.test index 391fa20db738..ed2c3550a04e 100644 --- a/test-data/unit/check-tuples.test +++ b/test-data/unit/check-tuples.test @@ -1434,7 +1434,21 @@ def foo(o: CallableTuple) -> int: class CallableTuple(Tuple[str, int]): def __call__(self, n: int, m: int) -> int: return n +[builtins fixtures/tuple.pyi] + +[case testTypeTupleGenericCall] +from typing import Generic, Tuple, TypeVar + +T = TypeVar('T') +def foo(o: CallableTuple[int]) -> int: + reveal_type(o) # N: Revealed type is "Tuple[builtins.str, builtins.int, fallback=__main__.CallableTuple[builtins.int]]" + reveal_type(o.count(3)) # N: Revealed type is "builtins.int" + return o(1, 2) + +class CallableTuple(Tuple[str, T]): + def __call__(self, n: int, m: int) -> int: + return n [builtins fixtures/tuple.pyi] [case testTupleCompatibleWithSequence]