From 184e415502add001425ef93fb74464618973879f Mon Sep 17 00:00:00 2001 From: hauntsaninja Date: Tue, 25 Apr 2023 02:09:07 -0600 Subject: [PATCH] Speed up make_simplified_union, remove a potential crash The following code optimises make_simplified_union in the common case that there are exact duplicates in the union. In this regard, this is similar to #15104 To get this to work, I needed to use partial tuple fallbacks in a couple places (these maybe had the potential to be latent crashes anyway?) There were some interesting things going on with recursive type aliases and type state assumptions This is about a 25% speedup on the pydantic codebase and about a 2% speedup on self check (measured with uncompiled mypy) --- mypy/subtypes.py | 7 ++- mypy/test/testtypes.py | 2 +- mypy/typeops.py | 121 +++++++++++++---------------------------- 3 files changed, 45 insertions(+), 85 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 59919456ab5c..88e4c6929aaf 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -439,7 +439,7 @@ def visit_instance(self, left: Instance) -> bool: # dynamic base classes correctly, see #5456. return not isinstance(self.right, NoneType) right = self.right - if isinstance(right, TupleType) and mypy.typeops.tuple_fallback(right).type.is_enum: + if isinstance(right, TupleType) and right.partial_fallback.type.is_enum: return self._is_subtype(left, mypy.typeops.tuple_fallback(right)) if isinstance(right, Instance): if type_state.is_cached_subtype_check(self._subtype_kind, left, right): @@ -753,7 +753,10 @@ def visit_tuple_type(self, left: TupleType) -> bool: # for isinstance(x, tuple), though it's unclear why. return True return all(self._is_subtype(li, iter_type) for li in left.items) - elif self._is_subtype(mypy.typeops.tuple_fallback(left), right): + elif ( + self._is_subtype(left.partial_fallback, right) + and self._is_subtype(mypy.typeops.tuple_fallback(left), right) + ): return True return False elif isinstance(right, TupleType): diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 601cdf27466e..3ac91e078b1c 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -613,7 +613,7 @@ def test_simplified_union_with_mixed_str_literals(self) -> None: ) self.assert_simplified_union( [fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], - UnionType([fx.lit_str1, fx.lit_str1_inst]), + fx.lit_str1, ) def assert_simplified_union(self, original: list[Type], union: Type) -> None: diff --git a/mypy/typeops.py b/mypy/typeops.py index 8ed59b6fbe55..fb00d4363a54 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -385,25 +385,6 @@ def callable_corresponding_argument( return by_name if by_name is not None else by_pos -def simple_literal_value_key(t: ProperType) -> tuple[str, ...] | None: - """Return a hashable description of simple literal type. - - Return None if not a simple literal type. - - The return value can be used to simplify away duplicate types in - unions by comparing keys for equality. For now enum, string or - Instance with string last_known_value are supported. - """ - if isinstance(t, LiteralType): - if t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str": - assert isinstance(t.value, str) - return "literal", t.value, t.fallback.type.fullname - if isinstance(t, Instance): - if t.last_known_value is not None and isinstance(t.last_known_value.value, str): - return "instance", t.last_known_value.value, t.type.fullname - return None - - def simple_literal_type(t: ProperType | None) -> Instance | None: """Extract the underlying fallback Instance type for a simple Literal""" if isinstance(t, Instance) and t.last_known_value is not None: @@ -414,7 +395,6 @@ def simple_literal_type(t: ProperType | None) -> Instance | None: def is_simple_literal(t: ProperType) -> bool: - """Fast way to check if simple_literal_value_key() would return a non-None value.""" if isinstance(t, LiteralType): return t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str" if isinstance(t, Instance): @@ -500,68 +480,45 @@ def make_simplified_union( def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]: from mypy.subtypes import is_proper_subtype - removed: set[int] = set() - seen: set[tuple[str, ...]] = set() - - # NB: having a separate fast path for Union of Literal and slow path for other things - # would arguably be cleaner, however it breaks down when simplifying the Union of two - # different enum types as try_expanding_sum_type_to_union works recursively and will - # trigger intermediate simplifications that would render the fast path useless - for i, item in enumerate(items): - proper_item = get_proper_type(item) - if i in removed: - continue - # Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169) - k = simple_literal_value_key(proper_item) - if k is not None: - if k in seen: - removed.add(i) - continue - - # NB: one would naively expect that it would be safe to skip the slow path - # always for literals. One would be sorely mistaken. Indeed, some simplifications - # such as that of None/Optional when strict optional is false, do require that we - # proceed with the slow path. Thankfully, all literals will have the same subtype - # relationship to non-literal types, so we only need to do that walk for the first - # literal, which keeps the fast path fast even in the presence of a mixture of - # literals and other types. - safe_skip = len(seen) > 0 - seen.add(k) - if safe_skip: - continue - - # Keep track of the truthiness info for deleted subtypes which can be relevant - cbt = cbf = False - for j, tj in enumerate(items): - proper_tj = get_proper_type(tj) - if ( - i == j - # avoid further checks if this item was already marked redundant. - or j in removed - # if the current item is a simple literal then this simplification loop can - # safely skip all other simple literals as two literals will only ever be - # subtypes of each other if they are equal, which is already handled above. - # However, if the current item is not a literal, it might plausibly be a - # supertype of other literals in the union, so we must check them again. - # This is an important optimization as is_proper_subtype is pretty expensive. - or (k is not None and is_simple_literal(proper_tj)) - ): - continue - # actual redundancy checks (XXX?) - if is_redundant_literal_instance(proper_item, proper_tj) and is_proper_subtype( - tj, item, keep_erased_types=keep_erased, ignore_promotions=True - ): - # We found a redundant item in the union. - removed.add(j) - cbt = cbt or tj.can_be_true - cbf = cbf or tj.can_be_false - # if deleted subtypes had more general truthiness, use that - if not item.can_be_true and cbt: - items[i] = true_or_false(item) - elif not item.can_be_false and cbf: - items[i] = true_or_false(item) - - return [items[i] for i in range(len(items)) if i not in removed] + # The first pass through this loop, we check if later items are subtypes of earlier items. + # The second pass through this loop, we check if earlier items are subtypes of later items + # (by reversing the remaining items) + for _direction in range(2): + new_items: list[Type] = [] + # seen is a map from a type to its index in new_items + seen: dict[ProperType, int] = {} + for ti in items: + proper_ti = get_proper_type(ti) + + duplicate_index = -1 + # Quickly check if we've seen this type + if proper_ti in seen: + duplicate_index = seen[proper_ti] + else: + # If not, check if we've seen a supertype of this type + for j, tj in enumerate(new_items): + tj = get_proper_type(tj) + if is_redundant_literal_instance(tj, proper_ti) and is_proper_subtype( + proper_ti, tj, keep_erased_types=keep_erased, ignore_promotions=True + ): + duplicate_index = j + break + if duplicate_index != -1: + # If deleted subtypes had more general truthiness, use that + orig_item = new_items[duplicate_index] + if not orig_item.can_be_true and ti.can_be_true: + new_items[duplicate_index] = true_or_false(orig_item) + elif not orig_item.can_be_false and ti.can_be_false: + new_items[duplicate_index] = true_or_false(orig_item) + else: + # We have a non-duplicate item, add it to new_items + seen[proper_ti] = len(new_items) + new_items.append(ti) + + items = new_items + items.reverse() + + return items def _get_type_special_method_bool_ret_type(t: Type) -> Type | None: