-
-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mypyc] Add
match
statement support (#13953)
Closes mypyc/mypyc#911 Like the title says, this PR adds support for compiling `match` statements in mypyc. Most of the work has been done, but there are some things which are still a WIP. A todo list of what has been done, and the (small) number of things that need to be worked out: - [x] Or patterns: `1 | 2 | 3` - [x] Value patterns: `123`, `x.y.z`, etc. - [x] Singleton patterns: `True`, `False`, and `None` - [x] Sequence patterns: - [x] Fixed length patterns `[1, 2, 3]` - [x] Starred patterns `[*prev, 4, 5, 6]`, `[1, 2, 3, *rest]`, etc: - [x] `[*rest]` is currently not working, but should be an easy fix - [x] Support any object which supports the [Sequence Protocol](https://docs.python.org/3/c-api/sequence.html) (need help with this) - [x] Mapping Pattern (`{"key": value}`): - [x] General support - [x] Starred patterns: `{"key": value, **rest}` - [x] Support any object which supports the [Mapping Protocol](https://docs.python.org/3/c-api/mapping.html) (need help with this) - [x] Class patterns: - [x] Basic class `isinstance()` check - [x] Positional args: `Class(1, 2, 3)` - [x] Keyword args: `Class(x=1, y=2, z=3)` - [x] Shortcut for built-in datatypes: `int(x)` -> `int() as x` - [x] Capture patterns: - [x] Wildcard pattern: `_` - [x] As pattern: `123 as num` - [x] Capture pattern: `x` Some features which I was unsure how to implement are: * Fix `*rest` and `**rest` star patterns name collisions. Basically, you cannot use `rest` (or any other name) twice in the same match statement if `rest` is a different type (ie, `dict` vs `list`). If it was defined as `object` instead of `dict`/`list` everything would be fine. Also some operations on native classes and primitive types could be optimized.
- Loading branch information
Showing
16 changed files
with
2,430 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,355 @@ | ||
from contextlib import contextmanager | ||
from typing import Generator, List, Optional, Tuple | ||
|
||
from mypy.nodes import MatchStmt, NameExpr, TypeInfo | ||
from mypy.patterns import ( | ||
AsPattern, | ||
ClassPattern, | ||
MappingPattern, | ||
OrPattern, | ||
Pattern, | ||
SequencePattern, | ||
SingletonPattern, | ||
StarredPattern, | ||
ValuePattern, | ||
) | ||
from mypy.traverser import TraverserVisitor | ||
from mypy.types import Instance, TupleType, get_proper_type | ||
from mypyc.ir.ops import BasicBlock, Value | ||
from mypyc.ir.rtypes import object_rprimitive | ||
from mypyc.irbuild.builder import IRBuilder | ||
from mypyc.primitives.dict_ops import ( | ||
dict_copy, | ||
dict_del_item, | ||
mapping_has_key, | ||
supports_mapping_protocol, | ||
) | ||
from mypyc.primitives.generic_ops import generic_ssize_t_len_op | ||
from mypyc.primitives.list_ops import ( | ||
sequence_get_item, | ||
sequence_get_slice, | ||
supports_sequence_protocol, | ||
) | ||
from mypyc.primitives.misc_ops import fast_isinstance_op, slow_isinstance_op | ||
|
||
# From: https://peps.python.org/pep-0634/#class-patterns | ||
MATCHABLE_BUILTINS = { | ||
"builtins.bool", | ||
"builtins.bytearray", | ||
"builtins.bytes", | ||
"builtins.dict", | ||
"builtins.float", | ||
"builtins.frozenset", | ||
"builtins.int", | ||
"builtins.list", | ||
"builtins.set", | ||
"builtins.str", | ||
"builtins.tuple", | ||
} | ||
|
||
|
||
class MatchVisitor(TraverserVisitor): | ||
builder: IRBuilder | ||
code_block: BasicBlock | ||
next_block: BasicBlock | ||
final_block: BasicBlock | ||
subject: Value | ||
match: MatchStmt | ||
|
||
as_pattern: Optional[AsPattern] = None | ||
|
||
def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None: | ||
self.builder = builder | ||
|
||
self.code_block = BasicBlock() | ||
self.next_block = BasicBlock() | ||
self.final_block = BasicBlock() | ||
|
||
self.match = match_node | ||
self.subject = builder.accept(match_node.subject) | ||
|
||
def build_match_body(self, index: int) -> None: | ||
self.builder.activate_block(self.code_block) | ||
|
||
guard = self.match.guards[index] | ||
|
||
if guard: | ||
self.code_block = BasicBlock() | ||
|
||
cond = self.builder.accept(guard) | ||
self.builder.add_bool_branch(cond, self.code_block, self.next_block) | ||
|
||
self.builder.activate_block(self.code_block) | ||
|
||
self.builder.accept(self.match.bodies[index]) | ||
self.builder.goto(self.final_block) | ||
|
||
def visit_match_stmt(self, m: MatchStmt) -> None: | ||
for i, pattern in enumerate(m.patterns): | ||
self.code_block = BasicBlock() | ||
self.next_block = BasicBlock() | ||
|
||
pattern.accept(self) | ||
|
||
self.build_match_body(i) | ||
self.builder.activate_block(self.next_block) | ||
|
||
self.builder.goto_and_activate(self.final_block) | ||
|
||
def visit_value_pattern(self, pattern: ValuePattern) -> None: | ||
value = self.builder.accept(pattern.expr) | ||
|
||
cond = self.builder.binary_op(self.subject, value, "==", pattern.expr.line) | ||
|
||
self.bind_as_pattern(value) | ||
|
||
self.builder.add_bool_branch(cond, self.code_block, self.next_block) | ||
|
||
def visit_or_pattern(self, pattern: OrPattern) -> None: | ||
backup_block = self.next_block | ||
self.next_block = BasicBlock() | ||
|
||
for p in pattern.patterns: | ||
# Hack to ensure the as pattern is bound to each pattern in the | ||
# "or" pattern, but not every subpattern | ||
backup = self.as_pattern | ||
p.accept(self) | ||
self.as_pattern = backup | ||
|
||
self.builder.activate_block(self.next_block) | ||
self.next_block = BasicBlock() | ||
|
||
self.next_block = backup_block | ||
self.builder.goto(self.next_block) | ||
|
||
def visit_class_pattern(self, pattern: ClassPattern) -> None: | ||
# TODO: use faster instance check for native classes (while still | ||
# making sure to account for inheritence) | ||
isinstance_op = ( | ||
fast_isinstance_op | ||
if self.builder.is_builtin_ref_expr(pattern.class_ref) | ||
else slow_isinstance_op | ||
) | ||
|
||
cond = self.builder.call_c( | ||
isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line | ||
) | ||
|
||
self.builder.add_bool_branch(cond, self.code_block, self.next_block) | ||
|
||
self.bind_as_pattern(self.subject, new_block=True) | ||
|
||
if pattern.positionals: | ||
if pattern.class_ref.fullname in MATCHABLE_BUILTINS: | ||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
pattern.positionals[0].accept(self) | ||
|
||
return | ||
|
||
node = pattern.class_ref.node | ||
assert isinstance(node, TypeInfo) | ||
|
||
ty = node.names.get("__match_args__") | ||
assert ty | ||
|
||
match_args_type = get_proper_type(ty.type) | ||
assert isinstance(match_args_type, TupleType) | ||
|
||
match_args: List[str] = [] | ||
|
||
for item in match_args_type.items: | ||
proper_item = get_proper_type(item) | ||
assert isinstance(proper_item, Instance) and proper_item.last_known_value | ||
|
||
match_arg = proper_item.last_known_value.value | ||
assert isinstance(match_arg, str) | ||
|
||
match_args.append(match_arg) | ||
|
||
for i, expr in enumerate(pattern.positionals): | ||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
# TODO: use faster "get_attr" method instead when calling on native or | ||
# builtin objects | ||
positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line) | ||
|
||
with self.enter_subpattern(positional): | ||
expr.accept(self) | ||
|
||
for key, value in zip(pattern.keyword_keys, pattern.keyword_values): | ||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
# TODO: same as above "get_attr" comment | ||
attr = self.builder.py_get_attr(self.subject, key, value.line) | ||
|
||
with self.enter_subpattern(attr): | ||
value.accept(self) | ||
|
||
def visit_as_pattern(self, pattern: AsPattern) -> None: | ||
if pattern.pattern: | ||
old_pattern = self.as_pattern | ||
self.as_pattern = pattern | ||
pattern.pattern.accept(self) | ||
self.as_pattern = old_pattern | ||
|
||
elif pattern.name: | ||
target = self.builder.get_assignment_target(pattern.name) | ||
|
||
self.builder.assign(target, self.subject, pattern.line) | ||
|
||
self.builder.goto(self.code_block) | ||
|
||
def visit_singleton_pattern(self, pattern: SingletonPattern) -> None: | ||
if pattern.value is None: | ||
obj = self.builder.none_object() | ||
elif pattern.value is True: | ||
obj = self.builder.true() | ||
else: | ||
obj = self.builder.false() | ||
|
||
cond = self.builder.binary_op(self.subject, obj, "is", pattern.line) | ||
|
||
self.builder.add_bool_branch(cond, self.code_block, self.next_block) | ||
|
||
def visit_mapping_pattern(self, pattern: MappingPattern) -> None: | ||
is_dict = self.builder.call_c(supports_mapping_protocol, [self.subject], pattern.line) | ||
|
||
self.builder.add_bool_branch(is_dict, self.code_block, self.next_block) | ||
|
||
keys: List[Value] = [] | ||
|
||
for key, value in zip(pattern.keys, pattern.values): | ||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
key_value = self.builder.accept(key) | ||
keys.append(key_value) | ||
|
||
exists = self.builder.call_c(mapping_has_key, [self.subject, key_value], pattern.line) | ||
|
||
self.builder.add_bool_branch(exists, self.code_block, self.next_block) | ||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
item = self.builder.gen_method_call( | ||
self.subject, "__getitem__", [key_value], object_rprimitive, pattern.line | ||
) | ||
|
||
with self.enter_subpattern(item): | ||
value.accept(self) | ||
|
||
if pattern.rest: | ||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
rest = self.builder.call_c(dict_copy, [self.subject], pattern.rest.line) | ||
|
||
target = self.builder.get_assignment_target(pattern.rest) | ||
|
||
self.builder.assign(target, rest, pattern.rest.line) | ||
|
||
for i, key_name in enumerate(keys): | ||
self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line) | ||
|
||
self.builder.goto(self.code_block) | ||
|
||
def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None: | ||
star_index, capture, patterns = prep_sequence_pattern(seq_pattern) | ||
|
||
is_list = self.builder.call_c(supports_sequence_protocol, [self.subject], seq_pattern.line) | ||
|
||
self.builder.add_bool_branch(is_list, self.code_block, self.next_block) | ||
|
||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line) | ||
min_len = len(patterns) | ||
|
||
is_long_enough = self.builder.binary_op( | ||
actual_len, | ||
self.builder.load_int(min_len), | ||
"==" if star_index is None else ">=", | ||
seq_pattern.line, | ||
) | ||
|
||
self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block) | ||
|
||
for i, pattern in enumerate(patterns): | ||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
if star_index is not None and i >= star_index: | ||
current = self.builder.binary_op( | ||
actual_len, self.builder.load_int(min_len - i), "-", pattern.line | ||
) | ||
|
||
else: | ||
current = self.builder.load_int(i) | ||
|
||
item = self.builder.call_c(sequence_get_item, [self.subject, current], pattern.line) | ||
|
||
with self.enter_subpattern(item): | ||
pattern.accept(self) | ||
|
||
if capture and star_index is not None: | ||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
capture_end = self.builder.binary_op( | ||
actual_len, self.builder.load_int(min_len - star_index), "-", capture.line | ||
) | ||
|
||
rest = self.builder.call_c( | ||
sequence_get_slice, | ||
[self.subject, self.builder.load_int(star_index), capture_end], | ||
capture.line, | ||
) | ||
|
||
target = self.builder.get_assignment_target(capture) | ||
self.builder.assign(target, rest, capture.line) | ||
|
||
self.builder.goto(self.code_block) | ||
|
||
def bind_as_pattern(self, value: Value, new_block: bool = False) -> None: | ||
if self.as_pattern and self.as_pattern.pattern and self.as_pattern.name: | ||
if new_block: | ||
self.builder.activate_block(self.code_block) | ||
self.code_block = BasicBlock() | ||
|
||
target = self.builder.get_assignment_target(self.as_pattern.name) | ||
self.builder.assign(target, value, self.as_pattern.pattern.line) | ||
|
||
self.as_pattern = None | ||
|
||
if new_block: | ||
self.builder.goto(self.code_block) | ||
|
||
@contextmanager | ||
def enter_subpattern(self, subject: Value) -> Generator[None, None, None]: | ||
old_subject = self.subject | ||
self.subject = subject | ||
yield | ||
self.subject = old_subject | ||
|
||
|
||
def prep_sequence_pattern( | ||
seq_pattern: SequencePattern, | ||
) -> Tuple[Optional[int], Optional[NameExpr], List[Pattern]]: | ||
star_index: Optional[int] = None | ||
capture: Optional[NameExpr] = None | ||
patterns: List[Pattern] = [] | ||
|
||
for i, pattern in enumerate(seq_pattern.patterns): | ||
if isinstance(pattern, StarredPattern): | ||
star_index = i | ||
capture = pattern.capture | ||
|
||
else: | ||
patterns.append(pattern) | ||
|
||
return star_index, capture, patterns |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.