Skip to content

Commit

Permalink
add support for cloning operations with regions (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-luecke authored Apr 14, 2022
1 parent bb2bc4d commit 8f0fd88
Showing 1 changed file with 67 additions and 8 deletions.
75 changes: 67 additions & 8 deletions src/xdsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,18 +368,49 @@ def verify(self, verify_nested_ops: bool = True) -> None:
def verify_(self) -> None:
pass

def clone_without_regions(self: OperationType) -> OperationType:
def clone_without_regions(
self: OperationType,
value_mapper: Optional[Dict[SSAValue, SSAValue]] = None,
block_mapper: Optional[Dict[Block,
Block]] = None) -> OperationType:
"""Clone an operation, with empty regions instead."""
operands = self.operands
if value_mapper is None:
value_mapper = {}
if block_mapper is None:
block_mapper = {}
operands = [
(value_mapper[operand] if operand in value_mapper else operand)
for operand in self.operands
]
result_types = [res.typ for res in self.results]
attributes = self.attributes.copy()
successors = self.successors.copy()
successors = [(block_mapper[successor]
if successor in block_mapper else successor)
for successor in self.successors]
regions = [Region() for _ in self.regions]
return self.create(operands=operands,
result_types=result_types,
attributes=attributes,
successors=successors,
regions=regions)
cloned_op = self.create(operands=operands,
result_types=result_types,
attributes=attributes,
successors=successors,
regions=regions)
for idx, result in enumerate(cloned_op.results):
value_mapper[self.results[idx]] = result
return cloned_op

def clone(
self: OperationType,
value_mapper: Optional[Dict[SSAValue, SSAValue]] = None,
block_mapper: Optional[Dict[Block,
Block]] = None) -> OperationType:
"""Clone an operation with all its regions and operations in them."""
if value_mapper is None:
value_mapper = {}
if block_mapper is None:
block_mapper = {}
op = self.clone_without_regions(value_mapper, block_mapper)
for idx, region in enumerate(self.regions):
region.clone_into(op.regions[idx], 0, value_mapper, block_mapper)
return op

def erase(self, safe_erase=True, drop_references=True) -> None:
"""
Expand Down Expand Up @@ -778,6 +809,34 @@ def erase_block(self, block: Union[int, Block], safe_erase=True) -> None:
block = self.detach_block(block)
block.erase(safe_erase=safe_erase)

def clone_into(self,
dest: Region,
insert_index: Optional[int] = None,
value_mapper: Optional[Dict[SSAValue, SSAValue]] = None,
block_mapper: Optional[Dict[Block, Block]] = None):
"""
Clone all block of this region into `dest` to position `insert_index`
"""
assert (dest is not None)
assert (dest != self)
if insert_index is None:
insert_index = len(dest.blocks)
if value_mapper is None:
value_mapper = {}
if block_mapper is None:
block_mapper = {}

for block in self.blocks:
new_block = Block()
block_mapper[block] = new_block
for idx, block_arg in enumerate(block.args):
new_block.insert_arg(block_arg.typ, idx)
value_mapper[block_arg] = new_block.args[idx]
for op in block.ops:
new_block.add_op(op.clone(value_mapper, block_mapper))
dest.insert_block(new_block, insert_index)
insert_index += 1

def walk(self, fun: Callable[[Operation], None]) -> None:
"""Call a function on all operations contained in the region."""
for block in self.blocks:
Expand Down

0 comments on commit 8f0fd88

Please sign in to comment.