From 15ecd0c31a1a8a59a02a877ab645d6744e84dedc Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Thu, 1 Feb 2024 15:32:43 +0000 Subject: [PATCH] core: add replacement arg_values to inline_block_before (#2061) Slightly modifies the semantics of inlining, bringing it in line with MLIR's behaviour, in my understanding. Currently, xDSL empties the block when inlining it to a different block, but does not erase it. With this PR, this behaviour is changed to erasing the block, potentially leaving an empty region. --- .../pattern_rewriter/test_pattern_rewriter.py | 67 +++++++++++++++++-- tests/test_rewriter.py | 1 - xdsl/pattern_rewriter.py | 6 +- xdsl/rewriter.py | 50 ++++++++++++-- 4 files changed, 109 insertions(+), 15 deletions(-) diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index 597ca9296b..da730ae064 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -766,9 +766,8 @@ def test_inline_block_before_matched_op(): %0 = "test.op"() : () -> !test.type<"int"> %1 = "test.op"() : () -> !test.type<"int"> %2 = "test.op"() ({ - ^0: }, { - ^1: + ^0: }) : () -> !test.type<"int"> }) : () -> () """ @@ -810,12 +809,11 @@ def test_inline_block_before(): %1 = "test.op"() ({ %2 = "test.op"() : () -> !test.type<"int"> %3 = "test.op"() ({ - ^0: }, { - ^1: + ^0: }) : () -> !test.type<"int"> }, { - ^2: + ^1: }) : () -> !test.type<"int"> }) : () -> () """ @@ -857,9 +855,8 @@ def test_inline_block_at_before_when_op_is_matched_op(): %0 = "test.op"() : () -> !test.type<"int"> %1 = "test.op"() : () -> !test.type<"int"> %2 = "test.op"() ({ - ^0: }, { - ^1: + ^0: }) : () -> !test.type<"int"> }) : () -> () """ @@ -877,6 +874,62 @@ def match_and_rewrite(self, matched_op: test.TestOp, rewriter: PatternRewriter): ) +def test_inline_block_before_with_args(): + """Test the inlining of a block before an operation.""" + + prog = """\ +"builtin.module"() ({ + %0 = "test.op"() : () -> !test.type<"int"> + %1 = "test.op"() ({ + ^0(%arg0 : !test.type<"int">): + %1 = "test.op"() ({ + ^1(%arg1 : !test.type<"int">): + %1 = "test.op"(%arg1) : (!test.type<"int">) -> !test.type<"int"> + }, { + ^2: + }) : () -> !test.type<"int"> + }, { + ^3: + }) : () -> !test.type<"int"> +}) : () -> () +""" + + expected = """\ +"builtin.module"() ({ + %0 = "test.op"() : () -> !test.type<"int"> + %1 = "test.op"() ({ + ^0(%arg0 : !test.type<"int">): + %2 = "test.op"(%arg0) : (!test.type<"int">) -> !test.type<"int"> + %3 = "test.op"() ({ + }, { + ^1: + }) : () -> !test.type<"int"> + }, { + ^2: + }) : () -> !test.type<"int"> +}) : () -> () +""" + + class Rewrite(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, matched_op: test.TestOp, rewriter: PatternRewriter): + if matched_op.regs and matched_op.regs[0].blocks: + outer_block = matched_op.regs[0].blocks[0] + first_op = outer_block.first_op + + if isinstance(first_op, test.TestOp): + inner_block = first_op.regs[0].blocks[0] + rewriter.inline_block_before( + inner_block, first_op, outer_block.args + ) + + rewrite_and_compare( + prog, + expected, + PatternRewriteWalker(Rewrite(), apply_recursively=False), + ) + + def test_inline_block_after(): """Test the inlining of a block after an operation.""" diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py index c667b3ffa9..e013a836e9 100644 --- a/tests/test_rewriter.py +++ b/tests/test_rewriter.py @@ -182,7 +182,6 @@ def test_inline_block_before(): %0 = "test.op"() : () -> !test.type<"int"> %1 = "test.op"() : () -> !test.type<"int"> "test.op"() ({ - ^0: }) : () -> () }) : () -> () """ diff --git a/xdsl/pattern_rewriter.py b/xdsl/pattern_rewriter.py index 1698632a9d..9f68cb8f3b 100644 --- a/xdsl/pattern_rewriter.py +++ b/xdsl/pattern_rewriter.py @@ -283,13 +283,15 @@ def inline_block_before_matched_op(self, block: Block): """ self.inline_block_before(block, self.current_operation) - def inline_block_before(self, block: Block, op: Operation): + def inline_block_before( + self, block: Block, op: Operation, arg_values: Sequence[SSAValue] = () + ): """ Move the block operations before the given operation. The block should not be a parent of the operation. """ self.has_done_action = True - Rewriter.inline_block_before(block, op) + Rewriter.inline_block_before(block, op, arg_values=arg_values) def inline_block_after_matched_op(self, block: Block): """ diff --git a/xdsl/rewriter.py b/xdsl/rewriter.py index 425f27ff44..6e89dbdabb 100644 --- a/xdsl/rewriter.py +++ b/xdsl/rewriter.py @@ -102,20 +102,60 @@ def inline_block_at_start(inlined_block: Block, extended_block: Block): Rewriter.inline_block_before(inlined_block, first_op_of_extended_block) @staticmethod - def inline_block_before(block: Block, op: Operation): + def inline_block_before( + source: Block, op: Operation, arg_values: Sequence[SSAValue] = () + ): """ Move the block operations before another operation. The block should not be a parent of the operation. - The block operations should not use the block arguments. """ - if op.parent is None: + # MLIR equivalent: + # https://github.com/llvm/llvm-project/blob/96a3d05ed923d2abd51acb52984b83b9e8044924/mlir/lib/IR/PatternMatch.cpp#L290 + assert len(arg_values) == len(source.args), ( + f"Expected {len(source.args)} replacement argument values, got " + f"{len(arg_values)}" + ) + + # The source block will be deleted, so it should not have any users (i.e., + # there should be no predecessors). + # TODO: check that the block has no predecessors + + # assert not block.predecessors, "expected 'source' to have no predecessors" + + if (dest := op.parent) is None: raise Exception("Cannot inline a block before a toplevel operation") - ops = list(block.ops) + # TODO: verify that the successors will make sense after inlining + # We currently cannot perform this check, just like the TODO above, due to lack + # of infrastructure in xDSL + # https://github.com/xdslproject/xdsl/issues/2066 + + # if dest.last_op != op: + # The source block will be inserted in the middle of the dest block, so the + # source block should have no successors. Otherwise, the remainder of the dest + # block would be unreachable. + # assert not source.successors, "expected 'source' to have no successors"); + # else: + # The source block will be inserted at the end of the dest block, so the dest + # block should have no successors. Otherwise, the inserted operations will be + # unreachable. + # assert not dest.successors, "expected 'dest' to have no successors"); + + # Replace all of the successor arguments with the provided values. + for arg, val in zip(source.args, arg_values, strict=True): + arg.replace_by(val) + + # Move operations from the source block to the dest block and erase the + # source block. + ops = list(source.ops) for block_op in ops: block_op.detach() - op.parent.insert_ops_before(ops, op) + dest.insert_ops_before(ops, op) + parent_region = source.parent + assert parent_region is not None + parent_region.detach_block(source) + source.erase() @staticmethod def inline_block_after(block: Block, op: Operation):