Skip to content

Commit

Permalink
Make enum type compatible with union of all enum item literals (#10388)
Browse files Browse the repository at this point in the history
For example, consider this enum:

```
class E(Enum):
    A = 1
    B = 1
```

This PR makes `E` compatible with `Literal[E.A, E.B]`.

Also fix mutation of the argument list in
`try_contracting_literals_in_union`.

This fixes some regressions introduced in #9097.
  • Loading branch information
JukkaL authored Apr 30, 2021
1 parent 7189a23 commit 322ef60
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
12 changes: 12 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,18 @@ def _is_subtype(left: Type, right: Type,
ignore_declared_variance=ignore_declared_variance,
ignore_promotions=ignore_promotions)
for item in right.items)
# Recombine rhs literal types, to make an enum type a subtype
# of a union of all enum items as literal types. Only do it if
# the previous check didn't succeed, since recombining can be
# expensive.
if not is_subtype_of_item and isinstance(left, Instance) and left.type.is_enum:
right = UnionType(mypy.typeops.try_contracting_literals_in_union(right.items))
is_subtype_of_item = any(is_subtype(orig_left, item,
ignore_type_params=ignore_type_params,
ignore_pos_arg_names=ignore_pos_arg_names,
ignore_declared_variance=ignore_declared_variance,
ignore_promotions=ignore_promotions)
for item in right.items)
# However, if 'left' is a type variable T, T might also have
# an upper bound which is itself a union. This case will be
# handled below by the SubtypeVisitor. We have to check both
Expand Down
11 changes: 6 additions & 5 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ class Status(Enum):
return typ


def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperType]:
def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]:
"""Contracts any literal types back into a sum type if possible.
Will replace the first instance of the literal with the sum type and
Expand All @@ -724,9 +724,10 @@ def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperTyp
if we call `try_contracting_union(Literal[Color.RED, Color.BLUE, Color.YELLOW])`,
this function will return Color.
"""
proper_types = [get_proper_type(typ) for typ in types]
sum_types = {} # type: Dict[str, Tuple[Set[Any], List[int]]]
marked_for_deletion = set()
for idx, typ in enumerate(types):
for idx, typ in enumerate(proper_types):
if isinstance(typ, LiteralType):
fullname = typ.fallback.type.fullname
if typ.fallback.type.is_enum:
Expand All @@ -737,10 +738,10 @@ def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperTyp
indexes.append(idx)
if not literals:
first, *rest = indexes
types[first] = typ.fallback
proper_types[first] = typ.fallback
marked_for_deletion |= set(rest)
return list(itertools.compress(types, [(i not in marked_for_deletion)
for i in range(len(types))]))
return list(itertools.compress(proper_types, [(i not in marked_for_deletion)
for i in range(len(proper_types))]))


def coerce_to_literal(typ: Type) -> Type:
Expand Down
16 changes: 16 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -1333,3 +1333,19 @@ def f(x: Foo):
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]"

[builtins fixtures/bool.pyi]

[case testEnumTypeCompatibleWithLiteralUnion]
from enum import Enum
from typing_extensions import Literal

class E(Enum):
A = 1
B = 2
C = 3

e: E
a: Literal[E.A, E.B, E.C] = e
b: Literal[E.A, E.B] = e # E: Incompatible types in assignment (expression has type "E", variable has type "Union[Literal[E.A], Literal[E.B]]")
c: Literal[E.A, E.C] = e # E: Incompatible types in assignment (expression has type "E", variable has type "Union[Literal[E.A], Literal[E.C]]")
b = a # E: Incompatible types in assignment (expression has type "Union[Literal[E.A], Literal[E.B], Literal[E.C]]", variable has type "Union[Literal[E.A], Literal[E.B]]")
[builtins fixtures/bool.pyi]

0 comments on commit 322ef60

Please sign in to comment.