Skip to content

Commit

Permalink
Fix inference of protocol against overloaded function (#12227)
Browse files Browse the repository at this point in the history
We used to infer a callable in a protocol against all overload
items. This could result in incorrect results, if only one
of the overload items would actually match the protocol.

Fix the issue by only considering the first matching overload
item.

This seems to help with protocols involving `__getitem__`.
In particular, this fixes regressions related to
`SupportsLenAndGetItem`, which is used for `random.choice`.
  • Loading branch information
JukkaL authored Feb 22, 2022
1 parent 2c9a8e7 commit a8b6d6f
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 1 deletion.
25 changes: 24 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,12 @@ def infer_against_any(self, types: Iterable[Type], any_type: AnyType) -> List[Co
return res

def visit_overloaded(self, template: Overloaded) -> List[Constraint]:
if isinstance(self.actual, CallableType):
items = find_matching_overload_items(template, self.actual)
else:
items = template.items
res: List[Constraint] = []
for t in template.items:
for t in items:
res.extend(infer_constraints(t, self.actual, self.direction))
return res

Expand Down Expand Up @@ -701,3 +705,22 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType)
# Fall back to the first item if we can't find a match. This is totally arbitrary --
# maybe we should just bail out at this point.
return items[0]


def find_matching_overload_items(overloaded: Overloaded,
template: CallableType) -> List[CallableType]:
"""Like find_matching_overload_item, but return all matches, not just the first."""
items = overloaded.items
res = []
for item in items:
# Return type may be indeterminate in the template, so ignore it when performing a
# subtype check.
if mypy.subtypes.is_callable_compatible(item, template,
is_compat=mypy.subtypes.is_subtype,
ignore_return=True):
res.append(item)
if not res:
# Falling back to all items if we can't find a match is pretty arbitrary, but
# it maintains backward compatibility.
res = items[:]
return res
83 changes: 83 additions & 0 deletions test-data/unit/check-protocols.test
Original file line number Diff line number Diff line change
Expand Up @@ -2806,3 +2806,86 @@ class MyClass:
assert isinstance(self, MyProtocol)
[builtins fixtures/isinstance.pyi]
[typing fixtures/typing-full.pyi]

[case testMatchProtocolAgainstOverloadWithAmbiguity]
from typing import TypeVar, Protocol, Union, Generic, overload

T = TypeVar("T", covariant=True)

class slice: pass

class GetItem(Protocol[T]):
def __getitem__(self, k: int) -> T: ...

class Str: # Resembles 'str'
def __getitem__(self, k: Union[int, slice]) -> Str: ...

class Lst(Generic[T]): # Resembles 'list'
def __init__(self, x: T): ...
@overload
def __getitem__(self, k: int) -> T: ...
@overload
def __getitem__(self, k: slice) -> Lst[T]: ...
def __getitem__(self, k): pass

def f(x: GetItem[GetItem[Str]]) -> None: ...

a: Lst[Str]
f(Lst(a))

class Lst2(Generic[T]):
def __init__(self, x: T): ...
# The overload items are tweaked but still compatible
@overload
def __getitem__(self, k: Str) -> None: ...
@overload
def __getitem__(self, k: slice) -> Lst2[T]: ...
@overload
def __getitem__(self, k: Union[int, str]) -> T: ...
def __getitem__(self, k): pass

b: Lst2[Str]
f(Lst2(b))

class Lst3(Generic[T]): # Resembles 'list'
def __init__(self, x: T): ...
# The overload items are no longer compatible (too narrow argument type)
@overload
def __getitem__(self, k: slice) -> Lst3[T]: ...
@overload
def __getitem__(self, k: bool) -> T: ...
def __getitem__(self, k): pass

c: Lst3[Str]
f(Lst3(c)) # E: Argument 1 to "f" has incompatible type "Lst3[Lst3[Str]]"; expected "GetItem[GetItem[Str]]" \
# N: Following member(s) of "Lst3[Lst3[Str]]" have conflicts: \
# N: Expected: \
# N: def __getitem__(self, int) -> GetItem[Str] \
# N: Got: \
# N: @overload \
# N: def __getitem__(self, slice) -> Lst3[Lst3[Str]] \
# N: @overload \
# N: def __getitem__(self, bool) -> Lst3[Str]

[builtins fixtures/list.pyi]
[typing fixtures/typing-full.pyi]

[case testMatchProtocolAgainstOverloadWithMultipleMatchingItems]
from typing import Protocol, overload, TypeVar, Any

_T_co = TypeVar("_T_co", covariant=True)
_T = TypeVar("_T")

class SupportsRound(Protocol[_T_co]):
@overload
def __round__(self) -> int: ...
@overload
def __round__(self, __ndigits: int) -> _T_co: ...

class C:
# This matches both overload items of SupportsRound
def __round__(self, __ndigits: int = ...) -> int: ...

def round(number: SupportsRound[_T], ndigits: int) -> _T: ...

round(C(), 1)

0 comments on commit a8b6d6f

Please sign in to comment.