From 11c082e6e402ed6ddf38edda5fba01523e7f2087 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 10 Aug 2023 23:32:26 +0100 Subject: [PATCH 1/6] Allow None vs TypeVar overlap for overloads --- mypy/checker.py | 24 ++++++++++++++++++++---- mypy/subtypes.py | 15 +++++++++++++-- test-data/unit/check-overloading.test | 19 ++++++++++++++----- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index b786155079e57..3bd9c494a8905 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -7216,22 +7216,32 @@ def is_unsafe_overlapping_overload_signatures( # # This discrepancy is unfortunately difficult to get rid of, so we repeat the # checks twice in both directions for now. + # + # Note that we ignore possible overlap between type variables and None. This + # is technically unsafe, but unsafety is tiny and this prevents some common + # use cases like: + # @overload + # def foo(x: None) -> None: .. + # @overload + # def foo(x: T) -> Foo[T]: ... return is_callable_compatible( signature, other, - is_compat=is_overlapping_types_no_promote_no_uninhabited, + is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none, is_compat_return=lambda l, r: not is_subtype_no_promote(l, r), ignore_return=False, check_args_covariantly=True, allow_partial_overlap=True, + no_unify_none=True, ) or is_callable_compatible( other, signature, - is_compat=is_overlapping_types_no_promote_no_uninhabited, + is_compat=is_overlapping_types_no_promote_no_uninhabited_no_none, is_compat_return=lambda l, r: not is_subtype_no_promote(r, l), ignore_return=False, check_args_covariantly=False, allow_partial_overlap=True, + no_unify_none=True, ) @@ -7717,12 +7727,18 @@ def is_subtype_no_promote(left: Type, right: Type) -> bool: return is_subtype(left, right, ignore_promotions=True) -def is_overlapping_types_no_promote_no_uninhabited(left: Type, right: Type) -> bool: +def is_overlapping_types_no_promote_no_uninhabited_no_none(left: Type, right: Type) -> bool: # For the purpose of unsafe overload checks we consider list[] and list[int] # non-overlapping. This is consistent with how we treat list[int] and list[str] as # non-overlapping, despite [] belongs to both. Also this will prevent false positives # for failed type inference during unification. - return is_overlapping_types(left, right, ignore_promotions=True, ignore_uninhabited=True) + return is_overlapping_types( + left, + right, + ignore_promotions=True, + ignore_uninhabited=True, + prohibit_none_typevar_overlap=True, + ) def is_private(node_name: str) -> bool: diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 5712d7375e50f..26c8816561a79 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1299,6 +1299,7 @@ def is_callable_compatible( check_args_covariantly: bool = False, allow_partial_overlap: bool = False, strict_concatenate: bool = False, + no_unify_none: bool = False, ) -> bool: """Is the left compatible with the right, using the provided compatibility check? @@ -1415,7 +1416,9 @@ def g(x: int) -> int: ... # (below) treats type variables on the two sides as independent. if left.variables: # Apply generic type variables away in left via type inference. - unified = unify_generic_callable(left, right, ignore_return=ignore_return) + unified = unify_generic_callable( + left, right, ignore_return=ignore_return, no_unify_none=True + ) if unified is None: return False left = unified @@ -1427,7 +1430,9 @@ def g(x: int) -> int: ... # So, we repeat the above checks in the opposite direction. This also # lets us preserve the 'symmetry' property of allow_partial_overlap. if allow_partial_overlap and right.variables: - unified = unify_generic_callable(right, left, ignore_return=ignore_return) + unified = unify_generic_callable( + right, left, ignore_return=ignore_return, no_unify_none=True + ) if unified is not None: right = unified @@ -1687,6 +1692,8 @@ def unify_generic_callable( target: NormalizedCallableType, ignore_return: bool, return_constraint_direction: int | None = None, + *, + no_unify_none: bool = False, ) -> NormalizedCallableType | None: """Try to unify a generic callable type with another callable type. @@ -1708,6 +1715,10 @@ def unify_generic_callable( type.ret_type, target.ret_type, return_constraint_direction ) constraints.extend(c) + if no_unify_none: + constraints = [ + c for c in constraints if not isinstance(get_proper_type(c.target), NoneType) + ] inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints) if None in inferred_vars: return None diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 50acd7d77c8cd..212d30690068d 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -2188,33 +2188,42 @@ def bar2(*x: int) -> int: ... from typing import overload, TypeVar, Generic T = TypeVar('T') +# The examples below are unsafe, but it is a quite common pattern +# so we ignore the possibility of type variables taking value `None` +# for the purpose of overload overlap checks. @overload -def foo(x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def foo(x: None, y: None) -> str: ... @overload def foo(x: T, y: T) -> int: ... def foo(x): ... # What if 'T' is 'object'? @overload -def bar(x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def bar(x: None, y: int) -> str: ... @overload def bar(x: T, y: T) -> int: ... def bar(x, y): ... class Wrapper(Generic[T]): @overload - def foo(self, x: None, y: None) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types + def foo(self, x: None, y: None) -> str: ... @overload def foo(self, x: T, y: None) -> int: ... def foo(self, x): ... @overload - def bar(self, x: None, y: int) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types + def bar(self, x: None, y: int) -> str: ... @overload def bar(self, x: T, y: T) -> int: ... def bar(self, x, y): ... +@overload +def baz(x: str, y: str) -> str: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def baz(x: T, y: T) -> int: ... +def baz(x): ... + [case testOverloadFlagsPossibleMatches] from wrapper import * [file wrapper.pyi] @@ -3996,7 +4005,7 @@ T = TypeVar('T') class FakeAttribute(Generic[T]): @overload - def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types + def dummy(self, instance: None, owner: Type[T]) -> 'FakeAttribute[T]': ... @overload def dummy(self, instance: T, owner: Type[T]) -> int: ... def dummy(self, instance: Optional[T], owner: Type[T]) -> Union['FakeAttribute[T]', int]: ... From 494016a6ab8da83cff462680365ceb6f1f7dc9f1 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 11 Aug 2023 00:13:11 +0100 Subject: [PATCH 2/6] Fix typo --- mypy/subtypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 26c8816561a79..da92f7398d4e1 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1417,7 +1417,7 @@ def g(x: int) -> int: ... if left.variables: # Apply generic type variables away in left via type inference. unified = unify_generic_callable( - left, right, ignore_return=ignore_return, no_unify_none=True + left, right, ignore_return=ignore_return, no_unify_none=no_unify_none ) if unified is None: return False @@ -1431,7 +1431,7 @@ def g(x: int) -> int: ... # lets us preserve the 'symmetry' property of allow_partial_overlap. if allow_partial_overlap and right.variables: unified = unify_generic_callable( - right, left, ignore_return=ignore_return, no_unify_none=True + right, left, ignore_return=ignore_return, no_unify_none=no_unify_none ) if unified is not None: right = unified From 05eb6ff265f3c867f3ca4b8836cac08e37f500d9 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 11 Aug 2023 17:52:53 +0100 Subject: [PATCH 3/6] Add heuristic to enforce union math for None vs TypeVar overlap --- mypy/checkexpr.py | 81 +++++++++++++++++++++------ test-data/unit/check-overloading.test | 20 ++++++- 2 files changed, 84 insertions(+), 17 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9e46d9ee39cb8..d2a4487e82825 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2400,6 +2400,11 @@ def check_overload_call( # typevar. See https://github.com/python/mypy/issues/4063 for related discussion. erased_targets: list[CallableType] | None = None unioned_result: tuple[Type, Type] | None = None + + # Determine whether we need to encourage union math. This should be generally safe, + # as union math infers better results in the vast majority of cases, but it is very + # computationally intensive. + none_type_var_overlap = self.possible_none_type_var_overlap(arg_types, plausible_targets) union_interrupted = False # did we try all union combinations? if any(self.real_union(arg) for arg in arg_types): try: @@ -2412,6 +2417,7 @@ def check_overload_call( arg_names, callable_name, object_type, + none_type_var_overlap, context, ) except TooManyUnions: @@ -2444,8 +2450,10 @@ def check_overload_call( # If any of checks succeed, stop early. if inferred_result is not None and unioned_result is not None: # Both unioned and direct checks succeeded, choose the more precise type. - if is_subtype(inferred_result[0], unioned_result[0]) and not isinstance( - get_proper_type(inferred_result[0]), AnyType + if ( + is_subtype(inferred_result[0], unioned_result[0]) + and not isinstance(get_proper_type(inferred_result[0]), AnyType) + and not none_type_var_overlap ): return inferred_result return unioned_result @@ -2650,6 +2658,42 @@ def overload_erased_call_targets( matches.append(typ) return matches + def possible_none_type_var_overlap( + self, arg_types: list[Type], plausible_targets: list[CallableType] + ) -> bool: + """Heuristic to determine whether we need to try forcing union math. + + This is needed to avoid greedy type variable match in situations like this: + @overload + def foo(x: None) -> None: ... + @overload + def foo(x: T) -> list[T]: ... + + x: int | None + foo(x) + we want this call to infer list[int] | None, not list[int | None]. + """ + has_optional_arg = False + for arg_type in get_proper_types(arg_types): + if not isinstance(arg_type, UnionType): + continue + for item in get_proper_types(arg_type.items): + if isinstance(item, NoneType): + has_optional_arg = True + break + if not has_optional_arg: + return False + + min_prefix = min(len(c.arg_types) for c in plausible_targets) + for i in range(min_prefix): + if any( + isinstance(get_proper_type(c.arg_types[i]), NoneType) for c in plausible_targets + ) and any( + isinstance(get_proper_type(c.arg_types[i]), TypeVarType) for c in plausible_targets + ): + return True + return False + def union_overload_result( self, plausible_targets: list[CallableType], @@ -2659,6 +2703,7 @@ def union_overload_result( arg_names: Sequence[str | None] | None, callable_name: str | None, object_type: Type | None, + none_type_var_overlap: bool, context: Context, level: int = 0, ) -> list[tuple[Type, Type]] | None: @@ -2698,20 +2743,23 @@ def union_overload_result( # Step 3: Try a direct match before splitting to avoid unnecessary union splits # and save performance. - with self.type_overrides_set(args, arg_types): - direct = self.infer_overload_return_type( - plausible_targets, - args, - arg_types, - arg_kinds, - arg_names, - callable_name, - object_type, - context, - ) - if direct is not None and not isinstance(get_proper_type(direct[0]), (UnionType, AnyType)): - # We only return non-unions soon, to avoid greedy match. - return [direct] + if not none_type_var_overlap: + with self.type_overrides_set(args, arg_types): + direct = self.infer_overload_return_type( + plausible_targets, + args, + arg_types, + arg_kinds, + arg_names, + callable_name, + object_type, + context, + ) + if direct is not None and not isinstance( + get_proper_type(direct[0]), (UnionType, AnyType) + ): + # We only return non-unions soon, to avoid greedy match. + return [direct] # Step 4: Split the first remaining union type in arguments into items and # try to match each item individually (recursive). @@ -2729,6 +2777,7 @@ def union_overload_result( arg_names, callable_name, object_type, + none_type_var_overlap, context, level + 1, ) diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 212d30690068d..4910dfe05d318 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -2185,7 +2185,8 @@ def bar2(*x: int) -> int: ... [builtins fixtures/tuple.pyi] [case testOverloadDetectsPossibleMatchesWithGenerics] -from typing import overload, TypeVar, Generic +# flags: --strict-optional +from typing import overload, TypeVar, Generic, Optional, List T = TypeVar('T') # The examples below are unsafe, but it is a quite common pattern @@ -2198,6 +2199,22 @@ def foo(x: None, y: None) -> str: ... def foo(x: T, y: T) -> int: ... def foo(x): ... +oi: Optional[int] +reveal_type(foo(None, None)) # N: Revealed type is "builtins.str" +reveal_type(foo(None, 42)) # N: Revealed type is "builtins.int" +reveal_type(foo(42, 42)) # N: Revealed type is "builtins.int" +reveal_type(foo(oi, None)) # N: Revealed type is "Union[builtins.int, builtins.str]" +reveal_type(foo(oi, 42)) # N: Revealed type is "builtins.int" +reveal_type(foo(oi, oi)) # N: Revealed type is "Union[builtins.int, builtins.str]" + +@overload +def foo_list(x: None) -> None: ... +@overload +def foo_list(x: T) -> List[T]: ... +def foo_list(x): ... + +reveal_type(foo_list(oi)) # N: Revealed type is "Union[builtins.list[builtins.int], None]" + # What if 'T' is 'object'? @overload def bar(x: None, y: int) -> str: ... @@ -2223,6 +2240,7 @@ def baz(x: str, y: str) -> str: ... # E: Overloaded function signatures 1 and 2 @overload def baz(x: T, y: T) -> int: ... def baz(x): ... +[builtins fixtures/tuple.pyi] [case testOverloadFlagsPossibleMatches] from wrapper import * From f4e78eafde11f5a0b0c3251bd0c169aeb914be76 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 11 Aug 2023 17:57:59 +0100 Subject: [PATCH 4/6] Awoid spurious errors --- mypy/checkexpr.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index d2a4487e82825..5f0a53f69ed7c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2503,7 +2503,8 @@ def check_overload_call( callable_name=callable_name, object_type=object_type, ) - if union_interrupted: + # Do not show the extra error is the union math was forced. + if union_interrupted and not none_type_var_overlap: self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context) return result From 123fea4700140b2903bf948e58b4206b7bd3ed2a Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 11 Aug 2023 18:00:59 +0100 Subject: [PATCH 5/6] Update mypy/checkexpr.py Co-authored-by: Alex Waygood --- mypy/checkexpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 5f0a53f69ed7c..3ac20681a8f06 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2503,7 +2503,7 @@ def check_overload_call( callable_name=callable_name, object_type=object_type, ) - # Do not show the extra error is the union math was forced. + # Do not show the extra error if the union math was forced. if union_interrupted and not none_type_var_overlap: self.chk.fail(message_registry.TOO_MANY_UNION_COMBINATIONS, context) return result From 76ec63f5d0f0362379da23acb10759219408155d Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 11 Aug 2023 18:57:58 +0100 Subject: [PATCH 6/6] Fix crash on non-matching overload --- mypy/checkexpr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3ac20681a8f06..3282ee338b874 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2674,6 +2674,8 @@ def foo(x: T) -> list[T]: ... foo(x) we want this call to infer list[int] | None, not list[int | None]. """ + if not plausible_targets or not arg_types: + return False has_optional_arg = False for arg_type in get_proper_types(arg_types): if not isinstance(arg_type, UnionType):