Skip to content

Improve unification for redundant unions and multiple inheritance. #1402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def visit_union_type(self, t: UnionType) -> Type:
if is_subtype(self.s, t):
return t
else:
return UnionType(t.items + [self.s])
return UnionType.make_simplified_union([self.s, t])

def visit_error_type(self, t: ErrorType) -> Type:
return t
Expand Down Expand Up @@ -235,7 +235,6 @@ def join_instances(t: Instance, s: Instance) -> Type:

Return ErrorType if the result is ambiguous.
"""

if t.type == s.type:
# Simplest case: join two types with the same base type (but
# potentially different arguments).
Expand Down Expand Up @@ -264,16 +263,29 @@ def join_instances_via_supertype(t: Instance, s: Instance) -> Type:
return join_types(t.type._promote, s)
elif s.type._promote and is_subtype(s.type._promote, t):
return join_types(t, s.type._promote)
res = s
mapped = map_instance_to_supertype(t, t.type.bases[0].type)
join = join_instances(mapped, res)
# If the join failed, fail. This is a defensive measure (this might
# never happen).
if isinstance(join, ErrorType):
return join
# Now the result must be an Instance, so the cast below cannot fail.
res = cast(Instance, join)
return res
# Compute the "best" supertype of t when joined with s.
# The definition of "best" may evolve; for now it is the one with
# the longest MRO. Ties are broken by using the earlier base.
best = None # type: Type
for base in t.type.bases:
mapped = map_instance_to_supertype(t, base.type)
res = join_instances(mapped, s)
if best is None or is_better(res, best):
best = res
assert best is not None
return best


def is_better(t: Type, s: Type) -> bool:
# Given two possible results from join_instances_via_supertype(),
# indicate whether t is the better one.
if isinstance(t, Instance):
if not isinstance(s, Instance):
return True
# Use len(mro) as a proxy for the better choice.
if len(t.type.mro) > len(s.type.mro):
return True
return False


def is_similar_callables(t: CallableType, s: CallableType) -> bool:
Expand Down
54 changes: 50 additions & 4 deletions mypy/test/data/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -599,9 +599,9 @@ g(a)
b = f(A(), B())
g(b)
c = f(A(), D())
g(c) # E: Argument 1 to "g" has incompatible type "object"; expected "I"
g(c) # E: Argument 1 to "g" has incompatible type "J"; expected "I"
d = f(D(), A())
g(d) # E: Argument 1 to "g" has incompatible type "object"; expected "I"
g(d) # E: Argument 1 to "g" has incompatible type "J"; expected "I"
e = f(D(), C())
g(e) # E: Argument 1 to "g" has incompatible type "object"; expected "I"

Expand Down Expand Up @@ -646,9 +646,9 @@ def f(a: T, b: T) -> T: pass
def g(x: K) -> None: pass

a = f(B(), C())
g(a) # E: Argument 1 to "g" has incompatible type "object"; expected "K"
g(a) # E: Argument 1 to "g" has incompatible type "J"; expected "K"
b = f(A(), C())
g(b) # E: Argument 1 to "g" has incompatible type "object"; expected "K"
g(b) # E: Argument 1 to "g" has incompatible type "J"; expected "K"
c = f(A(), B())
g(c)

Expand Down Expand Up @@ -1593,3 +1593,49 @@ tmp/m.py: note: In function "g":
tmp/m.py:2: error: "int" not callable
main: note: In function "f":
main:3: error: "int" not callable


-- Tests for special cases of unification
-- --------------------------------------

[case testUnificationRedundantUnion]
from typing import Union
a = None # type: Union[int, str]
b = None # type: Union[str, tuple]
def f(): pass
def g(x: Union[int, str]): pass
c = a if f() else b
g(c) # E: Argument 1 to "g" has incompatible type "Union[int, str, tuple]"; expected "Union[int, str]"

[case testUnificationMultipleInheritance]
class A: pass
class B:
def foo(self): pass
class C(A, B): pass
def f(): pass
a1 = B() if f() else C()
a1.foo()
a2 = C() if f() else B()
a2.foo()

[case testUnificationMultipleInheritanceAmbiguous]
# Show that join_instances_via_supertype() breakes ties using the first base class.
class A1: pass
class B1:
def foo1(self): pass
class C1(A1, B1): pass

class A2: pass
class B2:
def foo2(self): pass
class C2(A2, B2): pass

class D1(C1, C2): pass
class D2(C2, C1): pass

def f(): pass

a1 = D1() if f() else D2()
a1.foo1()
a2 = D2() if f() else D1()
a2.foo2()