Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recombine complete union of enum literals into original type (#9063) #9097

Merged
merged 1 commit into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,9 @@ def visit_literal_type(self, t: LiteralType) -> ProperType:
if isinstance(self.s, LiteralType):
if t == self.s:
return t
else:
return join_types(self.s.fallback, t.fallback)
if self.s.fallback.type.is_enum and t.fallback.type.is_enum:
return mypy.typeops.make_simplified_union([self.s, t])
return join_types(self.s.fallback, t.fallback)
else:
return join_types(self.s, t.fallback)

Expand Down
43 changes: 39 additions & 4 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
since these may assume that MROs are ready.
"""

from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar
from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any
from typing_extensions import Type as TypingType
import itertools
import sys

from mypy.types import (
Expand Down Expand Up @@ -315,7 +316,8 @@ def callable_corresponding_argument(typ: CallableType,

def make_simplified_union(items: Sequence[Type],
line: int = -1, column: int = -1,
*, keep_erased: bool = False) -> ProperType:
*, keep_erased: bool = False,
contract_literals: bool = True) -> ProperType:
"""Build union type with redundant union items removed.

If only a single item remains, this may return a non-union type.
Expand Down Expand Up @@ -377,6 +379,11 @@ def make_simplified_union(items: Sequence[Type],
items[i] = true_or_false(ti)

simplified_set = [items[i] for i in range(len(items)) if i not in removed]

# If more than one literal exists in the union, try to simplify
if (contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1):
simplified_set = try_contracting_literals_in_union(simplified_set)

return UnionType.make_union(simplified_set, line, column)


Expand Down Expand Up @@ -684,7 +691,7 @@ class Status(Enum):

if isinstance(typ, UnionType):
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
return make_simplified_union(items)
return make_simplified_union(items, contract_literals=False)
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname:
new_items = []
for name, symbol in typ.type.names.items():
Expand All @@ -702,11 +709,39 @@ class Status(Enum):
# only using CPython, but we might as well for the sake of full correctness.
if sys.version_info < (3, 7):
new_items.sort(key=lambda lit: lit.value)
return make_simplified_union(new_items)
return make_simplified_union(new_items, contract_literals=False)
else:
return typ


def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperType]:
"""Contracts any literal types back into a sum type if possible.

Will replace the first instance of the literal with the sum type and
remove all others.

if we call `try_contracting_union(Literal[Color.RED, Color.BLUE, Color.YELLOW])`,
this function will return Color.
"""
sum_types = {} # type: Dict[str, Tuple[Set[Any], List[int]]]
marked_for_deletion = set()
for idx, typ in enumerate(types):
if isinstance(typ, LiteralType):
fullname = typ.fallback.type.fullname
if typ.fallback.type.is_enum:
if fullname not in sum_types:
sum_types[fullname] = (set(get_enum_values(typ.fallback)), [])
literals, indexes = sum_types[fullname]
literals.discard(typ.value)
indexes.append(idx)
if not literals:
first, *rest = indexes
types[first] = typ.fallback
marked_for_deletion |= set(rest)
return list(itertools.compress(types, [(i not in marked_for_deletion)
for i in range(len(types))]))


def coerce_to_literal(typ: Type) -> Type:
"""Recursively converts any Instances that have a last_known_value or are
instances of enum types with a single value into the corresponding LiteralType.
Expand Down
34 changes: 34 additions & 0 deletions test-data/unit/check-enum.test
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ elif x is Foo.C:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable
reveal_type(x) # N: Revealed type is '__main__.Foo'

if Foo.A is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
Expand All @@ -722,6 +723,7 @@ elif Foo.C is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(x) # No output here: this branch is unreachable
reveal_type(x) # N: Revealed type is '__main__.Foo'

y: Foo
if y is Foo.A:
Expand All @@ -732,6 +734,7 @@ elif y is Foo.C:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable
reveal_type(y) # N: Revealed type is '__main__.Foo'

if Foo.A is y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
Expand All @@ -741,6 +744,7 @@ elif Foo.C is y:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
else:
reveal_type(y) # No output here: this branch is unreachable
reveal_type(y) # N: Revealed type is '__main__.Foo'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityChecksWithOrdering]
Expand Down Expand Up @@ -815,12 +819,14 @@ if x is y:
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(x) # N: Revealed type is '__main__.Foo'
if y is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(x) # N: Revealed type is '__main__.Foo'

if x is z:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
Expand All @@ -830,6 +836,7 @@ else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?'
accepts_foo_a(z)
reveal_type(x) # N: Revealed type is '__main__.Foo'
if z is x:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?'
Expand All @@ -838,6 +845,7 @@ else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?'
accepts_foo_a(z)
reveal_type(x) # N: Revealed type is '__main__.Foo'

if y is z:
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
Expand Down Expand Up @@ -909,6 +917,7 @@ if x is Foo.A:
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
[builtins fixtures/bool.pyi]

[case testEnumReachabilityWithMultipleEnums]
Expand All @@ -928,18 +937,21 @@ if x1 is Foo.A:
reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
else:
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'
reveal_type(x1) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'

x2: Union[Foo, Bar]
if x2 is Bar.A:
reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]'
else:
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'

x3: Union[Foo, Bar]
if x3 is Foo.A or x3 is Bar.A:
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
else:
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'
reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'

[builtins fixtures/bool.pyi]

Expand Down Expand Up @@ -1299,3 +1311,25 @@ reveal_type(a._value_) # N: Revealed type is 'Any'
[builtins fixtures/__new__.pyi]
[builtins fixtures/primitives.pyi]
[typing fixtures/typing-medium.pyi]

[case testEnumNarrowedToTwoLiterals]
# Regression test: two literals of an enum would be joined
# as the full type, regardless of the amount of elements
# the enum contains.
from enum import Enum
from typing import Union
from typing_extensions import Literal

class Foo(Enum):
A = 1
B = 2
C = 3

def f(x: Foo):
if x is Foo.A:
return x
if x is Foo.B:
pass
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'

[builtins fixtures/bool.pyi]