diff --git a/mypy/join.py b/mypy/join.py index 865dd073d081..867ee636997b 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -54,7 +54,8 @@ class InstanceJoiner: - def __init__(self) -> None: + def __init__(self, prefer_union_over_supertype: bool = False) -> None: + self.prefer_union_over_supertype: bool = prefer_union_over_supertype self.seen_instances: list[tuple[Instance, Instance]] = [] def join_instances(self, t: Instance, s: Instance) -> ProperType: @@ -164,6 +165,9 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType: if is_subtype(p, t): return join_types(t, p, self) + if self.prefer_union_over_supertype: + return mypy.typeops.make_simplified_union([t, s]) + # Compute the "best" supertype of t when joined with s. # The definition of "best" may evolve; for now it is the one with # the longest MRO. Ties are broken by using the earlier base. @@ -224,7 +228,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType: if isinstance(s, UnionType) and not isinstance(t, UnionType): s, t = t, s - value = t.accept(TypeJoinVisitor(s)) + value = t.accept( + TypeJoinVisitor(s, instance_joiner=InstanceJoiner(prefer_union_over_supertype=True)) + ) if declaration is None or is_subtype(value, declaration): return value @@ -601,12 +607,17 @@ def visit_tuple_type(self, t: TupleType) -> ProperType: # * Joining with any Sequence also returns a Sequence: # Tuple[int, bool] + List[bool] becomes Sequence[int] if isinstance(self.s, TupleType): + if self.instance_joiner is None: self.instance_joiner = InstanceJoiner() + prefer_union = self.instance_joiner.prefer_union_over_supertype + self.instance_joiner.prefer_union_over_supertype = False fallback = self.instance_joiner.join_instances( mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t) ) assert isinstance(fallback, Instance) + self.instance_joiner.prefer_union_over_supertype = prefer_union + items = self.join_tuples(self.s, t) if items is not None: return TupleType(items, fallback) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 285d56ff7e50..6a8aa5282a49 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2352,3 +2352,55 @@ def fn_while(arg: T) -> None: return None return None [builtins fixtures/primitives.pyi] + +[case testNarrowingInstancesCreatesUnion] + +class A: ... +class B(A): y: int +class C(A): y: int +class D(C): ... +class E(C): ... +class F(C): ... + +def f(x: A): + if isinstance(x, B): ... + elif isinstance(x, D): ... + elif isinstance(x, E): ... + elif isinstance(x, F): ... + else: return + reveal_type(x) # N: Revealed type is "Union[__main__.B, __main__.D, __main__.E, __main__.F]" + reveal_type(x.y) # N: Revealed type is "builtins.int" + +[builtins fixtures/isinstance.pyi] + +[case testNarrowingDoNotNarrowNamedTupleFallbacksToUnions] + +from typing import List, NamedTuple, Union + +class A(NamedTuple): + x: int +class B(NamedTuple): + x: int + y: int +class C(NamedTuple): + x: int + y: int + z: int + +def f() -> bool: ... + +def g() -> None: + l: List[Union[A, B, C]] + if f(): + assert isinstance(l[0], A) + reveal_type(l[0]) # N: Revealed type is "Tuple[builtins.int, fallback=__main__.A]" + elif f(): + assert isinstance(l[0], B) + reveal_type(l[0]) # N: Revealed type is "Tuple[builtins.int, builtins.int, fallback=__main__.B]" + else: + assert False + reveal_type(l[0]) # N: Revealed type is "Union[Tuple[builtins.int, fallback=__main__.A], \ + Tuple[builtins.int, builtins.int, fallback=__main__.B], \ + Tuple[builtins.int, builtins.int, builtins.int, fallback=__main__.C]]" + +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-redefine.test b/test-data/unit/check-redefine.test index b7642d30efc8..e162bb73a206 100644 --- a/test-data/unit/check-redefine.test +++ b/test-data/unit/check-redefine.test @@ -321,7 +321,7 @@ def f() -> None: x = 1 if int(): x = '' - reveal_type(x) # N: Revealed type is "builtins.object" + reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.int]" x = '' reveal_type(x) # N: Revealed type is "builtins.str" if int():