Skip to content

Commit

Permalink
[mypyc] (Re-)Support iterating over an Union of dicts (#14713)
Browse files Browse the repository at this point in the history
An optimization to make iterating over dict.keys(), dict.values() and
dict.items() faster caused mypyc to crash while compiling a Union of
dictionaries. This commit fixes the optimization helpers to properly
handle unions.

irbuild.Builder.get_dict_base_type() now returns list[Instance] with the
union items. In the common case we don't have a union, a single-element
list is returned. And get_dict_key_type() and get_dict_value_type() will
now build a simplified RUnion as needed.

Fixes mypyc/mypyc#965 and probably #14694.
  • Loading branch information
ichard26 authored Feb 16, 2023
1 parent 0bbeab8 commit 7237831
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 12 deletions.
5 changes: 2 additions & 3 deletions mypyc/codegen/literals.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, FrozenSet, List, Tuple, Union, cast
from typing import Any, FrozenSet, List, Tuple, Union, cast
from typing_extensions import Final

# Supported Python literal types. All tuple / frozenset items must have supported
Expand Down Expand Up @@ -151,8 +151,7 @@ def _encode_collection_values(
<length of the second collection>
...
"""
# FIXME: https://github.com/mypyc/mypyc/issues/965
value_by_index = {index: value for value, index in cast(Dict[Any, int], values).items()}
value_by_index = {index: value for value, index in values.items()}
result = []
count = len(values)
result.append(str(count))
Expand Down
32 changes: 24 additions & 8 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,23 +879,39 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType:
else:
return self.type_to_rtype(target_type.args[0])

def get_dict_base_type(self, expr: Expression) -> Instance:
def get_dict_base_type(self, expr: Expression) -> list[Instance]:
"""Find dict type of a dict-like expression.
This is useful for dict subclasses like SymbolTable.
"""
target_type = get_proper_type(self.types[expr])
assert isinstance(target_type, Instance), target_type
dict_base = next(base for base in target_type.type.mro if base.fullname == "builtins.dict")
return map_instance_to_supertype(target_type, dict_base)
if isinstance(target_type, UnionType):
types = [get_proper_type(item) for item in target_type.items]
else:
types = [target_type]

dict_types = []
for t in types:
assert isinstance(t, Instance), t
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
dict_types.append(map_instance_to_supertype(t, dict_base))
return dict_types

def get_dict_key_type(self, expr: Expression) -> RType:
dict_base_type = self.get_dict_base_type(expr)
return self.type_to_rtype(dict_base_type.args[0])
dict_base_types = self.get_dict_base_type(expr)
if len(dict_base_types) == 1:
return self.type_to_rtype(dict_base_types[0].args[0])
else:
rtypes = [self.type_to_rtype(t.args[0]) for t in dict_base_types]
return RUnion.make_simplified_union(rtypes)

def get_dict_value_type(self, expr: Expression) -> RType:
dict_base_type = self.get_dict_base_type(expr)
return self.type_to_rtype(dict_base_type.args[1])
dict_base_types = self.get_dict_base_type(expr)
if len(dict_base_types) == 1:
return self.type_to_rtype(dict_base_types[0].args[1])
else:
rtypes = [self.type_to_rtype(t.args[1]) for t in dict_base_types]
return RUnion.make_simplified_union(rtypes)

def get_dict_item_type(self, expr: Expression) -> RType:
key_type = self.get_dict_key_type(expr)
Expand Down
58 changes: 57 additions & 1 deletion mypyc/test-data/irbuild-dict.test
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,17 @@ L0:
return r2

[case testDictIterationMethods]
from typing import Dict
from typing import Dict, Union
def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int]) -> None:
for v in d1.values():
if v in d2:
return
for k, v in d2.items():
d2[k] += v
def union_of_dicts(d: Union[Dict[str, int], Dict[str, str]]) -> None:
new = {}
for k, v in d.items():
new[k] = int(v)
[out]
def print_dict_methods(d1, d2):
d1, d2 :: dict
Expand Down Expand Up @@ -314,6 +318,58 @@ L11:
r34 = CPy_NoErrOccured()
L12:
return 1
def union_of_dicts(d):
d, r0, new :: dict
r1 :: short_int
r2 :: native_int
r3 :: short_int
r4 :: object
r5 :: tuple[bool, short_int, object, object]
r6 :: short_int
r7 :: bool
r8, r9 :: object
r10 :: str
r11 :: union[int, str]
k :: str
v :: union[int, str]
r12, r13 :: object
r14 :: int
r15 :: object
r16 :: int32
r17, r18, r19 :: bit
L0:
r0 = PyDict_New()
new = r0
r1 = 0
r2 = PyDict_Size(d)
r3 = r2 << 1
r4 = CPyDict_GetItemsIter(d)
L1:
r5 = CPyDict_NextItem(r4, r1)
r6 = r5[1]
r1 = r6
r7 = r5[0]
if r7 goto L2 else goto L4 :: bool
L2:
r8 = r5[2]
r9 = r5[3]
r10 = cast(str, r8)
r11 = cast(union[int, str], r9)
k = r10
v = r11
r12 = load_address PyLong_Type
r13 = PyObject_CallFunctionObjArgs(r12, v, 0)
r14 = unbox(int, r13)
r15 = box(int, r14)
r16 = CPyDict_SetItem(new, k, r15)
r17 = r16 >= 0 :: signed
L3:
r18 = CPyDict_CheckSize(d, r3)
goto L1
L4:
r19 = CPy_NoErrOccured()
L5:
return 1

[case testDictLoadAddress]
def f() -> None:
Expand Down

0 comments on commit 7237831

Please sign in to comment.