diff --git a/mypy/join.py b/mypy/join.py index 4cd0da163e13..956c0a932e50 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -310,8 +310,9 @@ def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType): if t == self.s: return t - else: - return join_types(self.s.fallback, t.fallback) + if self.s.fallback.type.is_enum and t.fallback.type.is_enum: + return mypy.typeops.make_simplified_union([self.s, t]) + return join_types(self.s.fallback, t.fallback) else: return join_types(self.s, t.fallback) diff --git a/mypy/typeops.py b/mypy/typeops.py index 0f5c44b748fa..3b5ca73f8713 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,8 +5,9 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar +from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any from typing_extensions import Type as TypingType +import itertools import sys from mypy.types import ( @@ -315,7 +316,8 @@ def callable_corresponding_argument(typ: CallableType, def make_simplified_union(items: Sequence[Type], line: int = -1, column: int = -1, - *, keep_erased: bool = False) -> ProperType: + *, keep_erased: bool = False, + contract_literals: bool = True) -> ProperType: """Build union type with redundant union items removed. If only a single item remains, this may return a non-union type. @@ -377,6 +379,11 @@ def make_simplified_union(items: Sequence[Type], items[i] = true_or_false(ti) simplified_set = [items[i] for i in range(len(items)) if i not in removed] + + # If more than one literal exists in the union, try to simplify + if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1): + simplified_set = try_contracting_literals_in_union(simplified_set) + return UnionType.make_union(simplified_set, line, column) @@ -684,7 +691,7 @@ class Status(Enum): if isinstance(typ, UnionType): items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items] - return make_simplified_union(items) + return make_simplified_union(items, contract_literals=False) elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname: new_items = [] for name, symbol in typ.type.names.items(): @@ -702,11 +709,39 @@ class Status(Enum): # only using CPython, but we might as well for the sake of full correctness. if sys.version_info < (3, 7): new_items.sort(key=lambda lit: lit.value) - return make_simplified_union(new_items) + return make_simplified_union(new_items, contract_literals=False) else: return typ +def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperType]: + """Contracts any literal types back into a sum type if possible. + + Will replace the first instance of the literal with the sum type and + remove all others. + + if we call `try_contracting_union(Literal[Color.RED, Color.BLUE, Color.YELLOW])`, + this function will return Color. + """ + sum_types = {} # type: Dict[str, Tuple[Set[Any], List[int]]] + marked_for_deletion = set() + for idx, typ in enumerate(types): + if isinstance(typ, LiteralType): + fullname = typ.fallback.type.fullname + if typ.fallback.type.is_enum: + if fullname not in sum_types: + sum_types[fullname] = (set(get_enum_values(typ.fallback)), []) + literals, indexes = sum_types[fullname] + literals.discard(typ.value) + indexes.append(idx) + if not literals: + first, *rest = indexes + types[first] = typ.fallback + marked_for_deletion |= set(rest) + return list(itertools.compress(types, [(i not in marked_for_deletion) + for i in range(len(types))])) + + def coerce_to_literal(typ: Type) -> Type: """Recursively converts any Instances that have a last_known_value or are instances of enum types with a single value into the corresponding LiteralType. diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 27742a782d90..c9635f6ecc86 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -713,6 +713,7 @@ elif x is Foo.C: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(x) # No output here: this branch is unreachable +reveal_type(x) # N: Revealed type is '__main__.Foo' if Foo.A is x: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -722,6 +723,7 @@ elif Foo.C is x: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(x) # No output here: this branch is unreachable +reveal_type(x) # N: Revealed type is '__main__.Foo' y: Foo if y is Foo.A: @@ -732,6 +734,7 @@ elif y is Foo.C: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(y) # No output here: this branch is unreachable +reveal_type(y) # N: Revealed type is '__main__.Foo' if Foo.A is y: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -741,6 +744,7 @@ elif Foo.C is y: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(y) # No output here: this branch is unreachable +reveal_type(y) # N: Revealed type is '__main__.Foo' [builtins fixtures/bool.pyi] [case testEnumReachabilityChecksWithOrdering] @@ -815,12 +819,14 @@ if x is y: else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +reveal_type(x) # N: Revealed type is '__main__.Foo' if y is x: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +reveal_type(x) # N: Revealed type is '__main__.Foo' if x is z: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -830,6 +836,7 @@ else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?' accepts_foo_a(z) +reveal_type(x) # N: Revealed type is '__main__.Foo' if z is x: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?' @@ -838,6 +845,7 @@ else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?' accepts_foo_a(z) +reveal_type(x) # N: Revealed type is '__main__.Foo' if y is z: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -909,6 +917,7 @@ if x is Foo.A: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]' +reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' [builtins fixtures/bool.pyi] [case testEnumReachabilityWithMultipleEnums] @@ -928,18 +937,21 @@ if x1 is Foo.A: reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]' else: reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]' +reveal_type(x1) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' x2: Union[Foo, Bar] if x2 is Bar.A: reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]' else: reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]' +reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' x3: Union[Foo, Bar] if x3 is Foo.A or x3 is Bar.A: reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]' else: reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]' +reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' [builtins fixtures/bool.pyi] @@ -1299,3 +1311,25 @@ reveal_type(a._value_) # N: Revealed type is 'Any' [builtins fixtures/__new__.pyi] [builtins fixtures/primitives.pyi] [typing fixtures/typing-medium.pyi] + +[case testEnumNarrowedToTwoLiterals] +# Regression test: two literals of an enum would be joined +# as the full type, regardless of the amount of elements +# the enum contains. +from enum import Enum +from typing import Union +from typing_extensions import Literal + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +def f(x: Foo): + if x is Foo.A: + return x + if x is Foo.B: + pass + reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' + +[builtins fixtures/bool.pyi]