Skip to content

make_simplified_union: add caching and reduce allocations #12659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 101 additions & 29 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
since these may assume that MROs are ready.
"""

from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union
from typing import (
cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union, Callable
)
from typing_extensions import Type as TypingType
import itertools
import sys
Expand Down Expand Up @@ -336,6 +338,47 @@ def is_simple_literal(t: ProperType) -> bool:
return False


def _get_flattened_proper_types(items: Sequence[Type]) -> Sequence[ProperType]:
"""Similar to types.get_proper_types, with flattening of UnionType

Optimized to avoid allocating a new list whenever possible"""
i: int = 0
base: int = 0
n: int = len(items)

# optimistic fast path
while i < n:
t = items[i]
pt = get_proper_type(t)
if id(t) != id(pt) or isinstance(pt, UnionType):
# we need to allocate, switch to slow path
break
# simplify away any number of bottom type at the start of the input
if i == base and i+1 < n and isinstance(pt, UninhabitedType):
base += 1
i += 1

# optimistic fast path reached end of input, no need to allocate
if i == n:
return cast(Sequence[ProperType], items[base:] if base > 0 else items)

all_items = list(cast(Sequence[ProperType], items[base:i]))

while i < n:
pt = get_proper_type(items[i])
if isinstance(pt, UnionType):
all_items.extend(_get_flattened_proper_types(pt.items))
else:
all_items.append(pt)
i += 1
return all_items


_simplified_union_cache: List[Dict[Tuple[ProperType, ...], ProperType]] = [
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this live in TypeState soit can be reset along with other caches?

{} for _ in range(2**3)
]


def make_simplified_union(items: Sequence[Type],
line: int = -1, column: int = -1,
*, keep_erased: bool = False,
Expand All @@ -362,17 +405,35 @@ def make_simplified_union(items: Sequence[Type],
back into a sum type. Set it to False when called by try_expanding_sum_type_
to_union().
"""
items = get_proper_types(items)

# Step 1: expand all nested unions
while any(isinstance(typ, UnionType) for typ in items):
all_items: List[ProperType] = []
for typ in items:
if isinstance(typ, UnionType):
all_items.extend(get_proper_types(typ.items))
else:
all_items.append(typ)
items = all_items
items = _get_flattened_proper_types(items)

cache_fn: Optional[Callable[[ProperType], None]] = None

# 1 or 2 elements account for the vast majority of inputs and are not worth caching:
# - they're two small for the quadratic worst-case cost of simplification to really
# manifest
# - they majority of those inputs are only triggered once
# - avoiding the extra allocations is a bigger win
if len(items) == 1:
return items[0]
elif len(items) > 2:
# NB: ideally we would use a frozenset, but that would require normalizing the
# order of entries in the simplified union, or updating the test harness to
# treat Unions as equivalent regardless of item ordering (which is particularly
# tricky when it comes to all tests using string matching on reveal_type output)
cache_key = tuple(items)
# NB: we need to maintain separate caches depending on flags that might impact
# the results of simplification
cache = _simplified_union_cache[
int(keep_erased)
| int(contract_literals) << 1
| int(state.strict_optional) << 2
]
ret = cache.get(cache_key, None)
if ret is not None:
return ret
cache_fn = lambda v: cache.__setitem__(cache_key, v) # noqa: E731

# Step 2: remove redundant unions
simplified_set = _remove_redundant_union_items(items, keep_erased)
Expand All @@ -381,13 +442,20 @@ def make_simplified_union(items: Sequence[Type],
if contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1:
simplified_set = try_contracting_literals_in_union(simplified_set)

return UnionType.make_union(simplified_set, line, column)
ret = UnionType.make_union(simplified_set, line, column)

if cache_fn:
cache_fn(ret)

return ret


def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]:
def _remove_redundant_union_items(items: Sequence[ProperType],
keep_erased: bool) -> Sequence[ProperType]:
from mypy.subtypes import is_proper_subtype

removed: Set[int] = set()
truthed: Set[int] = set()
seen: Set[Tuple[str, ...]] = set()

# NB: having a separate fast path for Union of Literal and slow path for other things
Expand All @@ -397,6 +465,7 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) ->
for i, item in enumerate(items):
if i in removed:
continue

# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
k = simple_literal_value_key(item)
if k is not None:
Expand Down Expand Up @@ -434,20 +503,34 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) ->
continue
# actual redundancy checks
if (
is_redundant_literal_instance(item, tj) # XXX?
and is_proper_subtype(tj, item, keep_erased_types=keep_erased)
isinstance(tj, UninhabitedType)
or (
(
not isinstance(item, Instance)
or item.last_known_value is None
or (
isinstance(tj, Instance)
and tj.last_known_value == item.last_known_value
)
)
and is_proper_subtype(tj, item, keep_erased_types=keep_erased)
)
):
# We found a redundant item in the union.
removed.add(j)
cbt = cbt or tj.can_be_true
cbf = cbf or tj.can_be_false

# if deleted subtypes had more general truthiness, use that
if not item.can_be_true and cbt:
items[i] = true_or_false(item)
truthed.add(i)
elif not item.can_be_false and cbf:
items[i] = true_or_false(item)
truthed.add(i)

return [items[i] for i in range(len(items)) if i not in removed]
if not removed and not truthed:
return items
return [true_or_false(items[i]) if i in truthed else items[i]
for i in range(len(items)) if i not in removed]


def _get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]:
Expand Down Expand Up @@ -889,17 +972,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
return False


def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool:
if not isinstance(general, Instance) or general.last_known_value is None:
return True
if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value:
return True
if isinstance(specific, UninhabitedType):
return True

return False


def separate_union_literals(t: UnionType) -> Tuple[Sequence[LiteralType], Sequence[Type]]:
"""Separate literals from other members in a union type."""
literal_items = []
Expand Down