Skip to content

Commit

Permalink
[mypyc] Add match statement support (#13953)
Browse files Browse the repository at this point in the history
Closes mypyc/mypyc#911

Like the title says, this PR adds support for compiling `match`
statements in mypyc. Most of the work has been done, but there are some
things which are still a WIP.

A todo list of what has been done, and the (small) number of things that
need to be worked out:

- [x] Or patterns: `1 | 2 | 3`
- [x] Value patterns: `123`, `x.y.z`, etc.
- [x] Singleton patterns: `True`, `False`, and `None`
- [x] Sequence patterns:
  - [x] Fixed length patterns `[1, 2, 3]`
  - [x] Starred patterns `[*prev, 4, 5, 6]`, `[1, 2, 3, *rest]`, etc:
    - [x] `[*rest]` is currently not working, but should be an easy fix
- [x] Support any object which supports the [Sequence
Protocol](https://docs.python.org/3/c-api/sequence.html) (need help with
this)
- [x] Mapping Pattern (`{"key": value}`):
  - [x] General support
  - [x] Starred patterns: `{"key": value, **rest}`
- [x] Support any object which supports the [Mapping
Protocol](https://docs.python.org/3/c-api/mapping.html) (need help with
this)
- [x] Class patterns:
  - [x] Basic class `isinstance()` check
  - [x] Positional args: `Class(1, 2, 3)` 
  - [x] Keyword args: `Class(x=1, y=2, z=3)`
  - [x] Shortcut for built-in datatypes: `int(x)` -> `int() as x`
- [x] Capture patterns:
  - [x] Wildcard pattern: `_`
  - [x] As pattern: `123 as num`
  - [x] Capture pattern: `x`

Some features which I was unsure how to implement are:

* Fix `*rest` and `**rest` star patterns name collisions. Basically, you
cannot use `rest` (or any other name) twice in the same match statement
if `rest` is a different type (ie, `dict` vs `list`). If it was defined
as `object` instead of `dict`/`list` everything would be fine.

Also some operations on native classes and primitive types could be 
optimized.
  • Loading branch information
dosisod authored Dec 2, 2022
1 parent 740b364 commit d5e96e3
Show file tree
Hide file tree
Showing 16 changed files with 2,430 additions and 9 deletions.
2 changes: 1 addition & 1 deletion mypyc/irbuild/classdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def find_attr_initializers(
and not isinstance(stmt.rvalue, TempNode)
):
name = stmt.lvalues[0].name
if name in ("__slots__", "__match_args__"):
if name == "__slots__":
continue

if name == "__deletable__":
Expand Down
355 changes: 355 additions & 0 deletions mypyc/irbuild/match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,355 @@
from contextlib import contextmanager
from typing import Generator, List, Optional, Tuple

from mypy.nodes import MatchStmt, NameExpr, TypeInfo
from mypy.patterns import (
AsPattern,
ClassPattern,
MappingPattern,
OrPattern,
Pattern,
SequencePattern,
SingletonPattern,
StarredPattern,
ValuePattern,
)
from mypy.traverser import TraverserVisitor
from mypy.types import Instance, TupleType, get_proper_type
from mypyc.ir.ops import BasicBlock, Value
from mypyc.ir.rtypes import object_rprimitive
from mypyc.irbuild.builder import IRBuilder
from mypyc.primitives.dict_ops import (
dict_copy,
dict_del_item,
mapping_has_key,
supports_mapping_protocol,
)
from mypyc.primitives.generic_ops import generic_ssize_t_len_op
from mypyc.primitives.list_ops import (
sequence_get_item,
sequence_get_slice,
supports_sequence_protocol,
)
from mypyc.primitives.misc_ops import fast_isinstance_op, slow_isinstance_op

# From: https://peps.python.org/pep-0634/#class-patterns
MATCHABLE_BUILTINS = {
"builtins.bool",
"builtins.bytearray",
"builtins.bytes",
"builtins.dict",
"builtins.float",
"builtins.frozenset",
"builtins.int",
"builtins.list",
"builtins.set",
"builtins.str",
"builtins.tuple",
}


class MatchVisitor(TraverserVisitor):
builder: IRBuilder
code_block: BasicBlock
next_block: BasicBlock
final_block: BasicBlock
subject: Value
match: MatchStmt

as_pattern: Optional[AsPattern] = None

def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None:
self.builder = builder

self.code_block = BasicBlock()
self.next_block = BasicBlock()
self.final_block = BasicBlock()

self.match = match_node
self.subject = builder.accept(match_node.subject)

def build_match_body(self, index: int) -> None:
self.builder.activate_block(self.code_block)

guard = self.match.guards[index]

if guard:
self.code_block = BasicBlock()

cond = self.builder.accept(guard)
self.builder.add_bool_branch(cond, self.code_block, self.next_block)

self.builder.activate_block(self.code_block)

self.builder.accept(self.match.bodies[index])
self.builder.goto(self.final_block)

def visit_match_stmt(self, m: MatchStmt) -> None:
for i, pattern in enumerate(m.patterns):
self.code_block = BasicBlock()
self.next_block = BasicBlock()

pattern.accept(self)

self.build_match_body(i)
self.builder.activate_block(self.next_block)

self.builder.goto_and_activate(self.final_block)

def visit_value_pattern(self, pattern: ValuePattern) -> None:
value = self.builder.accept(pattern.expr)

cond = self.builder.binary_op(self.subject, value, "==", pattern.expr.line)

self.bind_as_pattern(value)

self.builder.add_bool_branch(cond, self.code_block, self.next_block)

def visit_or_pattern(self, pattern: OrPattern) -> None:
backup_block = self.next_block
self.next_block = BasicBlock()

for p in pattern.patterns:
# Hack to ensure the as pattern is bound to each pattern in the
# "or" pattern, but not every subpattern
backup = self.as_pattern
p.accept(self)
self.as_pattern = backup

self.builder.activate_block(self.next_block)
self.next_block = BasicBlock()

self.next_block = backup_block
self.builder.goto(self.next_block)

def visit_class_pattern(self, pattern: ClassPattern) -> None:
# TODO: use faster instance check for native classes (while still
# making sure to account for inheritence)
isinstance_op = (
fast_isinstance_op
if self.builder.is_builtin_ref_expr(pattern.class_ref)
else slow_isinstance_op
)

cond = self.builder.call_c(
isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line
)

self.builder.add_bool_branch(cond, self.code_block, self.next_block)

self.bind_as_pattern(self.subject, new_block=True)

if pattern.positionals:
if pattern.class_ref.fullname in MATCHABLE_BUILTINS:
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

pattern.positionals[0].accept(self)

return

node = pattern.class_ref.node
assert isinstance(node, TypeInfo)

ty = node.names.get("__match_args__")
assert ty

match_args_type = get_proper_type(ty.type)
assert isinstance(match_args_type, TupleType)

match_args: List[str] = []

for item in match_args_type.items:
proper_item = get_proper_type(item)
assert isinstance(proper_item, Instance) and proper_item.last_known_value

match_arg = proper_item.last_known_value.value
assert isinstance(match_arg, str)

match_args.append(match_arg)

for i, expr in enumerate(pattern.positionals):
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

# TODO: use faster "get_attr" method instead when calling on native or
# builtin objects
positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line)

with self.enter_subpattern(positional):
expr.accept(self)

for key, value in zip(pattern.keyword_keys, pattern.keyword_values):
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

# TODO: same as above "get_attr" comment
attr = self.builder.py_get_attr(self.subject, key, value.line)

with self.enter_subpattern(attr):
value.accept(self)

def visit_as_pattern(self, pattern: AsPattern) -> None:
if pattern.pattern:
old_pattern = self.as_pattern
self.as_pattern = pattern
pattern.pattern.accept(self)
self.as_pattern = old_pattern

elif pattern.name:
target = self.builder.get_assignment_target(pattern.name)

self.builder.assign(target, self.subject, pattern.line)

self.builder.goto(self.code_block)

def visit_singleton_pattern(self, pattern: SingletonPattern) -> None:
if pattern.value is None:
obj = self.builder.none_object()
elif pattern.value is True:
obj = self.builder.true()
else:
obj = self.builder.false()

cond = self.builder.binary_op(self.subject, obj, "is", pattern.line)

self.builder.add_bool_branch(cond, self.code_block, self.next_block)

def visit_mapping_pattern(self, pattern: MappingPattern) -> None:
is_dict = self.builder.call_c(supports_mapping_protocol, [self.subject], pattern.line)

self.builder.add_bool_branch(is_dict, self.code_block, self.next_block)

keys: List[Value] = []

for key, value in zip(pattern.keys, pattern.values):
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

key_value = self.builder.accept(key)
keys.append(key_value)

exists = self.builder.call_c(mapping_has_key, [self.subject, key_value], pattern.line)

self.builder.add_bool_branch(exists, self.code_block, self.next_block)
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

item = self.builder.gen_method_call(
self.subject, "__getitem__", [key_value], object_rprimitive, pattern.line
)

with self.enter_subpattern(item):
value.accept(self)

if pattern.rest:
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

rest = self.builder.call_c(dict_copy, [self.subject], pattern.rest.line)

target = self.builder.get_assignment_target(pattern.rest)

self.builder.assign(target, rest, pattern.rest.line)

for i, key_name in enumerate(keys):
self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line)

self.builder.goto(self.code_block)

def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None:
star_index, capture, patterns = prep_sequence_pattern(seq_pattern)

is_list = self.builder.call_c(supports_sequence_protocol, [self.subject], seq_pattern.line)

self.builder.add_bool_branch(is_list, self.code_block, self.next_block)

self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line)
min_len = len(patterns)

is_long_enough = self.builder.binary_op(
actual_len,
self.builder.load_int(min_len),
"==" if star_index is None else ">=",
seq_pattern.line,
)

self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block)

for i, pattern in enumerate(patterns):
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

if star_index is not None and i >= star_index:
current = self.builder.binary_op(
actual_len, self.builder.load_int(min_len - i), "-", pattern.line
)

else:
current = self.builder.load_int(i)

item = self.builder.call_c(sequence_get_item, [self.subject, current], pattern.line)

with self.enter_subpattern(item):
pattern.accept(self)

if capture and star_index is not None:
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

capture_end = self.builder.binary_op(
actual_len, self.builder.load_int(min_len - star_index), "-", capture.line
)

rest = self.builder.call_c(
sequence_get_slice,
[self.subject, self.builder.load_int(star_index), capture_end],
capture.line,
)

target = self.builder.get_assignment_target(capture)
self.builder.assign(target, rest, capture.line)

self.builder.goto(self.code_block)

def bind_as_pattern(self, value: Value, new_block: bool = False) -> None:
if self.as_pattern and self.as_pattern.pattern and self.as_pattern.name:
if new_block:
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()

target = self.builder.get_assignment_target(self.as_pattern.name)
self.builder.assign(target, value, self.as_pattern.pattern.line)

self.as_pattern = None

if new_block:
self.builder.goto(self.code_block)

@contextmanager
def enter_subpattern(self, subject: Value) -> Generator[None, None, None]:
old_subject = self.subject
self.subject = subject
yield
self.subject = old_subject


def prep_sequence_pattern(
seq_pattern: SequencePattern,
) -> Tuple[Optional[int], Optional[NameExpr], List[Pattern]]:
star_index: Optional[int] = None
capture: Optional[NameExpr] = None
patterns: List[Pattern] = []

for i, pattern in enumerate(seq_pattern.patterns):
if isinstance(pattern, StarredPattern):
star_index = i
capture = pattern.capture

else:
patterns.append(pattern)

return star_index, capture, patterns
6 changes: 1 addition & 5 deletions mypyc/irbuild/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,7 @@ def prepare_class_def(

if isinstance(node.node, Var):
assert node.node.type, "Class member %s missing type" % name
if not node.node.is_classvar and name not in (
"__slots__",
"__deletable__",
"__match_args__",
):
if not node.node.is_classvar and name not in ("__slots__", "__deletable__"):
ir.attributes[name] = mapper.type_to_rtype(node.node.type)
elif isinstance(node.node, (FuncDef, Decorator)):
prepare_method_def(ir, module_name, cdef, mapper, node.node)
Expand Down
Loading

0 comments on commit d5e96e3

Please sign in to comment.