Skip to content

Commit d5e96e3

Browse files
authored
[mypyc] Add match statement support (#13953)
Closes mypyc/mypyc#911 Like the title says, this PR adds support for compiling `match` statements in mypyc. Most of the work has been done, but there are some things which are still a WIP. A todo list of what has been done, and the (small) number of things that need to be worked out: - [x] Or patterns: `1 | 2 | 3` - [x] Value patterns: `123`, `x.y.z`, etc. - [x] Singleton patterns: `True`, `False`, and `None` - [x] Sequence patterns: - [x] Fixed length patterns `[1, 2, 3]` - [x] Starred patterns `[*prev, 4, 5, 6]`, `[1, 2, 3, *rest]`, etc: - [x] `[*rest]` is currently not working, but should be an easy fix - [x] Support any object which supports the [Sequence Protocol](https://docs.python.org/3/c-api/sequence.html) (need help with this) - [x] Mapping Pattern (`{"key": value}`): - [x] General support - [x] Starred patterns: `{"key": value, **rest}` - [x] Support any object which supports the [Mapping Protocol](https://docs.python.org/3/c-api/mapping.html) (need help with this) - [x] Class patterns: - [x] Basic class `isinstance()` check - [x] Positional args: `Class(1, 2, 3)` - [x] Keyword args: `Class(x=1, y=2, z=3)` - [x] Shortcut for built-in datatypes: `int(x)` -> `int() as x` - [x] Capture patterns: - [x] Wildcard pattern: `_` - [x] As pattern: `123 as num` - [x] Capture pattern: `x` Some features which I was unsure how to implement are: * Fix `*rest` and `**rest` star patterns name collisions. Basically, you cannot use `rest` (or any other name) twice in the same match statement if `rest` is a different type (ie, `dict` vs `list`). If it was defined as `object` instead of `dict`/`list` everything would be fine. Also some operations on native classes and primitive types could be optimized.
1 parent 740b364 commit d5e96e3

File tree

16 files changed

+2430
-9
lines changed

16 files changed

+2430
-9
lines changed

mypyc/irbuild/classdef.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ def find_attr_initializers(
629629
and not isinstance(stmt.rvalue, TempNode)
630630
):
631631
name = stmt.lvalues[0].name
632-
if name in ("__slots__", "__match_args__"):
632+
if name == "__slots__":
633633
continue
634634

635635
if name == "__deletable__":

mypyc/irbuild/match.py

Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
from contextlib import contextmanager
2+
from typing import Generator, List, Optional, Tuple
3+
4+
from mypy.nodes import MatchStmt, NameExpr, TypeInfo
5+
from mypy.patterns import (
6+
AsPattern,
7+
ClassPattern,
8+
MappingPattern,
9+
OrPattern,
10+
Pattern,
11+
SequencePattern,
12+
SingletonPattern,
13+
StarredPattern,
14+
ValuePattern,
15+
)
16+
from mypy.traverser import TraverserVisitor
17+
from mypy.types import Instance, TupleType, get_proper_type
18+
from mypyc.ir.ops import BasicBlock, Value
19+
from mypyc.ir.rtypes import object_rprimitive
20+
from mypyc.irbuild.builder import IRBuilder
21+
from mypyc.primitives.dict_ops import (
22+
dict_copy,
23+
dict_del_item,
24+
mapping_has_key,
25+
supports_mapping_protocol,
26+
)
27+
from mypyc.primitives.generic_ops import generic_ssize_t_len_op
28+
from mypyc.primitives.list_ops import (
29+
sequence_get_item,
30+
sequence_get_slice,
31+
supports_sequence_protocol,
32+
)
33+
from mypyc.primitives.misc_ops import fast_isinstance_op, slow_isinstance_op
34+
35+
# From: https://peps.python.org/pep-0634/#class-patterns
36+
MATCHABLE_BUILTINS = {
37+
"builtins.bool",
38+
"builtins.bytearray",
39+
"builtins.bytes",
40+
"builtins.dict",
41+
"builtins.float",
42+
"builtins.frozenset",
43+
"builtins.int",
44+
"builtins.list",
45+
"builtins.set",
46+
"builtins.str",
47+
"builtins.tuple",
48+
}
49+
50+
51+
class MatchVisitor(TraverserVisitor):
52+
builder: IRBuilder
53+
code_block: BasicBlock
54+
next_block: BasicBlock
55+
final_block: BasicBlock
56+
subject: Value
57+
match: MatchStmt
58+
59+
as_pattern: Optional[AsPattern] = None
60+
61+
def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None:
62+
self.builder = builder
63+
64+
self.code_block = BasicBlock()
65+
self.next_block = BasicBlock()
66+
self.final_block = BasicBlock()
67+
68+
self.match = match_node
69+
self.subject = builder.accept(match_node.subject)
70+
71+
def build_match_body(self, index: int) -> None:
72+
self.builder.activate_block(self.code_block)
73+
74+
guard = self.match.guards[index]
75+
76+
if guard:
77+
self.code_block = BasicBlock()
78+
79+
cond = self.builder.accept(guard)
80+
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
81+
82+
self.builder.activate_block(self.code_block)
83+
84+
self.builder.accept(self.match.bodies[index])
85+
self.builder.goto(self.final_block)
86+
87+
def visit_match_stmt(self, m: MatchStmt) -> None:
88+
for i, pattern in enumerate(m.patterns):
89+
self.code_block = BasicBlock()
90+
self.next_block = BasicBlock()
91+
92+
pattern.accept(self)
93+
94+
self.build_match_body(i)
95+
self.builder.activate_block(self.next_block)
96+
97+
self.builder.goto_and_activate(self.final_block)
98+
99+
def visit_value_pattern(self, pattern: ValuePattern) -> None:
100+
value = self.builder.accept(pattern.expr)
101+
102+
cond = self.builder.binary_op(self.subject, value, "==", pattern.expr.line)
103+
104+
self.bind_as_pattern(value)
105+
106+
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
107+
108+
def visit_or_pattern(self, pattern: OrPattern) -> None:
109+
backup_block = self.next_block
110+
self.next_block = BasicBlock()
111+
112+
for p in pattern.patterns:
113+
# Hack to ensure the as pattern is bound to each pattern in the
114+
# "or" pattern, but not every subpattern
115+
backup = self.as_pattern
116+
p.accept(self)
117+
self.as_pattern = backup
118+
119+
self.builder.activate_block(self.next_block)
120+
self.next_block = BasicBlock()
121+
122+
self.next_block = backup_block
123+
self.builder.goto(self.next_block)
124+
125+
def visit_class_pattern(self, pattern: ClassPattern) -> None:
126+
# TODO: use faster instance check for native classes (while still
127+
# making sure to account for inheritence)
128+
isinstance_op = (
129+
fast_isinstance_op
130+
if self.builder.is_builtin_ref_expr(pattern.class_ref)
131+
else slow_isinstance_op
132+
)
133+
134+
cond = self.builder.call_c(
135+
isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line
136+
)
137+
138+
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
139+
140+
self.bind_as_pattern(self.subject, new_block=True)
141+
142+
if pattern.positionals:
143+
if pattern.class_ref.fullname in MATCHABLE_BUILTINS:
144+
self.builder.activate_block(self.code_block)
145+
self.code_block = BasicBlock()
146+
147+
pattern.positionals[0].accept(self)
148+
149+
return
150+
151+
node = pattern.class_ref.node
152+
assert isinstance(node, TypeInfo)
153+
154+
ty = node.names.get("__match_args__")
155+
assert ty
156+
157+
match_args_type = get_proper_type(ty.type)
158+
assert isinstance(match_args_type, TupleType)
159+
160+
match_args: List[str] = []
161+
162+
for item in match_args_type.items:
163+
proper_item = get_proper_type(item)
164+
assert isinstance(proper_item, Instance) and proper_item.last_known_value
165+
166+
match_arg = proper_item.last_known_value.value
167+
assert isinstance(match_arg, str)
168+
169+
match_args.append(match_arg)
170+
171+
for i, expr in enumerate(pattern.positionals):
172+
self.builder.activate_block(self.code_block)
173+
self.code_block = BasicBlock()
174+
175+
# TODO: use faster "get_attr" method instead when calling on native or
176+
# builtin objects
177+
positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line)
178+
179+
with self.enter_subpattern(positional):
180+
expr.accept(self)
181+
182+
for key, value in zip(pattern.keyword_keys, pattern.keyword_values):
183+
self.builder.activate_block(self.code_block)
184+
self.code_block = BasicBlock()
185+
186+
# TODO: same as above "get_attr" comment
187+
attr = self.builder.py_get_attr(self.subject, key, value.line)
188+
189+
with self.enter_subpattern(attr):
190+
value.accept(self)
191+
192+
def visit_as_pattern(self, pattern: AsPattern) -> None:
193+
if pattern.pattern:
194+
old_pattern = self.as_pattern
195+
self.as_pattern = pattern
196+
pattern.pattern.accept(self)
197+
self.as_pattern = old_pattern
198+
199+
elif pattern.name:
200+
target = self.builder.get_assignment_target(pattern.name)
201+
202+
self.builder.assign(target, self.subject, pattern.line)
203+
204+
self.builder.goto(self.code_block)
205+
206+
def visit_singleton_pattern(self, pattern: SingletonPattern) -> None:
207+
if pattern.value is None:
208+
obj = self.builder.none_object()
209+
elif pattern.value is True:
210+
obj = self.builder.true()
211+
else:
212+
obj = self.builder.false()
213+
214+
cond = self.builder.binary_op(self.subject, obj, "is", pattern.line)
215+
216+
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
217+
218+
def visit_mapping_pattern(self, pattern: MappingPattern) -> None:
219+
is_dict = self.builder.call_c(supports_mapping_protocol, [self.subject], pattern.line)
220+
221+
self.builder.add_bool_branch(is_dict, self.code_block, self.next_block)
222+
223+
keys: List[Value] = []
224+
225+
for key, value in zip(pattern.keys, pattern.values):
226+
self.builder.activate_block(self.code_block)
227+
self.code_block = BasicBlock()
228+
229+
key_value = self.builder.accept(key)
230+
keys.append(key_value)
231+
232+
exists = self.builder.call_c(mapping_has_key, [self.subject, key_value], pattern.line)
233+
234+
self.builder.add_bool_branch(exists, self.code_block, self.next_block)
235+
self.builder.activate_block(self.code_block)
236+
self.code_block = BasicBlock()
237+
238+
item = self.builder.gen_method_call(
239+
self.subject, "__getitem__", [key_value], object_rprimitive, pattern.line
240+
)
241+
242+
with self.enter_subpattern(item):
243+
value.accept(self)
244+
245+
if pattern.rest:
246+
self.builder.activate_block(self.code_block)
247+
self.code_block = BasicBlock()
248+
249+
rest = self.builder.call_c(dict_copy, [self.subject], pattern.rest.line)
250+
251+
target = self.builder.get_assignment_target(pattern.rest)
252+
253+
self.builder.assign(target, rest, pattern.rest.line)
254+
255+
for i, key_name in enumerate(keys):
256+
self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line)
257+
258+
self.builder.goto(self.code_block)
259+
260+
def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None:
261+
star_index, capture, patterns = prep_sequence_pattern(seq_pattern)
262+
263+
is_list = self.builder.call_c(supports_sequence_protocol, [self.subject], seq_pattern.line)
264+
265+
self.builder.add_bool_branch(is_list, self.code_block, self.next_block)
266+
267+
self.builder.activate_block(self.code_block)
268+
self.code_block = BasicBlock()
269+
270+
actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line)
271+
min_len = len(patterns)
272+
273+
is_long_enough = self.builder.binary_op(
274+
actual_len,
275+
self.builder.load_int(min_len),
276+
"==" if star_index is None else ">=",
277+
seq_pattern.line,
278+
)
279+
280+
self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block)
281+
282+
for i, pattern in enumerate(patterns):
283+
self.builder.activate_block(self.code_block)
284+
self.code_block = BasicBlock()
285+
286+
if star_index is not None and i >= star_index:
287+
current = self.builder.binary_op(
288+
actual_len, self.builder.load_int(min_len - i), "-", pattern.line
289+
)
290+
291+
else:
292+
current = self.builder.load_int(i)
293+
294+
item = self.builder.call_c(sequence_get_item, [self.subject, current], pattern.line)
295+
296+
with self.enter_subpattern(item):
297+
pattern.accept(self)
298+
299+
if capture and star_index is not None:
300+
self.builder.activate_block(self.code_block)
301+
self.code_block = BasicBlock()
302+
303+
capture_end = self.builder.binary_op(
304+
actual_len, self.builder.load_int(min_len - star_index), "-", capture.line
305+
)
306+
307+
rest = self.builder.call_c(
308+
sequence_get_slice,
309+
[self.subject, self.builder.load_int(star_index), capture_end],
310+
capture.line,
311+
)
312+
313+
target = self.builder.get_assignment_target(capture)
314+
self.builder.assign(target, rest, capture.line)
315+
316+
self.builder.goto(self.code_block)
317+
318+
def bind_as_pattern(self, value: Value, new_block: bool = False) -> None:
319+
if self.as_pattern and self.as_pattern.pattern and self.as_pattern.name:
320+
if new_block:
321+
self.builder.activate_block(self.code_block)
322+
self.code_block = BasicBlock()
323+
324+
target = self.builder.get_assignment_target(self.as_pattern.name)
325+
self.builder.assign(target, value, self.as_pattern.pattern.line)
326+
327+
self.as_pattern = None
328+
329+
if new_block:
330+
self.builder.goto(self.code_block)
331+
332+
@contextmanager
333+
def enter_subpattern(self, subject: Value) -> Generator[None, None, None]:
334+
old_subject = self.subject
335+
self.subject = subject
336+
yield
337+
self.subject = old_subject
338+
339+
340+
def prep_sequence_pattern(
341+
seq_pattern: SequencePattern,
342+
) -> Tuple[Optional[int], Optional[NameExpr], List[Pattern]]:
343+
star_index: Optional[int] = None
344+
capture: Optional[NameExpr] = None
345+
patterns: List[Pattern] = []
346+
347+
for i, pattern in enumerate(seq_pattern.patterns):
348+
if isinstance(pattern, StarredPattern):
349+
star_index = i
350+
capture = pattern.capture
351+
352+
else:
353+
patterns.append(pattern)
354+
355+
return star_index, capture, patterns

mypyc/irbuild/prepare.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,11 +231,7 @@ def prepare_class_def(
231231

232232
if isinstance(node.node, Var):
233233
assert node.node.type, "Class member %s missing type" % name
234-
if not node.node.is_classvar and name not in (
235-
"__slots__",
236-
"__deletable__",
237-
"__match_args__",
238-
):
234+
if not node.node.is_classvar and name not in ("__slots__", "__deletable__"):
239235
ir.attributes[name] = mapper.type_to_rtype(node.node.type)
240236
elif isinstance(node.node, (FuncDef, Decorator)):
241237
prepare_method_def(ir, module_name, cdef, mapper, node.node)

0 commit comments

Comments
 (0)