Skip to content

Commit

Permalink
Fix inference for overloaded __call__ with generic self (#16053)
Browse files Browse the repository at this point in the history
Fixes #8283

Co-authored-by: ilevkivskyi
  • Loading branch information
hauntsaninja committed Sep 19, 2023
1 parent ba978f4 commit 249f3f8
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 30 deletions.
4 changes: 3 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,6 +1475,7 @@ def check_call(
callable_node: Expression | None = None,
callable_name: str | None = None,
object_type: Type | None = None,
original_type: Type | None = None,
) -> tuple[Type, Type]:
"""Type check a call.
Expand Down Expand Up @@ -1537,7 +1538,7 @@ def check_call(
is_super=False,
is_operator=True,
msg=self.msg,
original_type=callee,
original_type=original_type or callee,
chk=self.chk,
in_literal_context=self.is_literal_context(),
)
Expand Down Expand Up @@ -1578,6 +1579,7 @@ def check_call(
callable_node,
callable_name,
object_type,
original_type=callee,
)
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)
Expand Down
13 changes: 6 additions & 7 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,12 @@ def analyze_instance_member_access(
signature = method.type
signature = freshen_all_functions_type_vars(signature)
if not method.is_static:
if name != "__call__":
# TODO: use proper treatment of special methods on unions instead
# of this hack here and below (i.e. mx.self_type).
dispatched_type = meet.meet_types(mx.original_type, typ)
signature = check_self_arg(
signature, dispatched_type, method.is_class, mx.context, name, mx.msg
)
# TODO: use proper treatment of special methods on unions instead
# of this hack here and below (i.e. mx.self_type).
dispatched_type = meet.meet_types(mx.original_type, typ)
signature = check_self_arg(
signature, dispatched_type, method.is_class, mx.context, name, mx.msg
)
signature = bind_self(signature, mx.self_type, is_classmethod=method.is_class)
# TODO: should we skip these steps for static methods as well?
# Since generic static methods should not be allowed.
Expand Down
51 changes: 29 additions & 22 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,19 +454,22 @@ def visit_instance(self, left: Instance) -> bool:
if isinstance(unpacked, Instance):
return self._is_subtype(left, unpacked)
if left.type.has_base(right.partial_fallback.type.fullname):
# Special case to consider Foo[*tuple[Any, ...]] (i.e. bare Foo) a
# subtype of Foo[<whatever>], when Foo is user defined variadic tuple type.
mapped = map_instance_to_supertype(left, right.partial_fallback.type)
if len(mapped.args) == 1 and isinstance(mapped.args[0], UnpackType):
unpacked = get_proper_type(mapped.args[0].type)
if isinstance(unpacked, Instance):
assert unpacked.type.fullname == "builtins.tuple"
if isinstance(get_proper_type(unpacked.args[0]), AnyType):
return not self.proper_subtype
if mapped.type.fullname == "builtins.tuple" and isinstance(
get_proper_type(mapped.args[0]), AnyType
):
return not self.proper_subtype
if not self.proper_subtype:
# Special case to consider Foo[*tuple[Any, ...]] (i.e. bare Foo) a
# subtype of Foo[<whatever>], when Foo is user defined variadic tuple type.
mapped = map_instance_to_supertype(left, right.partial_fallback.type)
for arg in map(get_proper_type, mapped.args):
if isinstance(arg, UnpackType):
unpacked = get_proper_type(arg.type)
if not isinstance(unpacked, Instance):
break
assert unpacked.type.fullname == "builtins.tuple"
if not isinstance(get_proper_type(unpacked.args[0]), AnyType):
break
elif not isinstance(arg, AnyType):
break
else:
return True
return False
if isinstance(right, TypeVarTupleType):
# tuple[Any, ...] is like Any in the world of tuples (see special case above).
Expand Down Expand Up @@ -534,15 +537,19 @@ def visit_instance(self, left: Instance) -> bool:
right_args = (
right_prefix + (TupleType(list(right_middle), fallback),) + right_suffix
)
if len(t.args) == 1 and isinstance(t.args[0], UnpackType):
unpacked = get_proper_type(t.args[0].type)
if isinstance(unpacked, Instance):
assert unpacked.type.fullname == "builtins.tuple"
if (
isinstance(get_proper_type(unpacked.args[0]), AnyType)
and not self.proper_subtype
):
return True
if not self.proper_subtype:
for arg in map(get_proper_type, t.args):
if isinstance(arg, UnpackType):
unpacked = get_proper_type(arg.type)
if not isinstance(unpacked, Instance):
break
assert unpacked.type.fullname == "builtins.tuple"
if not isinstance(get_proper_type(unpacked.args[0]), AnyType):
break
elif not isinstance(arg, AnyType):
break
else:
return True
type_params = zip(left_args, right_args, right.type.defn.type_vars)
else:
type_params = zip(t.args, right.args, right.type.defn.type_vars)
Expand Down
24 changes: 24 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6650,3 +6650,27 @@ def d(x: int) -> int: ...
def d(f: int, *, x: int) -> str: ...
def d(*args, **kwargs): ...
[builtins fixtures/tuple.pyi]

[case testOverloadCallableGenericSelf]
from typing import Any, TypeVar, Generic, overload, reveal_type

T = TypeVar("T")

class MyCallable(Generic[T]):
def __init__(self, t: T):
self.t = t

@overload
def __call__(self: "MyCallable[int]") -> str: ...
@overload
def __call__(self: "MyCallable[str]") -> int: ...
def __call__(self): ...

c = MyCallable(5)
reveal_type(c) # N: Revealed type is "__main__.MyCallable[builtins.int]"
reveal_type(c()) # N: Revealed type is "builtins.str"

c2 = MyCallable("test")
reveal_type(c2) # N: Revealed type is "__main__.MyCallable[builtins.str]"
reveal_type(c2()) # should be int # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]
14 changes: 14 additions & 0 deletions test-data/unit/check-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,21 @@ def foo(o: CallableTuple) -> int:
class CallableTuple(Tuple[str, int]):
def __call__(self, n: int, m: int) -> int:
return n
[builtins fixtures/tuple.pyi]

[case testTypeTupleGenericCall]
from typing import Generic, Tuple, TypeVar

T = TypeVar('T')

def foo(o: CallableTuple[int]) -> int:
reveal_type(o) # N: Revealed type is "Tuple[builtins.str, builtins.int, fallback=__main__.CallableTuple[builtins.int]]"
reveal_type(o.count(3)) # N: Revealed type is "builtins.int"
return o(1, 2)

class CallableTuple(Tuple[str, T]):
def __call__(self, n: int, m: int) -> int:
return n
[builtins fixtures/tuple.pyi]

[case testTupleCompatibleWithSequence]
Expand Down

0 comments on commit 249f3f8

Please sign in to comment.