Skip to content

Commit

Permalink
core: Add BlockInsertPoint to simplify the builder API
Browse files Browse the repository at this point in the history
`BlockInsertPoint` acts the same as `InsertPoint`, but for blocks.
Its equivalent in MLIR is `Block::iterator`.

stack-info: PR: #3703, branch: math-fehr/stack/7
  • Loading branch information
math-fehr committed Jan 6, 2025
1 parent 4cb9ea2 commit 8075761
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
20 changes: 19 additions & 1 deletion tests/test_op_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from xdsl.dialects.builtin import IntAttr, i32, i64
from xdsl.dialects.scf import IfOp
from xdsl.ir import Block, BlockArgument, Operation, Region
from xdsl.rewriter import InsertPoint
from xdsl.rewriter import InsertPoint, BlockInsertPoint


def test_insertion_point_constructors():
Expand All @@ -24,6 +24,24 @@ def test_insertion_point_constructors():
assert InsertPoint.after(op2) == InsertPoint(target, None)


def test_block_insertion_point_constructors():
target = Region(
[
(block1 := Block()),
(block2 := Block()),
]
)

assert BlockInsertPoint.at_start(target) == BlockInsertPoint(target, block1)
assert BlockInsertPoint.at_end(target) == BlockInsertPoint(target, None)
assert BlockInsertPoint.before(block1) == BlockInsertPoint(target, block1)
assert BlockInsertPoint.after(block1) == BlockInsertPoint(target, block2)
assert BlockInsertPoint.before(block2) == BlockInsertPoint(target, block2)
assert BlockInsertPoint.after(block2) == BlockInsertPoint(target, None)

assert BlockInsertPoint.at_start(target) != BlockInsertPoint.at_end(target)


def test_builder():
target = Block(
[
Expand Down
50 changes: 50 additions & 0 deletions xdsl/rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,56 @@ def at_end(block: Block) -> InsertPoint:
return InsertPoint(block)


@dataclass(frozen=True)
class BlockInsertPoint:
"""
An insert point for a block.
It is either a point before a block, or after a block.
"""

region: Region
"""The region where the insertion point is in."""

insert_before: Block | None = field(default=None)
"""
The insertion point is right before this block.
If the block is None, the insertion point is at the end of the region.
"""

def __post_init__(self) -> None:
# Check that the insertion point is valid.
# An insertion point can only be invalid if `insert_before` is a `Block`,
# and its parent is not `region`.
if self.insert_before is not None:
if self.insert_before.parent is not self.region:
raise ValueError("Insertion point must be in the builder's `region`")

@staticmethod
def before(block: Block) -> BlockInsertPoint:
"""Gets the insertion point before a block."""
if (region := block.parent) is None:
raise ValueError("Block insertion point must have a parent region")
return BlockInsertPoint(region, block)

@staticmethod
def after(block: Block) -> BlockInsertPoint:
"""Gets the insertion point after a block."""
region = block.parent
if region is None:
raise ValueError("Block insertion point must have a parent region")
return BlockInsertPoint(region, block.next_block)

@staticmethod
def at_start(region: Region) -> BlockInsertPoint:
"""Gets the insertion point at the start of a region."""
return BlockInsertPoint(region, region.first_block)

@staticmethod
def at_end(region: Region) -> BlockInsertPoint:
"""Gets the insertion point at the end of a region."""
return BlockInsertPoint(region)


class Rewriter:
@staticmethod
def erase_op(op: Operation, safe_erase: bool = True):
Expand Down

0 comments on commit 8075761

Please sign in to comment.