From e5a41c6eb0987efd496bed657f1cdbea4929ecde Mon Sep 17 00:00:00 2001 From: Hugues Bruant Date: Fri, 22 Apr 2022 14:00:10 -0700 Subject: [PATCH] make_simplified_union: add caching and reduce allocations make_simplified_union is used in a lot of places and therefore accounts for a significant share to typechecking time. Based on sample metrics gathered from a large real-world codebase we can see that: 1. the majority of inputs are already as simple as they're going to get, which means we can avoid allocation extra lists and return the input unchanged 2. most of the cost of `make_simplified_union` comes from `is_proper_subtype` 3. `is_proper_subtype` has some caching going on under the hood but it only applies to `Instance`, and cache hit rate is low in this particular case because, as per 1) above, items are in fact rarely subtypes of each other To address 1, refactor `make_simplified_union` with an optimistic fast path that avoid unnecessary allocations. To address 2 & 3, introduce a cache to record the result of union simplification. These changes are observed to yield significant improvements in a real-world codebase: a roughly 10-20% overall speedup, with make_simplified_union/is_proper_subtype no longer showing up as hotspots in the py-spy profile. For #12526 --- mypy/typeops.py | 130 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 101 insertions(+), 29 deletions(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index dbfeebe42f14..653a2e4df531 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,7 +5,9 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union +from typing import ( + cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any, Union, Callable +) from typing_extensions import Type as TypingType import itertools import sys @@ -336,6 +338,47 @@ def is_simple_literal(t: ProperType) -> bool: return False +def _get_flattened_proper_types(items: Sequence[Type]) -> Sequence[ProperType]: + """Similar to types.get_proper_types, with flattening of UnionType + + Optimized to avoid allocating a new list whenever possible""" + i: int = 0 + base: int = 0 + n: int = len(items) + + # optimistic fast path + while i < n: + t = items[i] + pt = get_proper_type(t) + if id(t) != id(pt) or isinstance(pt, UnionType): + # we need to allocate, switch to slow path + break + # simplify away any number of bottom type at the start of the input + if i == base and i+1 < n and isinstance(pt, UninhabitedType): + base += 1 + i += 1 + + # optimistic fast path reached end of input, no need to allocate + if i == n: + return cast(Sequence[ProperType], items[base:] if base > 0 else items) + + all_items = list(cast(Sequence[ProperType], items[base:i])) + + while i < n: + pt = get_proper_type(items[i]) + if isinstance(pt, UnionType): + all_items.extend(_get_flattened_proper_types(pt.items)) + else: + all_items.append(pt) + i += 1 + return all_items + + +_simplified_union_cache: List[Dict[Tuple[ProperType, ...], ProperType]] = [ + {} for _ in range(2**3) +] + + def make_simplified_union(items: Sequence[Type], line: int = -1, column: int = -1, *, keep_erased: bool = False, @@ -362,17 +405,35 @@ def make_simplified_union(items: Sequence[Type], back into a sum type. Set it to False when called by try_expanding_sum_type_ to_union(). """ - items = get_proper_types(items) - # Step 1: expand all nested unions - while any(isinstance(typ, UnionType) for typ in items): - all_items: List[ProperType] = [] - for typ in items: - if isinstance(typ, UnionType): - all_items.extend(get_proper_types(typ.items)) - else: - all_items.append(typ) - items = all_items + items = _get_flattened_proper_types(items) + + cache_fn: Optional[Callable[[ProperType], None]] = None + + # 1 or 2 elements account for the vast majority of inputs and are not worth caching: + # - they're two small for the quadratic worst-case cost of simplification to really + # manifest + # - they majority of those inputs are only triggered once + # - avoiding the extra allocations is a bigger win + if len(items) == 1: + return items[0] + elif len(items) > 2: + # NB: ideally we would use a frozenset, but that would require normalizing the + # order of entries in the simplified union, or updating the test harness to + # treat Unions as equivalent regardless of item ordering (which is particularly + # tricky when it comes to all tests using string matching on reveal_type output) + cache_key = tuple(items) + # NB: we need to maintain separate caches depending on flags that might impact + # the results of simplification + cache = _simplified_union_cache[ + int(keep_erased) + | int(contract_literals) << 1 + | int(state.strict_optional) << 2 + ] + ret = cache.get(cache_key, None) + if ret is not None: + return ret + cache_fn = lambda v: cache.__setitem__(cache_key, v) # noqa: E731 # Step 2: remove redundant unions simplified_set = _remove_redundant_union_items(items, keep_erased) @@ -381,13 +442,20 @@ def make_simplified_union(items: Sequence[Type], if contract_literals and sum(isinstance(item, LiteralType) for item in simplified_set) > 1: simplified_set = try_contracting_literals_in_union(simplified_set) - return UnionType.make_union(simplified_set, line, column) + ret = UnionType.make_union(simplified_set, line, column) + + if cache_fn: + cache_fn(ret) + + return ret -def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> List[ProperType]: +def _remove_redundant_union_items(items: Sequence[ProperType], + keep_erased: bool) -> Sequence[ProperType]: from mypy.subtypes import is_proper_subtype removed: Set[int] = set() + truthed: Set[int] = set() seen: Set[Tuple[str, ...]] = set() # NB: having a separate fast path for Union of Literal and slow path for other things @@ -397,6 +465,7 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> for i, item in enumerate(items): if i in removed: continue + # Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169) k = simple_literal_value_key(item) if k is not None: @@ -434,20 +503,34 @@ def _remove_redundant_union_items(items: List[ProperType], keep_erased: bool) -> continue # actual redundancy checks if ( - is_redundant_literal_instance(item, tj) # XXX? - and is_proper_subtype(tj, item, keep_erased_types=keep_erased) + isinstance(tj, UninhabitedType) + or ( + ( + not isinstance(item, Instance) + or item.last_known_value is None + or ( + isinstance(tj, Instance) + and tj.last_known_value == item.last_known_value + ) + ) + and is_proper_subtype(tj, item, keep_erased_types=keep_erased) + ) ): # We found a redundant item in the union. removed.add(j) cbt = cbt or tj.can_be_true cbf = cbf or tj.can_be_false + # if deleted subtypes had more general truthiness, use that if not item.can_be_true and cbt: - items[i] = true_or_false(item) + truthed.add(i) elif not item.can_be_false and cbf: - items[i] = true_or_false(item) + truthed.add(i) - return [items[i] for i in range(len(items)) if i not in removed] + if not removed and not truthed: + return items + return [true_or_false(items[i]) if i in truthed else items[i] + for i in range(len(items)) if i not in removed] def _get_type_special_method_bool_ret_type(t: Type) -> Optional[Type]: @@ -889,17 +972,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool return False -def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool: - if not isinstance(general, Instance) or general.last_known_value is None: - return True - if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value: - return True - if isinstance(specific, UninhabitedType): - return True - - return False - - def separate_union_literals(t: UnionType) -> Tuple[Sequence[LiteralType], Sequence[Type]]: """Separate literals from other members in a union type.""" literal_items = []