From 8f0fd8878799fece63ac97e4f1df115f3b70ec6e Mon Sep 17 00:00:00 2001 From: martin-luecke Date: Thu, 14 Apr 2022 21:07:05 +0200 Subject: [PATCH] add support for cloning operations with regions (#99) --- src/xdsl/ir.py | 75 ++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 67 insertions(+), 8 deletions(-) diff --git a/src/xdsl/ir.py b/src/xdsl/ir.py index 0dca9cf2cc..dc76620e9d 100644 --- a/src/xdsl/ir.py +++ b/src/xdsl/ir.py @@ -368,18 +368,49 @@ def verify(self, verify_nested_ops: bool = True) -> None: def verify_(self) -> None: pass - def clone_without_regions(self: OperationType) -> OperationType: + def clone_without_regions( + self: OperationType, + value_mapper: Optional[Dict[SSAValue, SSAValue]] = None, + block_mapper: Optional[Dict[Block, + Block]] = None) -> OperationType: """Clone an operation, with empty regions instead.""" - operands = self.operands + if value_mapper is None: + value_mapper = {} + if block_mapper is None: + block_mapper = {} + operands = [ + (value_mapper[operand] if operand in value_mapper else operand) + for operand in self.operands + ] result_types = [res.typ for res in self.results] attributes = self.attributes.copy() - successors = self.successors.copy() + successors = [(block_mapper[successor] + if successor in block_mapper else successor) + for successor in self.successors] regions = [Region() for _ in self.regions] - return self.create(operands=operands, - result_types=result_types, - attributes=attributes, - successors=successors, - regions=regions) + cloned_op = self.create(operands=operands, + result_types=result_types, + attributes=attributes, + successors=successors, + regions=regions) + for idx, result in enumerate(cloned_op.results): + value_mapper[self.results[idx]] = result + return cloned_op + + def clone( + self: OperationType, + value_mapper: Optional[Dict[SSAValue, SSAValue]] = None, + block_mapper: Optional[Dict[Block, + Block]] = None) -> OperationType: + """Clone an operation with all its regions and operations in them.""" + if value_mapper is None: + value_mapper = {} + if block_mapper is None: + block_mapper = {} + op = self.clone_without_regions(value_mapper, block_mapper) + for idx, region in enumerate(self.regions): + region.clone_into(op.regions[idx], 0, value_mapper, block_mapper) + return op def erase(self, safe_erase=True, drop_references=True) -> None: """ @@ -778,6 +809,34 @@ def erase_block(self, block: Union[int, Block], safe_erase=True) -> None: block = self.detach_block(block) block.erase(safe_erase=safe_erase) + def clone_into(self, + dest: Region, + insert_index: Optional[int] = None, + value_mapper: Optional[Dict[SSAValue, SSAValue]] = None, + block_mapper: Optional[Dict[Block, Block]] = None): + """ + Clone all block of this region into `dest` to position `insert_index` + """ + assert (dest is not None) + assert (dest != self) + if insert_index is None: + insert_index = len(dest.blocks) + if value_mapper is None: + value_mapper = {} + if block_mapper is None: + block_mapper = {} + + for block in self.blocks: + new_block = Block() + block_mapper[block] = new_block + for idx, block_arg in enumerate(block.args): + new_block.insert_arg(block_arg.typ, idx) + value_mapper[block_arg] = new_block.args[idx] + for op in block.ops: + new_block.add_op(op.clone(value_mapper, block_mapper)) + dest.insert_block(new_block, insert_index) + insert_index += 1 + def walk(self, fun: Callable[[Operation], None]) -> None: """Call a function on all operations contained in the region.""" for block in self.blocks: