Skip to content

Commit

Permalink
Filter overload items based on self type during type inference (#17873)
Browse files Browse the repository at this point in the history
Fix type argument inference for overloaded functions with explicit self
types. Filter out the overload items based on the declared and actual
types of self. The implementation is best effort and does the filtering
only in simple cases, to reduce the risk of regressions (primarily
performance, but I worry also about infinite recursion). I added a fast
path for the typical case, since without it the filtering was quite
expensive.

Note that the overload item filtering already worked in many contexts.
This only improves it in specific contexts -- at least when inferring
generic protocol compatibility.

This is a more localized (and thus lower-risk) fix compared to #14975
(thanks @tyralla!). #14975 might still be a good idea, but I'm not
comfortable merging it now, and I want a quick fix to unblock the mypy
1.12 release.

Fixes #15031. Fixes #17863.

Co-authored by @tyralla.
  • Loading branch information
JukkaL authored Oct 3, 2024
1 parent ac98ab5 commit 3c09b32
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 3 deletions.
62 changes: 59 additions & 3 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import (
ARG_OPT,
ARG_POS,
ARG_STAR,
ARG_STAR2,
Expand Down Expand Up @@ -305,9 +306,27 @@ class B(A): pass
"""
if isinstance(method, Overloaded):
items = [
bind_self(c, original_type, is_classmethod, ignore_instances) for c in method.items
]
items = []
original_type = get_proper_type(original_type)
for c in method.items:
if isinstance(original_type, Instance):
# Filter based on whether declared self type can match actual object type.
# For example, if self has type C[int] and method is accessed on a C[str] value,
# omit this item. This is best effort since bind_self can be called in many
# contexts, and doing complete validation might trigger infinite recursion.
#
# Note that overload item filtering normally happens elsewhere. This is needed
# at least during constraint inference.
keep = is_valid_self_type_best_effort(c, original_type)
else:
keep = True
if keep:
items.append(bind_self(c, original_type, is_classmethod, ignore_instances))
if len(items) == 0:
# If no item matches, returning all items helps avoid some spurious errors
items = [
bind_self(c, original_type, is_classmethod, ignore_instances) for c in method.items
]
return cast(F, Overloaded(items))
assert isinstance(method, CallableType)
func = method
Expand Down Expand Up @@ -379,6 +398,43 @@ class B(A): pass
return cast(F, res)


def is_valid_self_type_best_effort(c: CallableType, self_type: Instance) -> bool:
"""Quickly check if self_type might match the self in a callable.
Avoid performing any complex type operations. This is performance-critical.
Default to returning True if we don't know (or it would be too expensive).
"""
if (
self_type.args
and c.arg_types
and isinstance((arg_type := get_proper_type(c.arg_types[0])), Instance)
and c.arg_kinds[0] in (ARG_POS, ARG_OPT)
and arg_type.args
and self_type.type.fullname != "functools._SingleDispatchCallable"
):
if self_type.type is not arg_type.type:
# We can't map to supertype, since it could trigger expensive checks for
# protocol types, so we consevatively assume this is fine.
return True

# Fast path: no explicit annotation on self
if all(
(
type(arg) is TypeVarType
and type(arg.upper_bound) is Instance
and arg.upper_bound.type.fullname == "builtins.object"
)
for arg in arg_type.args
):
return True

from mypy.meet import is_overlapping_types

return is_overlapping_types(self_type, c.arg_types[0])
return True


def erase_to_bound(t: Type) -> Type:
# TODO: use value restrictions to produce a union?
t = get_proper_type(t)
Expand Down
18 changes: 18 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6750,3 +6750,21 @@ def foo(x: object) -> str: ...
def bar(x: int) -> int: ...
@overload
def bar(x: Any) -> str: ...

[case testOverloadOnInvalidTypeArgument]
from typing import TypeVar, Self, Generic, overload

class C: pass

T = TypeVar("T", bound=C)

class D(Generic[T]):
@overload
def f(self, x: int) -> int: ...
@overload
def f(self, x: str) -> str: ...
def f(Self, x): ...

a: D[str] # E: Type argument "str" of "D" must be a subtype of "C"
reveal_type(a.f(1)) # N: Revealed type is "builtins.int"
reveal_type(a.f("x")) # N: Revealed type is "builtins.str"
88 changes: 88 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -4127,3 +4127,91 @@ class P(Protocol):

class C(P): ...
C(0) # OK

[case testTypeVarValueConstraintAgainstGenericProtocol]
from typing import TypeVar, Generic, Protocol, overload

T_contra = TypeVar("T_contra", contravariant=True)
AnyStr = TypeVar("AnyStr", str, bytes)

class SupportsWrite(Protocol[T_contra]):
def write(self, s: T_contra, /) -> None: ...

class Buffer: ...

class IO(Generic[AnyStr]):
@overload
def write(self: IO[bytes], s: Buffer, /) -> None: ...
@overload
def write(self, s: AnyStr, /) -> None: ...
def write(self, s): ...

def foo(fdst: SupportsWrite[AnyStr]) -> None: ...

x: IO[str]
foo(x)

[case testTypeVarValueConstraintAgainstGenericProtocol2]
from typing import Generic, Protocol, TypeVar, overload

AnyStr = TypeVar("AnyStr", str, bytes)
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)

class SupportsRead(Generic[T_co]):
def read(self) -> T_co: ...

class SupportsWrite(Protocol[T_contra]):
def write(self, s: T_contra) -> object: ...

def copyfileobj(fsrc: SupportsRead[AnyStr], fdst: SupportsWrite[AnyStr]) -> None: ...

class WriteToMe(Generic[AnyStr]):
@overload
def write(self: WriteToMe[str], s: str) -> int: ...
@overload
def write(self: WriteToMe[bytes], s: bytes) -> int: ...
def write(self, s): ...

class WriteToMeOrReadFromMe(WriteToMe[AnyStr], SupportsRead[AnyStr]): ...

copyfileobj(WriteToMeOrReadFromMe[bytes](), WriteToMe[bytes]())

[case testOverloadedMethodWithExplictSelfTypes]
from typing import Generic, overload, Protocol, TypeVar, Union

AnyStr = TypeVar("AnyStr", str, bytes)
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)

class SupportsRead(Protocol[T_co]):
def read(self) -> T_co: ...

class SupportsWrite(Protocol[T_contra]):
def write(self, s: T_contra) -> int: ...

class Input(Generic[AnyStr]):
def read(self) -> AnyStr: ...

class Output(Generic[AnyStr]):
@overload
def write(self: Output[str], s: str) -> int: ...
@overload
def write(self: Output[bytes], s: bytes) -> int: ...
def write(self, s: Union[str, bytes]) -> int: ...

def f(src: SupportsRead[AnyStr], dst: SupportsWrite[AnyStr]) -> None: ...

def g1(a: Input[bytes], b: Output[bytes]) -> None:
f(a, b)

def g2(a: Input[bytes], b: Output[bytes]) -> None:
f(a, b)

def g3(a: Input[str], b: Output[bytes]) -> None:
f(a, b) # E: Cannot infer type argument 1 of "f"

def g4(a: Input[bytes], b: Output[str]) -> None:
f(a, b) # E: Cannot infer type argument 1 of "f"

[builtins fixtures/tuple.pyi]

0 comments on commit 3c09b32

Please sign in to comment.