Skip to content

Commit

Permalink
Properly use proper subtyping for callables (#16343)
Browse files Browse the repository at this point in the history
Fixes #16338

This is kind of a major change, but it is technically correct: we should
not treat `(*args: Any, **kwargs: Any)` special in `is_proper_subtype()`
(only in `is_subtype()`). Unfortunately, this requires an additional
flag for `is_callable_compatible()`, since currently we are passing the
subtype kind information via a callback, which is not applicable to
handling argument kinds.
  • Loading branch information
ilevkivskyi authored Oct 27, 2023
1 parent 4f05dd5 commit 5c6ca5c
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 9 deletions.
11 changes: 8 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:

# Is the overload alternative's arguments subtypes of the implementation's?
if not is_callable_compatible(
impl, sig1, is_compat=is_subtype, ignore_return=True
impl, sig1, is_compat=is_subtype, is_proper_subtype=False, ignore_return=True
):
self.msg.overloaded_signatures_arg_specific(i + 1, defn.impl)

Expand Down Expand Up @@ -7685,6 +7685,7 @@ def is_unsafe_overlapping_overload_signatures(
signature,
other,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_proper_subtype=False,
is_compat_return=lambda l, r: not is_subtype_no_promote(l, r),
ignore_return=False,
check_args_covariantly=True,
Expand All @@ -7694,6 +7695,7 @@ def is_unsafe_overlapping_overload_signatures(
other,
signature,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_proper_subtype=False,
is_compat_return=lambda l, r: not is_subtype_no_promote(r, l),
ignore_return=False,
check_args_covariantly=False,
Expand Down Expand Up @@ -7744,7 +7746,7 @@ def overload_can_never_match(signature: CallableType, other: CallableType) -> bo
signature, {tvar.id: erase_def_to_union_or_bound(tvar) for tvar in signature.variables}
)
return is_callable_compatible(
exp_signature, other, is_compat=is_more_precise, ignore_return=True
exp_signature, other, is_compat=is_more_precise, is_proper_subtype=True, ignore_return=True
)


Expand All @@ -7754,7 +7756,9 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool:
# general than one with fewer items (or just one item)?
if isinstance(t, CallableType):
if isinstance(s, CallableType):
return is_callable_compatible(t, s, is_compat=is_proper_subtype, ignore_return=True)
return is_callable_compatible(
t, s, is_compat=is_proper_subtype, is_proper_subtype=True, ignore_return=True
)
elif isinstance(t, FunctionLike):
if isinstance(s, FunctionLike):
if len(t.items) == len(s.items):
Expand All @@ -7769,6 +7773,7 @@ def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool:
t,
s,
is_compat=is_same_type,
is_proper_subtype=True,
ignore_return=True,
check_args_covariantly=True,
ignore_pos_arg_names=True,
Expand Down
12 changes: 10 additions & 2 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,11 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType)
# Return type may be indeterminate in the template, so ignore it when performing a
# subtype check.
if mypy.subtypes.is_callable_compatible(
item, template, is_compat=mypy.subtypes.is_subtype, ignore_return=True
item,
template,
is_compat=mypy.subtypes.is_subtype,
is_proper_subtype=False,
ignore_return=True,
):
return item
# Fall back to the first item if we can't find a match. This is totally arbitrary --
Expand All @@ -1370,7 +1374,11 @@ def find_matching_overload_items(
# Return type may be indeterminate in the template, so ignore it when performing a
# subtype check.
if mypy.subtypes.is_callable_compatible(
item, template, is_compat=mypy.subtypes.is_subtype, ignore_return=True
item,
template,
is_compat=mypy.subtypes.is_subtype,
is_proper_subtype=False,
ignore_return=True,
):
res.append(item)
if not res:
Expand Down
1 change: 1 addition & 0 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
left,
right,
is_compat=_is_overlapping_types,
is_proper_subtype=False,
ignore_pos_arg_names=True,
allow_partial_overlap=True,
)
Expand Down
14 changes: 11 additions & 3 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,8 @@ def visit_parameters(self, left: Parameters) -> bool:
left,
self.right,
is_compat=self._is_subtype,
# TODO: this should pass the current value, but then couple tests fail.
is_proper_subtype=False,
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
)
else:
Expand All @@ -677,6 +679,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
left,
right,
is_compat=self._is_subtype,
is_proper_subtype=self.proper_subtype,
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
strict_concatenate=(self.options.extra_checks or self.options.strict_concatenate)
if self.options
Expand Down Expand Up @@ -932,6 +935,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
left_item,
right_item,
is_compat=self._is_subtype,
is_proper_subtype=self.proper_subtype,
ignore_return=True,
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
strict_concatenate=strict_concat,
Expand All @@ -940,6 +944,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
right_item,
left_item,
is_compat=self._is_subtype,
is_proper_subtype=self.proper_subtype,
ignore_return=True,
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
strict_concatenate=strict_concat,
Expand Down Expand Up @@ -1358,6 +1363,7 @@ def is_callable_compatible(
right: CallableType,
*,
is_compat: Callable[[Type, Type], bool],
is_proper_subtype: bool,
is_compat_return: Callable[[Type, Type], bool] | None = None,
ignore_return: bool = False,
ignore_pos_arg_names: bool = False,
Expand Down Expand Up @@ -1517,6 +1523,7 @@ def g(x: int) -> int: ...
left,
right,
is_compat=is_compat,
is_proper_subtype=is_proper_subtype,
ignore_pos_arg_names=ignore_pos_arg_names,
allow_partial_overlap=allow_partial_overlap,
strict_concatenate_check=strict_concatenate_check,
Expand Down Expand Up @@ -1552,12 +1559,13 @@ def are_parameters_compatible(
right: Parameters | NormalizedCallableType,
*,
is_compat: Callable[[Type, Type], bool],
is_proper_subtype: bool,
ignore_pos_arg_names: bool = False,
allow_partial_overlap: bool = False,
strict_concatenate_check: bool = False,
) -> bool:
"""Helper function for is_callable_compatible, used for Parameter compatibility"""
if right.is_ellipsis_args:
if right.is_ellipsis_args and not is_proper_subtype:
return True

left_star = left.var_arg()
Expand All @@ -1566,9 +1574,9 @@ def are_parameters_compatible(
right_star2 = right.kw_arg()

# Treat "def _(*a: Any, **kw: Any) -> X" similarly to "Callable[..., X]"
if are_trivial_parameters(right):
if are_trivial_parameters(right) and not is_proper_subtype:
return True
trivial_suffix = is_trivial_suffix(right)
trivial_suffix = is_trivial_suffix(right) and not is_proper_subtype

# Match up corresponding arguments and check them for compatibility. In
# every pair (argL, argR) of corresponding arguments from L and R, argL must
Expand Down
22 changes: 21 additions & 1 deletion test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6501,7 +6501,7 @@ eggs = lambda: 'eggs'
reveal_type(func(eggs)) # N: Revealed type is "def (builtins.str) -> builtins.str"

spam: Callable[..., str] = lambda x, y: 'baz'
reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> builtins.str"
reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> Any"
[builtins fixtures/paramspec.pyi]

[case testGenericOverloadOverlapWithType]
Expand Down Expand Up @@ -6673,3 +6673,23 @@ c2 = MyCallable("test")
reveal_type(c2) # N: Revealed type is "__main__.MyCallable[builtins.str]"
reveal_type(c2()) # should be int # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testOverloadWithStarAnyFallback]
from typing import overload, Any

class A:
@overload
def f(self, e: str) -> str: ...
@overload
def f(self, *args: Any, **kwargs: Any) -> Any: ...
def f(self, *args, **kwargs):
pass

class B:
@overload
def f(self, e: str, **kwargs: Any) -> str: ...
@overload
def f(self, *args: Any, **kwargs: Any) -> Any: ...
def f(self, *args, **kwargs):
pass
[builtins fixtures/tuple.pyi]

0 comments on commit 5c6ca5c

Please sign in to comment.