From 322ef60ba7231bbd4f735b273e88e67afb57ad17 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Fri, 30 Apr 2021 13:54:53 +0100 Subject: [PATCH] Make enum type compatible with union of all enum item literals (#10388) For example, consider this enum: ``` class E(Enum): A = 1 B = 1 ``` This PR makes `E` compatible with `Literal[E.A, E.B]`. Also fix mutation of the argument list in `try_contracting_literals_in_union`. This fixes some regressions introduced in #9097. --- mypy/subtypes.py | 12 ++++++++++++ mypy/typeops.py | 11 ++++++----- test-data/unit/check-enum.test | 16 ++++++++++++++++ 3 files changed, 34 insertions(+), 5 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 9bd4bfa0f6e4..c3b8b82a3c2c 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -121,6 +121,18 @@ def _is_subtype(left: Type, right: Type, ignore_declared_variance=ignore_declared_variance, ignore_promotions=ignore_promotions) for item in right.items) + # Recombine rhs literal types, to make an enum type a subtype + # of a union of all enum items as literal types. Only do it if + # the previous check didn't succeed, since recombining can be + # expensive. + if not is_subtype_of_item and isinstance(left, Instance) and left.type.is_enum: + right = UnionType(mypy.typeops.try_contracting_literals_in_union(right.items)) + is_subtype_of_item = any(is_subtype(orig_left, item, + ignore_type_params=ignore_type_params, + ignore_pos_arg_names=ignore_pos_arg_names, + ignore_declared_variance=ignore_declared_variance, + ignore_promotions=ignore_promotions) + for item in right.items) # However, if 'left' is a type variable T, T might also have # an upper bound which is itself a union. This case will be # handled below by the SubtypeVisitor. We have to check both diff --git a/mypy/typeops.py b/mypy/typeops.py index 1760e9c00503..56a6002d1e40 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -715,7 +715,7 @@ class Status(Enum): return typ -def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperType]: +def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]: """Contracts any literal types back into a sum type if possible. Will replace the first instance of the literal with the sum type and @@ -724,9 +724,10 @@ def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperTyp if we call `try_contracting_union(Literal[Color.RED, Color.BLUE, Color.YELLOW])`, this function will return Color. """ + proper_types = [get_proper_type(typ) for typ in types] sum_types = {} # type: Dict[str, Tuple[Set[Any], List[int]]] marked_for_deletion = set() - for idx, typ in enumerate(types): + for idx, typ in enumerate(proper_types): if isinstance(typ, LiteralType): fullname = typ.fallback.type.fullname if typ.fallback.type.is_enum: @@ -737,10 +738,10 @@ def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperTyp indexes.append(idx) if not literals: first, *rest = indexes - types[first] = typ.fallback + proper_types[first] = typ.fallback marked_for_deletion |= set(rest) - return list(itertools.compress(types, [(i not in marked_for_deletion) - for i in range(len(types))])) + return list(itertools.compress(proper_types, [(i not in marked_for_deletion) + for i in range(len(proper_types))])) def coerce_to_literal(typ: Type) -> Type: diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index c306d058278d..06f79776f77f 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -1333,3 +1333,19 @@ def f(x: Foo): reveal_type(x) # N: Revealed type is "Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]" [builtins fixtures/bool.pyi] + +[case testEnumTypeCompatibleWithLiteralUnion] +from enum import Enum +from typing_extensions import Literal + +class E(Enum): + A = 1 + B = 2 + C = 3 + +e: E +a: Literal[E.A, E.B, E.C] = e +b: Literal[E.A, E.B] = e # E: Incompatible types in assignment (expression has type "E", variable has type "Union[Literal[E.A], Literal[E.B]]") +c: Literal[E.A, E.C] = e # E: Incompatible types in assignment (expression has type "E", variable has type "Union[Literal[E.A], Literal[E.C]]") +b = a # E: Incompatible types in assignment (expression has type "Union[Literal[E.A], Literal[E.B], Literal[E.C]]", variable has type "Union[Literal[E.A], Literal[E.B]]") +[builtins fixtures/bool.pyi]