Skip to content

Commit

Permalink
Re-work overload overlap logic (#17392)
Browse files Browse the repository at this point in the history
Fixes #5510

OK, so I noticed during last couple years, that every other time I
change something about type variables, a few unsafe overload overlap
errors either appears or disappears. At some point I almost stopped
looking at them. The problem is that unsafe overload overlap detection
for generic callables is currently ad-hoc. However, as I started working
on it, I discovered a bunch of foundational problems (and few smaller
issues), so I decided to re-work the unsafe overload overlap detection.
Here is a detailed summary:

* Currently return type compatibility is decided using regular subtype
check. Although it is technically
correct, in most cases there is nothing wrong if first overload returns
`list[Subtype]` and second returns `list[Supertype]`. All the unsafe
overload story is about runtime values, not static types, so we should
use `is_subset()` instead of `is_subtype()`, which is IIUC easy to
implement: we simply need to consider all invariant types covariant.
* Current implementation only checks for overlap between parameters,
i.e. it checks if there are some calls that are valid for both
overloads. But we also need to check that those common calls will not be
always caught by the first overload. I assume it was not checked
because, naively, we already check elsewhere that first overload doesn't
completely shadow the second one. But this is not the same: first
overload may be not more general overall, but when narrowed to common
calls, it may be more general. Example of such false-positive (this is
an oversimplified version of what is often used in situations with many
optional positional arguments):
  ```python
  @overload
  def foo(x: object) -> object: ...
  @overload
  def foo(x: int = ...) -> int: ...
  ```
* Currently overlap for generic callables is decided using some weird
two-way unification procedure, where we actually keep going on (with
non-unified variables, and/or `<never>`) if the right to left
unification fails. TBH I never understood this. What we need is to find
some set of type variable values that makes two overloads unsafely
overlapping. Constraint inference may be used as a (good) source of such
guesses, but is not decisive in any way. So instead I simply try all
combinations of upper bounds and values. The main benefit of such
approach is that it is guaranteed false-positive free. If such algorithm
finds an overlap it is definitely an overlap. There are however false
negatives, but we can incrementally tighten them in the future.
* I am making `Any` overlap nothing when considering overloads.
Currently it overlaps everything (i.e. it is not different from
`object`), but this violates the rule that replacing a precise type with
`Any` should not generate an error. IOW I essentially treat `Any` as
"too dynamic or not imported".
* I extend `None` special-casing to be more uniform. Now essentially it
only overlaps with explicitly optional types. This is important for
descriptor-like signatures.
* Finally, I did a cleanup in `is_overlapping_types()`, most notably
flags were not passed down to various (recursive) helpers, and
`ParamSpec`/`Parameters` were treated a bit arbitrary.

Pros/cons of the outcome:
* Pro: simple (even if not 100% accurate) mental model
* Pro: all major classes of false positives eliminated
* Pro: couple minor false negatives fixed
* Con: two new false negatives added, more details below

So here a two new false negatives and motivation on why I think they are
OK. First example is
```python
T = TypeVar("T")

@overload
def foo(x: str) -> int: ...
@overload
def foo(x: T) -> T: ...
def foo(x):
    if isinstance(x, str):
        return 0
    return x
```
This is obviously unsafe (consider `T = float`), but not flagged after
this PR. I think this is ~fine for two reasons:
* There is no good alternative for a user, the error is not very
actionable. Using types like `(str | T) -> int | T` is a bad idea
because unions with type variables are not only imprecise, but also
highly problematic for inference.
* The false negative is mostly affecting unbounded type variables, if a
"suspicious" bound is used (like `bound=float` in this example), the
error will be still reported.

Second example is signatures like
```python
@overload
def foo(x: str, y: str) -> str: ...
@overload
def foo(*args: str) -> int: ...

@overload
def bar(*, x: str, y: str) -> str: ...
@overload
def bar(**kwds: str) -> int: ...
```
These are also unsafe because one can fool mypy with `x: tuple[str, ...]
= ("x", "y"); foo(*x)` and `x: dict[str, str] = {"x": "x", "y": "y"};
bar(**x)`. I think this is OK because while such unsafe calls are quite
rare, this kind of catch-all fallback as last overload is relatively
common.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 19, 2024
1 parent e1ff8aa commit 7cb733a
Show file tree
Hide file tree
Showing 16 changed files with 352 additions and 275 deletions.
195 changes: 122 additions & 73 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@
false_only,
fixup_partial_type,
function_type,
get_type_vars,
is_literal_type_like,
is_singleton_type,
make_simplified_union,
Expand Down Expand Up @@ -787,7 +786,16 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
type_vars = current_class.defn.type_vars if current_class else []
with state.strict_optional_set(True):
if is_unsafe_overlapping_overload_signatures(sig1, sig2, type_vars):
self.msg.overloaded_signatures_overlap(i + 1, i + j + 2, item.func)
flip_note = (
j == 0
and not is_unsafe_overlapping_overload_signatures(
sig2, sig1, type_vars
)
and not overload_can_never_match(sig2, sig1)
)
self.msg.overloaded_signatures_overlap(
i + 1, i + j + 2, flip_note, item.func
)

if impl_type is not None:
assert defn.impl is not None
Expand Down Expand Up @@ -1764,6 +1772,8 @@ def is_unsafe_overlapping_op(
# second operand is the right argument -- we switch the order of
# the arguments of the reverse method.

# TODO: this manipulation is dangerous if callables are generic.
# Shuffling arguments between callables can create meaningless types.
forward_tweaked = forward_item.copy_modified(
arg_types=[forward_base_erased, forward_item.arg_types[0]],
arg_kinds=[nodes.ARG_POS] * 2,
Expand All @@ -1790,7 +1800,9 @@ def is_unsafe_overlapping_op(

current_class = self.scope.active_class()
type_vars = current_class.defn.type_vars if current_class else []
return is_unsafe_overlapping_overload_signatures(first, second, type_vars)
return is_unsafe_overlapping_overload_signatures(
first, second, type_vars, partial_only=False
)

def check_inplace_operator_method(self, defn: FuncBase) -> None:
"""Check an inplace operator method such as __iadd__.
Expand Down Expand Up @@ -2185,7 +2197,7 @@ def get_op_other_domain(self, tp: FunctionLike) -> Type | None:
if isinstance(tp, CallableType):
if tp.arg_kinds and tp.arg_kinds[0] == ARG_POS:
# For generic methods, domain comparison is tricky, as a first
# approximation erase all remaining type variables to bounds.
# approximation erase all remaining type variables.
return erase_typevars(tp.arg_types[0], {v.id for v in tp.variables})
return None
elif isinstance(tp, Overloaded):
Expand Down Expand Up @@ -7827,68 +7839,112 @@ def are_argument_counts_overlapping(t: CallableType, s: CallableType) -> bool:
return min_args <= max_args


def expand_callable_variants(c: CallableType) -> list[CallableType]:
"""Expand a generic callable using all combinations of type variables' values/bounds."""
for tv in c.variables:
# We need to expand self-type before other variables, because this is the only
# type variable that can have other type variables in the upper bound.
if tv.id.is_self():
c = expand_type(c, {tv.id: tv.upper_bound}).copy_modified(
variables=[v for v in c.variables if not v.id.is_self()]
)
break

if not c.is_generic():
# Fast path.
return [c]

tvar_values = []
for tvar in c.variables:
if isinstance(tvar, TypeVarType) and tvar.values:
tvar_values.append(tvar.values)
else:
tvar_values.append([tvar.upper_bound])

variants = []
for combination in itertools.product(*tvar_values):
tvar_map = {tv.id: subst for (tv, subst) in zip(c.variables, combination)}
variants.append(expand_type(c, tvar_map).copy_modified(variables=[]))
return variants


def is_unsafe_overlapping_overload_signatures(
signature: CallableType, other: CallableType, class_type_vars: list[TypeVarLikeType]
signature: CallableType,
other: CallableType,
class_type_vars: list[TypeVarLikeType],
partial_only: bool = True,
) -> bool:
"""Check if two overloaded signatures are unsafely overlapping or partially overlapping.
We consider two functions 's' and 't' to be unsafely overlapping if both
of the following are true:
We consider two functions 's' and 't' to be unsafely overlapping if three
conditions hold:
1. s's parameters are partially overlapping with t's. i.e. there are calls that are
valid for both signatures.
2. for these common calls, some of t's parameters types are wider that s's.
3. s's return type is NOT a subset of t's.
1. s's parameters are all more precise or partially overlapping with t's
2. s's return type is NOT a subtype of t's.
Note that we use subset rather than subtype relationship in these checks because:
* Overload selection happens at runtime, not statically.
* This results in more lenient behavior.
This can cause false negatives (e.g. if overloaded function returns an externally
visible attribute with invariant type), but such situations are rare. In general,
overloads in Python are generally unsafe, so we intentionally try to avoid giving
non-actionable errors (see more details in comments below).
Assumes that 'signature' appears earlier in the list of overload
alternatives then 'other' and that their argument counts are overlapping.
"""
# Try detaching callables from the containing class so that all TypeVars
# are treated as being free.
#
# This lets us identify cases where the two signatures use completely
# incompatible types -- e.g. see the testOverloadingInferUnionReturnWithMixedTypevars
# test case.
# are treated as being free, i.e. the signature is as seen from inside the class,
# where "self" is not yet bound to anything.
signature = detach_callable(signature, class_type_vars)
other = detach_callable(other, class_type_vars)

# Note: We repeat this check twice in both directions due to a slight
# asymmetry in 'is_callable_compatible'. When checking for partial overlaps,
# we attempt to unify 'signature' and 'other' both against each other.
#
# If 'signature' cannot be unified with 'other', we end early. However,
# if 'other' cannot be modified with 'signature', the function continues
# using the older version of 'other'.
#
# This discrepancy is unfortunately difficult to get rid of, so we repeat the
# checks twice in both directions for now.
#
# Note that we ignore possible overlap between type variables and None. This
# is technically unsafe, but unsafety is tiny and this prevents some common
# use cases like:
# @overload
# def foo(x: None) -> None: ..
# @overload
# def foo(x: T) -> Foo[T]: ...
return is_callable_compatible(
signature,
other,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_proper_subtype=False,
is_compat_return=lambda l, r: not is_subtype_no_promote(l, r),
ignore_return=False,
check_args_covariantly=True,
allow_partial_overlap=True,
no_unify_none=True,
) or is_callable_compatible(
other,
signature,
is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none,
is_proper_subtype=False,
is_compat_return=lambda l, r: not is_subtype_no_promote(r, l),
ignore_return=False,
check_args_covariantly=False,
allow_partial_overlap=True,
no_unify_none=True,
)
# Note: We repeat this check twice in both directions compensate for slight
# asymmetries in 'is_callable_compatible'.

for sig_variant in expand_callable_variants(signature):
for other_variant in expand_callable_variants(other):
# Using only expanded callables may cause false negatives, we can add
# more variants (e.g. using inference between callables) in the future.
if is_subset_no_promote(sig_variant.ret_type, other_variant.ret_type):
continue
if not (
is_callable_compatible(
sig_variant,
other_variant,
is_compat=is_overlapping_types_for_overload,
check_args_covariantly=False,
is_proper_subtype=False,
is_compat_return=lambda l, r: not is_subset_no_promote(l, r),
allow_partial_overlap=True,
)
or is_callable_compatible(
other_variant,
sig_variant,
is_compat=is_overlapping_types_for_overload,
check_args_covariantly=True,
is_proper_subtype=False,
is_compat_return=lambda l, r: not is_subset_no_promote(r, l),
allow_partial_overlap=True,
)
):
continue
# Using the same `allow_partial_overlap` flag as before, can cause false
# negatives in case where star argument is used in a catch-all fallback overload.
# But again, practicality beats purity here.
if not partial_only or not is_callable_compatible(
other_variant,
sig_variant,
is_compat=is_subset_no_promote,
check_args_covariantly=True,
is_proper_subtype=False,
ignore_return=True,
allow_partial_overlap=True,
):
return True
return False


def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) -> CallableType:
Expand All @@ -7897,21 +7953,11 @@ def detach_callable(typ: CallableType, class_type_vars: list[TypeVarLikeType]) -
A callable normally keeps track of the type variables it uses within its 'variables' field.
However, if the callable is from a method and that method is using a class type variable,
the callable will not keep track of that type variable since it belongs to the class.
This function will traverse the callable and find all used type vars and add them to the
variables field if it isn't already present.
The caller can then unify on all type variables whether the callable is originally from
the class or not."""
"""
if not class_type_vars:
# Fast path, nothing to update.
return typ
seen_type_vars = set()
for t in typ.arg_types + [typ.ret_type]:
seen_type_vars |= set(get_type_vars(t))
return typ.copy_modified(
variables=list(typ.variables) + [tv for tv in class_type_vars if tv in seen_type_vars]
)
return typ.copy_modified(variables=list(typ.variables) + class_type_vars)


def overload_can_never_match(signature: CallableType, other: CallableType) -> bool:
Expand Down Expand Up @@ -8388,21 +8434,24 @@ def get_property_type(t: ProperType) -> ProperType:
return t


def is_subtype_no_promote(left: Type, right: Type) -> bool:
return is_subtype(left, right, ignore_promotions=True)
def is_subset_no_promote(left: Type, right: Type) -> bool:
return is_subtype(left, right, ignore_promotions=True, always_covariant=True)


def is_overlapping_types_no_promote_no_uninhabited_no_none(left: Type, right: Type) -> bool:
# For the purpose of unsafe overload checks we consider list[Never] and list[int]
# non-overlapping. This is consistent with how we treat list[int] and list[str] as
# non-overlapping, despite [] belongs to both. Also this will prevent false positives
# for failed type inference during unification.
def is_overlapping_types_for_overload(left: Type, right: Type) -> bool:
# Note that among other effects 'overlap_for_overloads' flag will effectively
# ignore possible overlap between type variables and None. This is technically
# unsafe, but unsafety is tiny and this prevents some common use cases like:
# @overload
# def foo(x: None) -> None: ..
# @overload
# def foo(x: T) -> Foo[T]: ...
return is_overlapping_types(
left,
right,
ignore_promotions=True,
ignore_uninhabited=True,
prohibit_none_typevar_overlap=True,
overlap_for_overloads=True,
)


Expand Down
2 changes: 1 addition & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# like U -> U, should be Callable[..., Any], but if U is a self-type, we can
# allow it to leak, to be later bound to self. A bunch of existing code
# depends on this old behaviour.
and not any(tv.id.raw_id == 0 for tv in cactual.variables)
and not any(tv.id.is_self() for tv in cactual.variables)
):
# If the actual callable is generic, infer constraints in the opposite
# direction, and indicate to the solver there are extra type variables
Expand Down
2 changes: 1 addition & 1 deletion mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def visit_instance(self, t: Instance) -> Type:
def visit_type_var(self, t: TypeVarType) -> Type:
# Normally upper bounds can't contain other type variables, the only exception is
# special type variable Self`0 <: C[T, S], where C is the class where Self is used.
if t.id.raw_id == 0:
if t.id.is_self():
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
repl = self.variables.get(t.id, t)
if isinstance(repl, ProperType) and isinstance(repl, Instance):
Expand Down
Loading

0 comments on commit 7cb733a

Please sign in to comment.