Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: add replacement arg_values to inline_block_before #2061

Merged
merged 3 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,9 +766,8 @@ def test_inline_block_before_matched_op():
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() : () -> !test.type<"int">
%2 = "test.op"() ({
^0:
}, {
^1:
^0:
}) : () -> !test.type<"int">
}) : () -> ()
"""
Expand Down Expand Up @@ -810,12 +809,11 @@ def test_inline_block_before():
%1 = "test.op"() ({
%2 = "test.op"() : () -> !test.type<"int">
%3 = "test.op"() ({
^0:
}, {
^1:
^0:
}) : () -> !test.type<"int">
}, {
^2:
^1:
}) : () -> !test.type<"int">
}) : () -> ()
"""
Expand Down Expand Up @@ -857,9 +855,8 @@ def test_inline_block_at_before_when_op_is_matched_op():
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() : () -> !test.type<"int">
%2 = "test.op"() ({
^0:
}, {
^1:
^0:
}) : () -> !test.type<"int">
}) : () -> ()
"""
Expand All @@ -877,6 +874,62 @@ def match_and_rewrite(self, matched_op: test.TestOp, rewriter: PatternRewriter):
)


def test_inline_block_before_with_args():
"""Test the inlining of a block before an operation."""

prog = """\
"builtin.module"() ({
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() ({
^0(%arg0 : !test.type<"int">):
%1 = "test.op"() ({
^1(%arg1 : !test.type<"int">):
%1 = "test.op"(%arg1) : (!test.type<"int">) -> !test.type<"int">
}, {
^2:
}) : () -> !test.type<"int">
}, {
^3:
}) : () -> !test.type<"int">
}) : () -> ()
"""

expected = """\
"builtin.module"() ({
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() ({
^0(%arg0 : !test.type<"int">):
%2 = "test.op"(%arg0) : (!test.type<"int">) -> !test.type<"int">
%3 = "test.op"() ({
}, {
^1:
}) : () -> !test.type<"int">
}, {
^2:
}) : () -> !test.type<"int">
}) : () -> ()
"""

class Rewrite(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, matched_op: test.TestOp, rewriter: PatternRewriter):
if matched_op.regs and matched_op.regs[0].blocks:
outer_block = matched_op.regs[0].blocks[0]
first_op = outer_block.first_op

if isinstance(first_op, test.TestOp):
inner_block = first_op.regs[0].blocks[0]
rewriter.inline_block_before(
inner_block, first_op, outer_block.args
)

rewrite_and_compare(
prog,
expected,
PatternRewriteWalker(Rewrite(), apply_recursively=False),
)


def test_inline_block_after():
"""Test the inlining of a block after an operation."""

Expand Down
1 change: 0 additions & 1 deletion tests/test_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,6 @@ def test_inline_block_before():
%0 = "test.op"() : () -> !test.type<"int">
%1 = "test.op"() : () -> !test.type<"int">
"test.op"() ({
^0:
}) : () -> ()
}) : () -> ()
"""
Expand Down
6 changes: 4 additions & 2 deletions xdsl/pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,15 @@ def inline_block_before_matched_op(self, block: Block):
"""
self.inline_block_before(block, self.current_operation)

def inline_block_before(self, block: Block, op: Operation):
def inline_block_before(
self, block: Block, op: Operation, arg_values: Sequence[SSAValue] = ()
):
"""
Move the block operations before the given operation.
The block should not be a parent of the operation.
"""
self.has_done_action = True
Rewriter.inline_block_before(block, op)
Rewriter.inline_block_before(block, op, arg_values=arg_values)

def inline_block_after_matched_op(self, block: Block):
"""
Expand Down
50 changes: 45 additions & 5 deletions xdsl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,20 +102,60 @@ def inline_block_at_start(inlined_block: Block, extended_block: Block):
Rewriter.inline_block_before(inlined_block, first_op_of_extended_block)

@staticmethod
def inline_block_before(block: Block, op: Operation):
def inline_block_before(
source: Block, op: Operation, arg_values: Sequence[SSAValue] = ()
):
"""
Move the block operations before another operation.
The block should not be a parent of the operation.
The block operations should not use the block arguments.
"""
if op.parent is None:
# MLIR equivalent:
# https://github.com/llvm/llvm-project/blob/96a3d05ed923d2abd51acb52984b83b9e8044924/mlir/lib/IR/PatternMatch.cpp#L290
assert len(arg_values) == len(source.args), (
f"Expected {len(source.args)} replacement argument values, got "
f"{len(arg_values)}"
)

# The source block will be deleted, so it should not have any users (i.e.,
# there should be no predecessors).
# TODO: check that the block has no predecessors

# assert not block.predecessors, "expected 'source' to have no predecessors"

if (dest := op.parent) is None:
raise Exception("Cannot inline a block before a toplevel operation")

ops = list(block.ops)
# TODO: verify that the successors will make sense after inlining
# We currently cannot perform this check, just like the TODO above, due to lack
# of infrastructure in xDSL
# https://github.com/xdslproject/xdsl/issues/2066

# if dest.last_op != op:
# The source block will be inserted in the middle of the dest block, so the
# source block should have no successors. Otherwise, the remainder of the dest
# block would be unreachable.
# assert not source.successors, "expected 'source' to have no successors");
# else:
# The source block will be inserted at the end of the dest block, so the dest
# block should have no successors. Otherwise, the inserted operations will be
# unreachable.
# assert not dest.successors, "expected 'dest' to have no successors");

# Replace all of the successor arguments with the provided values.
for arg, val in zip(source.args, arg_values, strict=True):
arg.replace_by(val)

# Move operations from the source block to the dest block and erase the
# source block.
ops = list(source.ops)
for block_op in ops:
block_op.detach()

op.parent.insert_ops_before(ops, op)
dest.insert_ops_before(ops, op)
parent_region = source.parent
assert parent_region is not None
parent_region.detach_block(source)
source.erase()

@staticmethod
def inline_block_after(block: Block, op: Operation):
Expand Down
Loading