From 80757619120e185c9e05524c2fb65434c5f64952 Mon Sep 17 00:00:00 2001 From: Mathieu Fehr Date: Mon, 6 Jan 2025 09:58:02 +0000 Subject: [PATCH] core: Add BlockInsertPoint to simplify the builder API `BlockInsertPoint` acts the same as `InsertPoint`, but for blocks. Its equivalent in MLIR is `Block::iterator`. stack-info: PR: https://github.com/xdslproject/xdsl/pull/3703, branch: math-fehr/stack/7 --- tests/test_op_builder.py | 20 +++++++++++++++- xdsl/rewriter.py | 50 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+), 1 deletion(-) diff --git a/tests/test_op_builder.py b/tests/test_op_builder.py index e0e3e316ea..3e8b915845 100644 --- a/tests/test_op_builder.py +++ b/tests/test_op_builder.py @@ -5,7 +5,7 @@ from xdsl.dialects.builtin import IntAttr, i32, i64 from xdsl.dialects.scf import IfOp from xdsl.ir import Block, BlockArgument, Operation, Region -from xdsl.rewriter import InsertPoint +from xdsl.rewriter import InsertPoint, BlockInsertPoint def test_insertion_point_constructors(): @@ -24,6 +24,24 @@ def test_insertion_point_constructors(): assert InsertPoint.after(op2) == InsertPoint(target, None) +def test_block_insertion_point_constructors(): + target = Region( + [ + (block1 := Block()), + (block2 := Block()), + ] + ) + + assert BlockInsertPoint.at_start(target) == BlockInsertPoint(target, block1) + assert BlockInsertPoint.at_end(target) == BlockInsertPoint(target, None) + assert BlockInsertPoint.before(block1) == BlockInsertPoint(target, block1) + assert BlockInsertPoint.after(block1) == BlockInsertPoint(target, block2) + assert BlockInsertPoint.before(block2) == BlockInsertPoint(target, block2) + assert BlockInsertPoint.after(block2) == BlockInsertPoint(target, None) + + assert BlockInsertPoint.at_start(target) != BlockInsertPoint.at_end(target) + + def test_builder(): target = Block( [ diff --git a/xdsl/rewriter.py b/xdsl/rewriter.py index b8726a782b..5154d2d934 100644 --- a/xdsl/rewriter.py +++ b/xdsl/rewriter.py @@ -58,6 +58,56 @@ def at_end(block: Block) -> InsertPoint: return InsertPoint(block) +@dataclass(frozen=True) +class BlockInsertPoint: + """ + An insert point for a block. + It is either a point before a block, or after a block. + """ + + region: Region + """The region where the insertion point is in.""" + + insert_before: Block | None = field(default=None) + """ + The insertion point is right before this block. + If the block is None, the insertion point is at the end of the region. + """ + + def __post_init__(self) -> None: + # Check that the insertion point is valid. + # An insertion point can only be invalid if `insert_before` is a `Block`, + # and its parent is not `region`. + if self.insert_before is not None: + if self.insert_before.parent is not self.region: + raise ValueError("Insertion point must be in the builder's `region`") + + @staticmethod + def before(block: Block) -> BlockInsertPoint: + """Gets the insertion point before a block.""" + if (region := block.parent) is None: + raise ValueError("Block insertion point must have a parent region") + return BlockInsertPoint(region, block) + + @staticmethod + def after(block: Block) -> BlockInsertPoint: + """Gets the insertion point after a block.""" + region = block.parent + if region is None: + raise ValueError("Block insertion point must have a parent region") + return BlockInsertPoint(region, block.next_block) + + @staticmethod + def at_start(region: Region) -> BlockInsertPoint: + """Gets the insertion point at the start of a region.""" + return BlockInsertPoint(region, region.first_block) + + @staticmethod + def at_end(region: Region) -> BlockInsertPoint: + """Gets the insertion point at the end of a region.""" + return BlockInsertPoint(region) + + class Rewriter: @staticmethod def erase_op(op: Operation, safe_erase: bool = True):