Skip to content

Commit

Permalink
Speed up make_simplified_union, remove a potential crash
Browse files Browse the repository at this point in the history
The following code optimises make_simplified_union in the common case
that there are exact duplicates in the union. In this regard, this is
similar to python#15104

To get this to work, I needed to use partial tuple fallbacks in a couple
places (these maybe had the potential to be latent crashes anyway?)
There were some interesting things going on with recursive type aliases
and type state assumptions

This is about a 25% speedup on the pydantic codebase and about a 2%
speedup on self check (measured with uncompiled mypy)
  • Loading branch information
hauntsaninja committed Apr 25, 2023
1 parent ba35026 commit 184e415
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 85 deletions.
7 changes: 5 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def visit_instance(self, left: Instance) -> bool:
# dynamic base classes correctly, see #5456.
return not isinstance(self.right, NoneType)
right = self.right
if isinstance(right, TupleType) and mypy.typeops.tuple_fallback(right).type.is_enum:
if isinstance(right, TupleType) and right.partial_fallback.type.is_enum:
return self._is_subtype(left, mypy.typeops.tuple_fallback(right))
if isinstance(right, Instance):
if type_state.is_cached_subtype_check(self._subtype_kind, left, right):
Expand Down Expand Up @@ -753,7 +753,10 @@ def visit_tuple_type(self, left: TupleType) -> bool:
# for isinstance(x, tuple), though it's unclear why.
return True
return all(self._is_subtype(li, iter_type) for li in left.items)
elif self._is_subtype(mypy.typeops.tuple_fallback(left), right):
elif (
self._is_subtype(left.partial_fallback, right)
and self._is_subtype(mypy.typeops.tuple_fallback(left), right)
):
return True
return False
elif isinstance(right, TupleType):
Expand Down
2 changes: 1 addition & 1 deletion mypy/test/testtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def test_simplified_union_with_mixed_str_literals(self) -> None:
)
self.assert_simplified_union(
[fx.lit_str1, fx.lit_str1, fx.lit_str1_inst],
UnionType([fx.lit_str1, fx.lit_str1_inst]),
fx.lit_str1,
)

def assert_simplified_union(self, original: list[Type], union: Type) -> None:
Expand Down
121 changes: 39 additions & 82 deletions mypy/typeops.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,25 +385,6 @@ def callable_corresponding_argument(
return by_name if by_name is not None else by_pos


def simple_literal_value_key(t: ProperType) -> tuple[str, ...] | None:
"""Return a hashable description of simple literal type.
Return None if not a simple literal type.
The return value can be used to simplify away duplicate types in
unions by comparing keys for equality. For now enum, string or
Instance with string last_known_value are supported.
"""
if isinstance(t, LiteralType):
if t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str":
assert isinstance(t.value, str)
return "literal", t.value, t.fallback.type.fullname
if isinstance(t, Instance):
if t.last_known_value is not None and isinstance(t.last_known_value.value, str):
return "instance", t.last_known_value.value, t.type.fullname
return None


def simple_literal_type(t: ProperType | None) -> Instance | None:
"""Extract the underlying fallback Instance type for a simple Literal"""
if isinstance(t, Instance) and t.last_known_value is not None:
Expand All @@ -414,7 +395,6 @@ def simple_literal_type(t: ProperType | None) -> Instance | None:


def is_simple_literal(t: ProperType) -> bool:
"""Fast way to check if simple_literal_value_key() would return a non-None value."""
if isinstance(t, LiteralType):
return t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str"
if isinstance(t, Instance):
Expand Down Expand Up @@ -500,68 +480,45 @@ def make_simplified_union(
def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]:
from mypy.subtypes import is_proper_subtype

removed: set[int] = set()
seen: set[tuple[str, ...]] = set()

# NB: having a separate fast path for Union of Literal and slow path for other things
# would arguably be cleaner, however it breaks down when simplifying the Union of two
# different enum types as try_expanding_sum_type_to_union works recursively and will
# trigger intermediate simplifications that would render the fast path useless
for i, item in enumerate(items):
proper_item = get_proper_type(item)
if i in removed:
continue
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
k = simple_literal_value_key(proper_item)
if k is not None:
if k in seen:
removed.add(i)
continue

# NB: one would naively expect that it would be safe to skip the slow path
# always for literals. One would be sorely mistaken. Indeed, some simplifications
# such as that of None/Optional when strict optional is false, do require that we
# proceed with the slow path. Thankfully, all literals will have the same subtype
# relationship to non-literal types, so we only need to do that walk for the first
# literal, which keeps the fast path fast even in the presence of a mixture of
# literals and other types.
safe_skip = len(seen) > 0
seen.add(k)
if safe_skip:
continue

# Keep track of the truthiness info for deleted subtypes which can be relevant
cbt = cbf = False
for j, tj in enumerate(items):
proper_tj = get_proper_type(tj)
if (
i == j
# avoid further checks if this item was already marked redundant.
or j in removed
# if the current item is a simple literal then this simplification loop can
# safely skip all other simple literals as two literals will only ever be
# subtypes of each other if they are equal, which is already handled above.
# However, if the current item is not a literal, it might plausibly be a
# supertype of other literals in the union, so we must check them again.
# This is an important optimization as is_proper_subtype is pretty expensive.
or (k is not None and is_simple_literal(proper_tj))
):
continue
# actual redundancy checks (XXX?)
if is_redundant_literal_instance(proper_item, proper_tj) and is_proper_subtype(
tj, item, keep_erased_types=keep_erased, ignore_promotions=True
):
# We found a redundant item in the union.
removed.add(j)
cbt = cbt or tj.can_be_true
cbf = cbf or tj.can_be_false
# if deleted subtypes had more general truthiness, use that
if not item.can_be_true and cbt:
items[i] = true_or_false(item)
elif not item.can_be_false and cbf:
items[i] = true_or_false(item)

return [items[i] for i in range(len(items)) if i not in removed]
# The first pass through this loop, we check if later items are subtypes of earlier items.
# The second pass through this loop, we check if earlier items are subtypes of later items
# (by reversing the remaining items)
for _direction in range(2):
new_items: list[Type] = []
# seen is a map from a type to its index in new_items
seen: dict[ProperType, int] = {}
for ti in items:
proper_ti = get_proper_type(ti)

duplicate_index = -1
# Quickly check if we've seen this type
if proper_ti in seen:
duplicate_index = seen[proper_ti]
else:
# If not, check if we've seen a supertype of this type
for j, tj in enumerate(new_items):
tj = get_proper_type(tj)
if is_redundant_literal_instance(tj, proper_ti) and is_proper_subtype(
proper_ti, tj, keep_erased_types=keep_erased, ignore_promotions=True
):
duplicate_index = j
break
if duplicate_index != -1:
# If deleted subtypes had more general truthiness, use that
orig_item = new_items[duplicate_index]
if not orig_item.can_be_true and ti.can_be_true:
new_items[duplicate_index] = true_or_false(orig_item)
elif not orig_item.can_be_false and ti.can_be_false:
new_items[duplicate_index] = true_or_false(orig_item)
else:
# We have a non-duplicate item, add it to new_items
seen[proper_ti] = len(new_items)
new_items.append(ti)

items = new_items
items.reverse()

return items


def _get_type_special_method_bool_ret_type(t: Type) -> Type | None:
Expand Down

0 comments on commit 184e415

Please sign in to comment.