Skip to content

Commit

Permalink
core: add replacement arg_values to inline_block_before (#2061)
Browse files Browse the repository at this point in the history
Slightly modifies the semantics of inlining, bringing it in line with
MLIR's behaviour, in my understanding. Currently, xDSL empties the block
when inlining it to a different block, but does not erase it. With this
PR, this behaviour is changed to erasing the block, potentially leaving
an empty region.
  • Loading branch information
superlopuh authored Feb 1, 2024
1 parent 9a4b28c commit 15ecd0c
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 15 deletions.
67 changes: 60 additions & 7 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,9 +766,8 @@ def test_inline_block_before_matched_op():
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() : () -> !test.type<"int">
%2 = "test.op"() ({
^0:
}, {
^1:
^0:
}) : () -> !test.type<"int">
}) : () -> ()
"""
Expand Down Expand Up @@ -810,12 +809,11 @@ def test_inline_block_before():
%1 = "test.op"() ({
%2 = "test.op"() : () -> !test.type<"int">
%3 = "test.op"() ({
^0:
}, {
^1:
^0:
}) : () -> !test.type<"int">
}, {
^2:
^1:
}) : () -> !test.type<"int">
}) : () -> ()
"""
Expand Down Expand Up @@ -857,9 +855,8 @@ def test_inline_block_at_before_when_op_is_matched_op():
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() : () -> !test.type<"int">
%2 = "test.op"() ({
^0:
}, {
^1:
^0:
}) : () -> !test.type<"int">
}) : () -> ()
"""
Expand All @@ -877,6 +874,62 @@ def match_and_rewrite(self, matched_op: test.TestOp, rewriter: PatternRewriter):
)


def test_inline_block_before_with_args():
"""Test the inlining of a block before an operation."""

prog = """\
"builtin.module"() ({
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() ({
^0(%arg0 : !test.type<"int">):
%1 = "test.op"() ({
^1(%arg1 : !test.type<"int">):
%1 = "test.op"(%arg1) : (!test.type<"int">) -> !test.type<"int">
}, {
^2:
}) : () -> !test.type<"int">
}, {
^3:
}) : () -> !test.type<"int">
}) : () -> ()
"""

expected = """\
"builtin.module"() ({
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() ({
^0(%arg0 : !test.type<"int">):
%2 = "test.op"(%arg0) : (!test.type<"int">) -> !test.type<"int">
%3 = "test.op"() ({
}, {
^1:
}) : () -> !test.type<"int">
}, {
^2:
}) : () -> !test.type<"int">
}) : () -> ()
"""

class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, matched_op: test.TestOp, rewriter: PatternRewriter):
if matched_op.regs and matched_op.regs[0].blocks:
outer_block = matched_op.regs[0].blocks[0]
first_op = outer_block.first_op

if isinstance(first_op, test.TestOp):
inner_block = first_op.regs[0].blocks[0]
rewriter.inline_block_before(
inner_block, first_op, outer_block.args
)

rewrite_and_compare(
prog,
expected,
PatternRewriteWalker(Rewrite(), apply_recursively=False),
)


def test_inline_block_after():
"""Test the inlining of a block after an operation."""

Expand Down
1 change: 0 additions & 1 deletion tests/test_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def test_inline_block_before():
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() : () -> !test.type<"int">
"test.op"() ({
^0:
}) : () -> ()
}) : () -> ()
"""
Expand Down
6 changes: 4 additions & 2 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,15 @@ def inline_block_before_matched_op(self, block: Block):
"""
self.inline_block_before(block, self.current_operation)

def inline_block_before(self, block: Block, op: Operation):
def inline_block_before(
self, block: Block, op: Operation, arg_values: Sequence[SSAValue] = ()
):
"""
Move the block operations before the given operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
Rewriter.inline_block_before(block, op)
Rewriter.inline_block_before(block, op, arg_values=arg_values)

def inline_block_after_matched_op(self, block: Block):
"""
Expand Down
50 changes: 45 additions & 5 deletions xdsl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,60 @@ def inline_block_at_start(inlined_block: Block, extended_block: Block):
Rewriter.inline_block_before(inlined_block, first_op_of_extended_block)

@staticmethod
def inline_block_before(block: Block, op: Operation):
def inline_block_before(
source: Block, op: Operation, arg_values: Sequence[SSAValue] = ()
):
"""
Move the block operations before another operation.
The block should not be a parent of the operation.
The block operations should not use the block arguments.
"""
if op.parent is None:
# MLIR equivalent:
# https://github.com/llvm/llvm-project/blob/96a3d05ed923d2abd51acb52984b83b9e8044924/mlir/lib/IR/PatternMatch.cpp#L290
assert len(arg_values) == len(source.args), (
f"Expected {len(source.args)} replacement argument values, got "
f"{len(arg_values)}"
)

# The source block will be deleted, so it should not have any users (i.e.,
# there should be no predecessors).
# TODO: check that the block has no predecessors

# assert not block.predecessors, "expected 'source' to have no predecessors"

if (dest := op.parent) is None:
raise Exception("Cannot inline a block before a toplevel operation")

ops = list(block.ops)
# TODO: verify that the successors will make sense after inlining
# We currently cannot perform this check, just like the TODO above, due to lack
# of infrastructure in xDSL
# https://github.com/xdslproject/xdsl/issues/2066

# if dest.last_op != op:
# The source block will be inserted in the middle of the dest block, so the
# source block should have no successors. Otherwise, the remainder of the dest
# block would be unreachable.
# assert not source.successors, "expected 'source' to have no successors");
# else:
# The source block will be inserted at the end of the dest block, so the dest
# block should have no successors. Otherwise, the inserted operations will be
# unreachable.
# assert not dest.successors, "expected 'dest' to have no successors");

# Replace all of the successor arguments with the provided values.
for arg, val in zip(source.args, arg_values, strict=True):
arg.replace_by(val)

# Move operations from the source block to the dest block and erase the
# source block.
ops = list(source.ops)
for block_op in ops:
block_op.detach()

op.parent.insert_ops_before(ops, op)
dest.insert_ops_before(ops, op)
parent_region = source.parent
assert parent_region is not None
parent_region.detach_block(source)
source.erase()

@staticmethod
def inline_block_after(block: Block, op: Operation):
Expand Down

0 comments on commit 15ecd0c

Please sign in to comment.