diff --git a/mypy/join.py b/mypy/join.py index 7489fec2572e..3adb9531e5a2 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -87,7 +87,7 @@ def visit_union_type(self, t: UnionType) -> Type: if is_subtype(self.s, t): return t else: - return UnionType(t.items + [self.s]) + return UnionType.make_simplified_union([self.s, t]) def visit_error_type(self, t: ErrorType) -> Type: return t @@ -235,7 +235,6 @@ def join_instances(t: Instance, s: Instance) -> Type: Return ErrorType if the result is ambiguous. """ - if t.type == s.type: # Simplest case: join two types with the same base type (but # potentially different arguments). @@ -264,16 +263,29 @@ def join_instances_via_supertype(t: Instance, s: Instance) -> Type: return join_types(t.type._promote, s) elif s.type._promote and is_subtype(s.type._promote, t): return join_types(t, s.type._promote) - res = s - mapped = map_instance_to_supertype(t, t.type.bases[0].type) - join = join_instances(mapped, res) - # If the join failed, fail. This is a defensive measure (this might - # never happen). - if isinstance(join, ErrorType): - return join - # Now the result must be an Instance, so the cast below cannot fail. - res = cast(Instance, join) - return res + # Compute the "best" supertype of t when joined with s. + # The definition of "best" may evolve; for now it is the one with + # the longest MRO. Ties are broken by using the earlier base. + best = None # type: Type + for base in t.type.bases: + mapped = map_instance_to_supertype(t, base.type) + res = join_instances(mapped, s) + if best is None or is_better(res, best): + best = res + assert best is not None + return best + + +def is_better(t: Type, s: Type) -> bool: + # Given two possible results from join_instances_via_supertype(), + # indicate whether t is the better one. + if isinstance(t, Instance): + if not isinstance(s, Instance): + return True + # Use len(mro) as a proxy for the better choice. + if len(t.type.mro) > len(s.type.mro): + return True + return False def is_similar_callables(t: CallableType, s: CallableType) -> bool: diff --git a/mypy/test/data/check-inference.test b/mypy/test/data/check-inference.test index 999b1cb8e054..50df29d3a024 100644 --- a/mypy/test/data/check-inference.test +++ b/mypy/test/data/check-inference.test @@ -599,9 +599,9 @@ g(a) b = f(A(), B()) g(b) c = f(A(), D()) -g(c) # E: Argument 1 to "g" has incompatible type "object"; expected "I" +g(c) # E: Argument 1 to "g" has incompatible type "J"; expected "I" d = f(D(), A()) -g(d) # E: Argument 1 to "g" has incompatible type "object"; expected "I" +g(d) # E: Argument 1 to "g" has incompatible type "J"; expected "I" e = f(D(), C()) g(e) # E: Argument 1 to "g" has incompatible type "object"; expected "I" @@ -646,9 +646,9 @@ def f(a: T, b: T) -> T: pass def g(x: K) -> None: pass a = f(B(), C()) -g(a) # E: Argument 1 to "g" has incompatible type "object"; expected "K" +g(a) # E: Argument 1 to "g" has incompatible type "J"; expected "K" b = f(A(), C()) -g(b) # E: Argument 1 to "g" has incompatible type "object"; expected "K" +g(b) # E: Argument 1 to "g" has incompatible type "J"; expected "K" c = f(A(), B()) g(c) @@ -1593,3 +1593,49 @@ tmp/m.py: note: In function "g": tmp/m.py:2: error: "int" not callable main: note: In function "f": main:3: error: "int" not callable + + +-- Tests for special cases of unification +-- -------------------------------------- + +[case testUnificationRedundantUnion] +from typing import Union +a = None # type: Union[int, str] +b = None # type: Union[str, tuple] +def f(): pass +def g(x: Union[int, str]): pass +c = a if f() else b +g(c) # E: Argument 1 to "g" has incompatible type "Union[int, str, tuple]"; expected "Union[int, str]" + +[case testUnificationMultipleInheritance] +class A: pass +class B: + def foo(self): pass +class C(A, B): pass +def f(): pass +a1 = B() if f() else C() +a1.foo() +a2 = C() if f() else B() +a2.foo() + +[case testUnificationMultipleInheritanceAmbiguous] +# Show that join_instances_via_supertype() breakes ties using the first base class. +class A1: pass +class B1: + def foo1(self): pass +class C1(A1, B1): pass + +class A2: pass +class B2: + def foo2(self): pass +class C2(A2, B2): pass + +class D1(C1, C2): pass +class D2(C2, C1): pass + +def f(): pass + +a1 = D1() if f() else D2() +a1.foo1() +a2 = D2() if f() else D1() +a2.foo2()