From 2585bbf5b82c3e4db7359ac0cd745edcbbc006e9 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 13 Aug 2024 13:11:48 +0100 Subject: [PATCH 1/5] Simple core helpers for type access. --- tests/dialects/test_hw.py | 4 ++-- tests/dialects/test_linalg.py | 2 +- tests/dialects/test_pdl.py | 2 +- tests/irdl/test_operation_builder.py | 20 +++++++++---------- .../pattern_rewriter/test_pattern_rewriter.py | 2 +- tests/test_ir.py | 4 ++-- .../lowering/convert_func_to_riscv_func.py | 2 +- .../lowering/convert_riscv_scf_to_riscv_cf.py | 8 +++----- .../convert_snitch_stream_to_snitch.py | 2 +- xdsl/dialects/affine.py | 8 ++++---- xdsl/dialects/csl/csl.py | 6 ++---- xdsl/dialects/fsm.py | 6 +----- xdsl/dialects/func.py | 6 +++--- xdsl/dialects/gpu.py | 18 ++++++++--------- xdsl/dialects/hw.py | 2 +- xdsl/dialects/pdl.py | 2 +- xdsl/ir/core.py | 14 ++++++++++++- xdsl/irdl/operations.py | 2 +- xdsl/printer.py | 4 +--- .../immutable_ir/immutable_ir.py | 2 +- .../common_subexpression_elimination.py | 8 +++----- xdsl/transforms/stencil_bufferize.py | 14 ++++++------- xdsl/transforms/stencil_to_csl_stencil.py | 4 ++-- xdsl/transforms/stencil_unroll.py | 4 ++-- 24 files changed, 71 insertions(+), 75 deletions(-) diff --git a/tests/dialects/test_hw.py b/tests/dialects/test_hw.py index a4b2e5ef62..7ad8e70db5 100644 --- a/tests/dialects/test_hw.py +++ b/tests/dialects/test_hw.py @@ -364,8 +364,8 @@ def test_instance_builder(): assert inst_op.arg_names.data == (StringAttr("foo"), StringAttr("bar")) assert inst_op.result_names.data == (StringAttr("baz"), StringAttr("qux")) - assert [op.type for op in inst_op.operands] == [i32, i64] - assert [res.type for res in inst_op.results] == [i32, i64] + assert inst_op.operands_types == (i32, i64) + assert inst_op.results_types == (i32, i64) def test_hwmoduleop_hwmodulelike(): diff --git a/tests/dialects/test_linalg.py b/tests/dialects/test_linalg.py index 2d632fa8ac..12f5081ff3 100644 --- a/tests/dialects/test_linalg.py +++ b/tests/dialects/test_linalg.py @@ -43,7 +43,7 @@ def test_matmul_on_memrefs(): matmul_op = linalg.MatmulOp(inputs=(a.memref, b.memref), outputs=(c.memref,)) - assert tuple(result.type for result in matmul_op.results) == () + assert matmul_op.results_types == () def test_loop_range_methods(): diff --git a/tests/dialects/test_pdl.py b/tests/dialects/test_pdl.py index f334f1ac46..937ee58f5b 100644 --- a/tests/dialects/test_pdl.py +++ b/tests/dialects/test_pdl.py @@ -39,7 +39,7 @@ def test_build_anr(): assert anr.constraint_name == StringAttr("anr") assert anr.args == (type_val,) assert len(anr.results) == 1 - assert [r.type for r in anr.results] == [attribute_type] + assert anr.results_types == (attribute_type,) def test_build_rewrite(): diff --git a/tests/irdl/test_operation_builder.py b/tests/irdl/test_operation_builder.py index 12501a09d1..5a5fb2745a 100644 --- a/tests/irdl/test_operation_builder.py +++ b/tests/irdl/test_operation_builder.py @@ -57,7 +57,7 @@ class ResultOp(IRDLOperation): def test_result_builder(): op = ResultOp.build(result_types=[StringAttr("")]) op.verify() - assert [res.type for res in op.results] == [StringAttr("")] + assert op.results_types == (StringAttr(""),) def test_result_builder_exception(): @@ -79,7 +79,7 @@ def test_opt_result_builder(): op1.verify() op2.verify() op3.verify() - assert [res.type for res in op1.results] == [StringAttr("")] + assert op1.results_types == (StringAttr(""),) assert len(op2.results) == 0 assert len(op3.results) == 0 @@ -99,10 +99,10 @@ class VarResultOp(IRDLOperation): def test_var_result_builder(): op = VarResultOp.build(result_types=[[StringAttr("0"), StringAttr("1")]]) op.verify() - assert [res.type for res in op.results] == [ + assert op.results_types == ( StringAttr("0"), StringAttr("1"), - ] + ) @irdl_op_definition @@ -122,12 +122,12 @@ def test_two_var_result_builder(): ] ) op.verify() - assert [res.type for res in op.results] == [ + assert op.results_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), StringAttr("3"), - ] + ) assert op.attributes[ AttrSizedResultSegments.attribute_name @@ -142,12 +142,12 @@ def test_two_var_result_builder2(): ] ) op.verify() - assert [res.type for res in op.results] == [ + assert op.results_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), StringAttr("3"), - ] + ) assert op.attributes[ AttrSizedResultSegments.attribute_name ] == DenseArrayBase.from_list(i32, [1, 3]) @@ -172,13 +172,13 @@ def test_var_mixed_builder(): ] ) op.verify() - assert [res.type for res in op.results] == [ + assert op.results_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), StringAttr("3"), StringAttr("4"), - ] + ) assert op.attributes[ AttrSizedResultSegments.attribute_name diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index a32ac3fa88..e9cc19855b 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -1075,7 +1075,7 @@ def match_and_rewrite(self, op: ModuleOp, rewriter: PatternRewriter): old_op = next(ops_iter) assert isinstance(old_op, test.TestOp) new_region = rewriter.move_region_contents_to_new_regions(old_op.regions[0]) - res_types = [r.type for r in old_op.results] + res_types = old_op.results_types new_op = test.TestOp.create(result_types=res_types, regions=[new_region]) rewriter.insert_op(new_op, InsertPoint.after(old_op)) diff --git a/tests/test_ir.py b/tests/test_ir.py index 725bc97700..47ef25d985 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -488,8 +488,8 @@ def test_split_block_args(): new_block = old_block.split_before(op, arg_types=(i32, i64)) - arg_types = [a.type for a in new_block.args] - assert arg_types == [i32, i64] + arg_types = new_block.args_types + assert arg_types == (i32, i64) def test_region_clone_into_circular_blocks(): diff --git a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py index a6f2226c00..f2a54d5e50 100644 --- a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py +++ b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py @@ -31,7 +31,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): if (first_block := op.body.blocks.first) is not None: cast_block_args_from_a_regs(first_block, rewriter) - input_types = [arg.type for arg in first_block.args] + input_types = first_block.args_types else: input_types = tuple(a_regs_for_types(op.function_type.inputs.data)) result_types = list(a_regs_for_types(op.function_type.outputs.data)) diff --git a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py index 68e1ab3819..1ea29cbe87 100644 --- a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py +++ b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py @@ -78,15 +78,13 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /): init_block = op.parent_block() assert init_block is not None - body_args = op.body.blocks[0].args + body = op.body.blocks[0] # TODO: add method to rewriter - end_block = init_block.split_before( - op, arg_types=(arg.type for arg in body_args) - ) + end_block = init_block.split_before(op, arg_types=body.args_types) # The first argument of the loop body block is the loop counter by SCF invariant. - loop_var_reg = body_args[0].type + loop_var_reg = body.args[0].type assert isinstance(loop_var_reg, riscv.IntRegisterType) # Use the first block of the loop body as the condition block since it is the diff --git a/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py b/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py index 8fcfc5e7a2..f46b28e1cb 100644 --- a/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py +++ b/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py @@ -178,7 +178,7 @@ def match_and_rewrite( block = op.body.block rewriter.insert_op_before_matched_op( - enable_op := snitch.SsrEnable(tuple(arg.type for arg in block.args)) + enable_op := snitch.SsrEnable(block.args_types) ) for val, arg in zip(enable_op.streams, block.args): diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index 52ad520c9b..ebbc74372b 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -142,14 +142,14 @@ def verify_(self) -> None: raise VerifyException( "Expected as many upper bound operands as upper bound dimensions and symbols." ) - iter_types = [op.type for op in self.inits] - if iter_types != [res.type for res in self.results]: + iter_types = tuple(op.type for op in self.inits) + if iter_types != self.results_types: raise VerifyException( "Expected all operands and result pairs to have matching types" ) entry_block: Block = self.body.blocks[0] - block_arg_types = [IndexType()] + iter_types - arg_types = [arg.type for arg in entry_block.args] + block_arg_types = (IndexType(), *iter_types) + arg_types = entry_block.args_types if block_arg_types != arg_types: raise VerifyException( "Expected BlockArguments to have the same types as the operands" diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index e0b1de741c..22e8851151 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -151,7 +151,7 @@ def _verify(self): return entry_block: Block = self.body.blocks[0] - block_arg_types = [arg.type for arg in entry_block.args] + block_arg_types = entry_block.args_types if self.function_type.inputs.data != tuple(block_arg_types): raise VerifyException( "Expected entry block arguments to have the same types as the function " @@ -718,9 +718,7 @@ def verify_(self) -> None: func_op = self.parent_op() assert isinstance(func_op, FuncOp) or isinstance(func_op, TaskOp) - if tuple(func_op.function_type.outputs) != tuple( - val.type for val in self.operands - ): + if tuple(func_op.function_type.outputs.data) != self.operands_types: raise VerifyException( "Expected arguments to have the same types as the function output types" ) diff --git a/xdsl/dialects/fsm.py b/xdsl/dialects/fsm.py index 049f1cd206..499857bbba 100644 --- a/xdsl/dialects/fsm.py +++ b/xdsl/dialects/fsm.py @@ -205,11 +205,7 @@ def verify_(self): raise VerifyException("Transition regions should not output any value") while (parent := parent.parent_op()) is not None: if isinstance(parent, MachineOp): - if not ( - [operand.type for operand in self.operands] - == [result for result in parent.function_type.outputs] - and len(self.operands) == len(parent.function_type.outputs) - ): + if not (self.operands_types == parent.function_type.outputs.data): raise VerifyException( "OutputOp output type must be consistent with the machine " + str(parent.sym_name) diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 1bcfc8bfea..4825a78019 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -114,7 +114,7 @@ def verify_(self) -> None: # TODO: how to verify that there is a terminator? entry_block = self.body.blocks.first assert entry_block is not None - block_arg_types = [arg.type for arg in entry_block.args] + block_arg_types = entry_block.args_types if self.function_type.inputs.data != tuple(block_arg_types): raise VerifyException( "Expected entry block arguments to have the same types as the function " @@ -222,10 +222,10 @@ def update_function_type(self): not self.is_declaration ), "update_function_type does not work with function declarations!" return_op = self.get_return_op() - return_type: tuple[Attribute, ...] = self.function_type.outputs.data + return_type = self.function_type.outputs.data if return_op is not None: - return_type = tuple(arg.type for arg in return_op.operands) + return_type = return_op.operands_types self.properties["function_type"] = FunctionType.from_lists( [arg.type for arg in self.args], diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index 67d1519540..f66b88d9a0 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -245,7 +245,7 @@ def verify_(self) -> None: f"{self.operand.type}. They must be the same type for gpu.all_reduce" ) - non_empty_body = any(b.ops for b in self.body.blocks) + non_empty_body = len(self.body.blocks) > 0 op_attr = self.op is not None if non_empty_body == op_attr: if op_attr: @@ -258,9 +258,8 @@ def verify_(self) -> None: "gpu.all_reduce need either a non empty body or an op attribute." ) if non_empty_body: - region_args = self.body.blocks[0].args - args_types = [r.type for r in region_args] - if args_types != [self.result.type, self.operand.type]: + args_types = self.body.blocks[0].args_types + if args_types != (self.result.type, self.operand.type): raise VerifyException( f"Expected {[str(t) for t in [self.result.type, self.operand.type]]}, " f"got {[str(t) for t in args_types]}. A gpu.all_reduce's body must " @@ -431,7 +430,7 @@ def __init__( def verify_(self): entry_block: Block = self.body.blocks[0] function_inputs = self.function_type.inputs.data - block_arg_types = tuple(a.type for a in entry_block.args) + block_arg_types = entry_block.args_types if function_inputs != block_arg_types: raise VerifyException( "Expected first entry block arguments to have the same types as the " @@ -558,9 +557,8 @@ def __init__( def verify_(self) -> None: if not any(b.ops for b in self.body.blocks): raise VerifyException("gpu.launch requires a non-empty body.") - body_args = self.body.blocks[0].args - args_type = [a.type for a in body_args] - if args_type != [IndexType()] * 12: + args_type = self.body.blocks[0].args_types + if args_type != (IndexType(),) * 12: raise VerifyException( f"Expected [12 x {str(IndexType())}], got {[str(t) for t in args_type]}. " "gpu.launch's body arguments are 12 index arguments, with 3 block " @@ -759,8 +757,8 @@ def __init__(self, operands: Sequence[SSAValue | Operation]): def verify_(self) -> None: op = self.parent_op() if op is not None: - yield_type = [o.type for o in self.values] - result_type = [r.type for r in op.results] + yield_type = tuple(o.type for o in self.values) + result_type = op.results_types if yield_type != result_type: raise VerifyException( f"Expected {[str(t) for t in result_type]}, got {[str(t) for t in yield_type]}. The gpu.yield values " diff --git a/xdsl/dialects/hw.py b/xdsl/dialects/hw.py index 0bdb0d243b..b0f1a533a5 100644 --- a/xdsl/dialects/hw.py +++ b/xdsl/dialects/hw.py @@ -1230,7 +1230,7 @@ def print_output_port(name: str, port_type: Attribute): printer.print_list( zip( (name.data for name in self.result_names), - (result.type for result in self.results), + self.results_types, ), lambda x: print_output_port(*x), ) diff --git a/xdsl/dialects/pdl.py b/xdsl/dialects/pdl.py index 2b341649d6..02c5f18ef1 100644 --- a/xdsl/dialects/pdl.py +++ b/xdsl/dialects/pdl.py @@ -260,7 +260,7 @@ def print(self, printer: Printer) -> None: printer.print(")") if len(self.results) != 0: printer.print(" : ") - printer.print_list([res.type for res in self.results], printer.print) + printer.print_list(self.results_types, printer.print) @irdl_op_definition diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index e0607efe90..a9aff23181 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -614,6 +614,14 @@ class Operation(IRNode): def parent_node(self) -> IRNode | None: return self.parent + @property + def results_types(self) -> Sequence[Attribute]: + return tuple(r.type for r in self.results) + + @property + def operands_types(self) -> Sequence[Attribute]: + return tuple(operand.type for operand in self.operands) + def parent_op(self) -> Operation | None: if p := self.parent_region(): return p.parent @@ -896,7 +904,7 @@ def clone_without_regions( (value_mapper[operand] if operand in value_mapper else operand) for operand in self.operands ] - result_types = [res.type for res in self.results] + result_types = self.results_types attributes = self.attributes.copy() properties = self.properties.copy() successors = [ @@ -1214,6 +1222,10 @@ def __init__( self.add_ops(ops) + @property + def args_types(self) -> Sequence[Attribute]: + return tuple(arg.type for arg in self._args) + @property def parent_node(self) -> IRNode | None: return self.parent diff --git a/xdsl/irdl/operations.py b/xdsl/irdl/operations.py index 54ad8de93b..f00c2fd817 100644 --- a/xdsl/irdl/operations.py +++ b/xdsl/irdl/operations.py @@ -1390,7 +1390,7 @@ def irdl_op_verify_regions( f"{len(region.blocks)} blocks" ) if (first_block := region.blocks.first) is not None: - entry_args_types = tuple(a.type for a in first_block.args) + entry_args_types = first_block.args_types try: region_def.entry_args.verify(entry_args_types, constraint_context) except Exception as e: diff --git a/xdsl/printer.py b/xdsl/printer.py index f68a05e14f..2bd636458c 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -938,9 +938,7 @@ def print_function_type( self.print_string(")") def print_operation_type(self, op: Operation) -> None: - self.print_function_type( - (o.type for o in op.operands), (r.type for r in op.results) - ) + self.print_function_type(op.operands_types, op.results_types) if self.print_debuginfo: self.print_string(" loc(unknown)") diff --git a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py index 4203983bdb..98dde304bf 100644 --- a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py +++ b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py @@ -570,7 +570,7 @@ def from_mutable( op.name, op_type, operands, - [result.type for result in op.results], + op.results_types, properties, attributes, successors, diff --git a/xdsl/transforms/common_subexpression_elimination.py b/xdsl/transforms/common_subexpression_elimination.py index 9da5fed014..c361ef59e8 100644 --- a/xdsl/transforms/common_subexpression_elimination.py +++ b/xdsl/transforms/common_subexpression_elimination.py @@ -42,7 +42,7 @@ def __hash__(self): self.name, sum(hash(i) for i in self.op.attributes.items()), sum(hash(i) for i in self.op.properties.items()), - hash(tuple(i.type for i in self.op.results)), + hash(self.op.results_types), hash(self.op.operands), ) ) @@ -55,12 +55,10 @@ def __eq__(self, other: object): and self.op.attributes == other.op.attributes and self.op.properties == other.op.properties and self.op.operands == other.op.operands - and len(self.op.results) == len(other.op.results) - and all(r.type == o.type for r, o in zip(self.op.results, other.op.results)) - and len(self.op.regions) == len(other.op.regions) + and self.op.results_types == other.op.results_types and all( s.is_structurally_equivalent(o) - for s, o in zip(self.op.regions, other.op.regions) + for s, o in zip(self.op.regions, other.op.regions, strict=True) ) ) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 34f85dfc77..16e62ea7fd 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -1,6 +1,5 @@ from collections.abc import Generator from dataclasses import dataclass -from itertools import chain from typing import Any, TypeVar, cast from xdsl.context import MLContext @@ -286,7 +285,7 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): new_load = LoadOp.create( operands=[op.field], - result_types=[r.type for r in load.results], + result_types=load.results_types, attributes=load.attributes.copy(), properties=load.properties.copy(), ) @@ -306,14 +305,14 @@ class UpdateApplyArgs(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): - new_arg_types = [o.type for o in op.args] - if new_arg_types == [a.type for a in op.region.block.args]: + new_arg_types = tuple(o.type for o in op.args) + if new_arg_types == op.region.block.args_types: return new_block = Block(arg_types=new_arg_types) new_apply = ApplyOp.create( operands=op.operands, - result_types=[r.type for r in op.results], + result_types=op.results_types, properties=op.properties.copy(), attributes=op.attributes.copy(), regions=[Region(new_block)], @@ -407,9 +406,8 @@ def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter): new_upper = op.upper new_lowerext = op.lowerext new_upperext = op.upperext - new_results_types = [ - r.type for r in chain(op.results[:i], op.results[i + 1 :]) - ] + new_results_types = list(op.results_types) + new_results_types.pop(i) bounds = cast(StencilBoundsAttr, cast(TempType[Attribute], r.type).bounds) newub = list(bounds.ub) diff --git a/xdsl/transforms/stencil_to_csl_stencil.py b/xdsl/transforms/stencil_to_csl_stencil.py index b34d72a9e2..c84284980b 100644 --- a/xdsl/transforms/stencil_to_csl_stencil.py +++ b/xdsl/transforms/stencil_to_csl_stencil.py @@ -250,7 +250,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): ) # rebuild stencil.apply op - r_types = [r.type for r in apply_op.results] + r_types = apply_op.results_types assert isa(r_types, Sequence[stencil.TempType[Attribute]]) new_apply_op = stencil.ApplyOp.build( operands=[[*apply_op.args, prefetch_op.result], apply_op.dest], @@ -521,7 +521,7 @@ def get_prefetch_overhead(o: OpResult): chunk_reduce, post_process, ], - result_types=[r.type for r in op.results] or [[]], + result_types=[op.results_types], ) ) diff --git a/xdsl/transforms/stencil_unroll.py b/xdsl/transforms/stencil_unroll.py index a672483e9b..5418cc328b 100644 --- a/xdsl/transforms/stencil_unroll.py +++ b/xdsl/transforms/stencil_unroll.py @@ -78,8 +78,8 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): return # Enforced by verification - res_types = [r.type for r in op.results] - assert isa(res_types, list[TempType[Attribute]]) + res_types = op.results_types + assert isa(res_types, Sequence[TempType[Attribute]]) dim = res_types[0].get_num_dims() # If unroll factors list is shorter than the dim, fill with ones from the front From fb61fbb482b6ceb6ef4c00052fc7b461fc9d6e9f Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 13 Aug 2024 13:34:02 +0100 Subject: [PATCH 2/5] Update xdsl/dialects/fsm.py Co-authored-by: Sasha Lopoukhine --- xdsl/dialects/fsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/dialects/fsm.py b/xdsl/dialects/fsm.py index 499857bbba..a428ff3920 100644 --- a/xdsl/dialects/fsm.py +++ b/xdsl/dialects/fsm.py @@ -205,7 +205,7 @@ def verify_(self): raise VerifyException("Transition regions should not output any value") while (parent := parent.parent_op()) is not None: if isinstance(parent, MachineOp): - if not (self.operands_types == parent.function_type.outputs.data): + if self.operands_types != parent.function_type.outputs.data: raise VerifyException( "OutputOp output type must be consistent with the machine " + str(parent.sym_name) From 481c893e22e5a1c6a4f14bd51b53480161d217c7 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 13 Aug 2024 13:34:07 +0100 Subject: [PATCH 3/5] Update xdsl/dialects/gpu.py Co-authored-by: Sasha Lopoukhine --- xdsl/dialects/gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index f66b88d9a0..445fd849dc 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -245,7 +245,7 @@ def verify_(self) -> None: f"{self.operand.type}. They must be the same type for gpu.all_reduce" ) - non_empty_body = len(self.body.blocks) > 0 + non_empty_body = bool(self.body.blocks) op_attr = self.op is not None if non_empty_body == op_attr: if op_attr: From 01fb2285d777edde6dc51a09a78d2feb2c2a0dfa Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 13 Aug 2024 13:35:37 +0100 Subject: [PATCH 4/5] Consistent naming. --- tests/dialects/test_hw.py | 4 ++-- tests/dialects/test_linalg.py | 2 +- tests/dialects/test_pdl.py | 2 +- tests/irdl/test_operation_builder.py | 12 ++++++------ tests/pattern_rewriter/test_pattern_rewriter.py | 2 +- tests/test_ir.py | 2 +- .../riscv/lowering/convert_func_to_riscv_func.py | 2 +- .../riscv/lowering/convert_riscv_scf_to_riscv_cf.py | 2 +- .../lowering/convert_snitch_stream_to_snitch.py | 2 +- xdsl/dialects/affine.py | 4 ++-- xdsl/dialects/csl/csl.py | 4 ++-- xdsl/dialects/fsm.py | 2 +- xdsl/dialects/func.py | 4 ++-- xdsl/dialects/gpu.py | 8 ++++---- xdsl/dialects/hw.py | 2 +- xdsl/dialects/pdl.py | 2 +- xdsl/ir/core.py | 8 ++++---- xdsl/irdl/operations.py | 2 +- xdsl/printer.py | 2 +- .../immutable_ir/immutable_ir.py | 2 +- xdsl/transforms/common_subexpression_elimination.py | 4 ++-- xdsl/transforms/stencil_bufferize.py | 8 ++++---- xdsl/transforms/stencil_to_csl_stencil.py | 4 ++-- xdsl/transforms/stencil_unroll.py | 2 +- 24 files changed, 44 insertions(+), 44 deletions(-) diff --git a/tests/dialects/test_hw.py b/tests/dialects/test_hw.py index 7ad8e70db5..37cfa56501 100644 --- a/tests/dialects/test_hw.py +++ b/tests/dialects/test_hw.py @@ -364,8 +364,8 @@ def test_instance_builder(): assert inst_op.arg_names.data == (StringAttr("foo"), StringAttr("bar")) assert inst_op.result_names.data == (StringAttr("baz"), StringAttr("qux")) - assert inst_op.operands_types == (i32, i64) - assert inst_op.results_types == (i32, i64) + assert inst_op.operand_types == (i32, i64) + assert inst_op.result_types == (i32, i64) def test_hwmoduleop_hwmodulelike(): diff --git a/tests/dialects/test_linalg.py b/tests/dialects/test_linalg.py index 12f5081ff3..adf486c2ce 100644 --- a/tests/dialects/test_linalg.py +++ b/tests/dialects/test_linalg.py @@ -43,7 +43,7 @@ def test_matmul_on_memrefs(): matmul_op = linalg.MatmulOp(inputs=(a.memref, b.memref), outputs=(c.memref,)) - assert matmul_op.results_types == () + assert matmul_op.result_types == () def test_loop_range_methods(): diff --git a/tests/dialects/test_pdl.py b/tests/dialects/test_pdl.py index 937ee58f5b..1cd1867356 100644 --- a/tests/dialects/test_pdl.py +++ b/tests/dialects/test_pdl.py @@ -39,7 +39,7 @@ def test_build_anr(): assert anr.constraint_name == StringAttr("anr") assert anr.args == (type_val,) assert len(anr.results) == 1 - assert anr.results_types == (attribute_type,) + assert anr.result_types == (attribute_type,) def test_build_rewrite(): diff --git a/tests/irdl/test_operation_builder.py b/tests/irdl/test_operation_builder.py index 5a5fb2745a..f267ee3071 100644 --- a/tests/irdl/test_operation_builder.py +++ b/tests/irdl/test_operation_builder.py @@ -57,7 +57,7 @@ class ResultOp(IRDLOperation): def test_result_builder(): op = ResultOp.build(result_types=[StringAttr("")]) op.verify() - assert op.results_types == (StringAttr(""),) + assert op.result_types == (StringAttr(""),) def test_result_builder_exception(): @@ -79,7 +79,7 @@ def test_opt_result_builder(): op1.verify() op2.verify() op3.verify() - assert op1.results_types == (StringAttr(""),) + assert op1.result_types == (StringAttr(""),) assert len(op2.results) == 0 assert len(op3.results) == 0 @@ -99,7 +99,7 @@ class VarResultOp(IRDLOperation): def test_var_result_builder(): op = VarResultOp.build(result_types=[[StringAttr("0"), StringAttr("1")]]) op.verify() - assert op.results_types == ( + assert op.result_types == ( StringAttr("0"), StringAttr("1"), ) @@ -122,7 +122,7 @@ def test_two_var_result_builder(): ] ) op.verify() - assert op.results_types == ( + assert op.result_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), @@ -142,7 +142,7 @@ def test_two_var_result_builder2(): ] ) op.verify() - assert op.results_types == ( + assert op.result_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), @@ -172,7 +172,7 @@ def test_var_mixed_builder(): ] ) op.verify() - assert op.results_types == ( + assert op.result_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index e9cc19855b..0f673b9066 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -1075,7 +1075,7 @@ def match_and_rewrite(self, op: ModuleOp, rewriter: PatternRewriter): old_op = next(ops_iter) assert isinstance(old_op, test.TestOp) new_region = rewriter.move_region_contents_to_new_regions(old_op.regions[0]) - res_types = old_op.results_types + res_types = old_op.result_types new_op = test.TestOp.create(result_types=res_types, regions=[new_region]) rewriter.insert_op(new_op, InsertPoint.after(old_op)) diff --git a/tests/test_ir.py b/tests/test_ir.py index 47ef25d985..8ad55cde35 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -488,7 +488,7 @@ def test_split_block_args(): new_block = old_block.split_before(op, arg_types=(i32, i64)) - arg_types = new_block.args_types + arg_types = new_block.arg_types assert arg_types == (i32, i64) diff --git a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py index f2a54d5e50..c479cb4d99 100644 --- a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py +++ b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py @@ -31,7 +31,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): if (first_block := op.body.blocks.first) is not None: cast_block_args_from_a_regs(first_block, rewriter) - input_types = first_block.args_types + input_types = first_block.arg_types else: input_types = tuple(a_regs_for_types(op.function_type.inputs.data)) result_types = list(a_regs_for_types(op.function_type.outputs.data)) diff --git a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py index 1ea29cbe87..24abc9f874 100644 --- a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py +++ b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py @@ -81,7 +81,7 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /): body = op.body.blocks[0] # TODO: add method to rewriter - end_block = init_block.split_before(op, arg_types=body.args_types) + end_block = init_block.split_before(op, arg_types=body.arg_types) # The first argument of the loop body block is the loop counter by SCF invariant. loop_var_reg = body.args[0].type diff --git a/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py b/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py index f46b28e1cb..b049934222 100644 --- a/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py +++ b/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py @@ -178,7 +178,7 @@ def match_and_rewrite( block = op.body.block rewriter.insert_op_before_matched_op( - enable_op := snitch.SsrEnable(block.args_types) + enable_op := snitch.SsrEnable(block.arg_types) ) for val, arg in zip(enable_op.streams, block.args): diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index ebbc74372b..23dc742d69 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -143,13 +143,13 @@ def verify_(self) -> None: "Expected as many upper bound operands as upper bound dimensions and symbols." ) iter_types = tuple(op.type for op in self.inits) - if iter_types != self.results_types: + if iter_types != self.result_types: raise VerifyException( "Expected all operands and result pairs to have matching types" ) entry_block: Block = self.body.blocks[0] block_arg_types = (IndexType(), *iter_types) - arg_types = entry_block.args_types + arg_types = entry_block.arg_types if block_arg_types != arg_types: raise VerifyException( "Expected BlockArguments to have the same types as the operands" diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index 22e8851151..1cce7ee47d 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -151,7 +151,7 @@ def _verify(self): return entry_block: Block = self.body.blocks[0] - block_arg_types = entry_block.args_types + block_arg_types = entry_block.arg_types if self.function_type.inputs.data != tuple(block_arg_types): raise VerifyException( "Expected entry block arguments to have the same types as the function " @@ -718,7 +718,7 @@ def verify_(self) -> None: func_op = self.parent_op() assert isinstance(func_op, FuncOp) or isinstance(func_op, TaskOp) - if tuple(func_op.function_type.outputs.data) != self.operands_types: + if tuple(func_op.function_type.outputs.data) != self.operand_types: raise VerifyException( "Expected arguments to have the same types as the function output types" ) diff --git a/xdsl/dialects/fsm.py b/xdsl/dialects/fsm.py index a428ff3920..d4736fa969 100644 --- a/xdsl/dialects/fsm.py +++ b/xdsl/dialects/fsm.py @@ -205,7 +205,7 @@ def verify_(self): raise VerifyException("Transition regions should not output any value") while (parent := parent.parent_op()) is not None: if isinstance(parent, MachineOp): - if self.operands_types != parent.function_type.outputs.data: + if self.operand_types != parent.function_type.outputs.data: raise VerifyException( "OutputOp output type must be consistent with the machine " + str(parent.sym_name) diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 4825a78019..b42daba236 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -114,7 +114,7 @@ def verify_(self) -> None: # TODO: how to verify that there is a terminator? entry_block = self.body.blocks.first assert entry_block is not None - block_arg_types = entry_block.args_types + block_arg_types = entry_block.arg_types if self.function_type.inputs.data != tuple(block_arg_types): raise VerifyException( "Expected entry block arguments to have the same types as the function " @@ -225,7 +225,7 @@ def update_function_type(self): return_type = self.function_type.outputs.data if return_op is not None: - return_type = return_op.operands_types + return_type = return_op.operand_types self.properties["function_type"] = FunctionType.from_lists( [arg.type for arg in self.args], diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index 445fd849dc..7ac61ffc32 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -258,7 +258,7 @@ def verify_(self) -> None: "gpu.all_reduce need either a non empty body or an op attribute." ) if non_empty_body: - args_types = self.body.blocks[0].args_types + args_types = self.body.blocks[0].arg_types if args_types != (self.result.type, self.operand.type): raise VerifyException( f"Expected {[str(t) for t in [self.result.type, self.operand.type]]}, " @@ -430,7 +430,7 @@ def __init__( def verify_(self): entry_block: Block = self.body.blocks[0] function_inputs = self.function_type.inputs.data - block_arg_types = entry_block.args_types + block_arg_types = entry_block.arg_types if function_inputs != block_arg_types: raise VerifyException( "Expected first entry block arguments to have the same types as the " @@ -557,7 +557,7 @@ def __init__( def verify_(self) -> None: if not any(b.ops for b in self.body.blocks): raise VerifyException("gpu.launch requires a non-empty body.") - args_type = self.body.blocks[0].args_types + args_type = self.body.blocks[0].arg_types if args_type != (IndexType(),) * 12: raise VerifyException( f"Expected [12 x {str(IndexType())}], got {[str(t) for t in args_type]}. " @@ -758,7 +758,7 @@ def verify_(self) -> None: op = self.parent_op() if op is not None: yield_type = tuple(o.type for o in self.values) - result_type = op.results_types + result_type = op.result_types if yield_type != result_type: raise VerifyException( f"Expected {[str(t) for t in result_type]}, got {[str(t) for t in yield_type]}. The gpu.yield values " diff --git a/xdsl/dialects/hw.py b/xdsl/dialects/hw.py index b0f1a533a5..8c433fdfcc 100644 --- a/xdsl/dialects/hw.py +++ b/xdsl/dialects/hw.py @@ -1230,7 +1230,7 @@ def print_output_port(name: str, port_type: Attribute): printer.print_list( zip( (name.data for name in self.result_names), - self.results_types, + self.result_types, ), lambda x: print_output_port(*x), ) diff --git a/xdsl/dialects/pdl.py b/xdsl/dialects/pdl.py index 02c5f18ef1..64b3d93f09 100644 --- a/xdsl/dialects/pdl.py +++ b/xdsl/dialects/pdl.py @@ -260,7 +260,7 @@ def print(self, printer: Printer) -> None: printer.print(")") if len(self.results) != 0: printer.print(" : ") - printer.print_list(self.results_types, printer.print) + printer.print_list(self.result_types, printer.print) @irdl_op_definition diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index a9aff23181..08ec48a47f 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -615,11 +615,11 @@ def parent_node(self) -> IRNode | None: return self.parent @property - def results_types(self) -> Sequence[Attribute]: + def result_types(self) -> Sequence[Attribute]: return tuple(r.type for r in self.results) @property - def operands_types(self) -> Sequence[Attribute]: + def operand_types(self) -> Sequence[Attribute]: return tuple(operand.type for operand in self.operands) def parent_op(self) -> Operation | None: @@ -904,7 +904,7 @@ def clone_without_regions( (value_mapper[operand] if operand in value_mapper else operand) for operand in self.operands ] - result_types = self.results_types + result_types = self.result_types attributes = self.attributes.copy() properties = self.properties.copy() successors = [ @@ -1223,7 +1223,7 @@ def __init__( self.add_ops(ops) @property - def args_types(self) -> Sequence[Attribute]: + def arg_types(self) -> Sequence[Attribute]: return tuple(arg.type for arg in self._args) @property diff --git a/xdsl/irdl/operations.py b/xdsl/irdl/operations.py index f00c2fd817..5526280494 100644 --- a/xdsl/irdl/operations.py +++ b/xdsl/irdl/operations.py @@ -1390,7 +1390,7 @@ def irdl_op_verify_regions( f"{len(region.blocks)} blocks" ) if (first_block := region.blocks.first) is not None: - entry_args_types = first_block.args_types + entry_args_types = first_block.arg_types try: region_def.entry_args.verify(entry_args_types, constraint_context) except Exception as e: diff --git a/xdsl/printer.py b/xdsl/printer.py index 2bd636458c..20aa8a5cee 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -938,7 +938,7 @@ def print_function_type( self.print_string(")") def print_operation_type(self, op: Operation) -> None: - self.print_function_type(op.operands_types, op.results_types) + self.print_function_type(op.operand_types, op.result_types) if self.print_debuginfo: self.print_string(" loc(unknown)") diff --git a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py index 98dde304bf..b6adb4a787 100644 --- a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py +++ b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py @@ -570,7 +570,7 @@ def from_mutable( op.name, op_type, operands, - op.results_types, + op.result_types, properties, attributes, successors, diff --git a/xdsl/transforms/common_subexpression_elimination.py b/xdsl/transforms/common_subexpression_elimination.py index c361ef59e8..fc05983560 100644 --- a/xdsl/transforms/common_subexpression_elimination.py +++ b/xdsl/transforms/common_subexpression_elimination.py @@ -42,7 +42,7 @@ def __hash__(self): self.name, sum(hash(i) for i in self.op.attributes.items()), sum(hash(i) for i in self.op.properties.items()), - hash(self.op.results_types), + hash(self.op.result_types), hash(self.op.operands), ) ) @@ -55,7 +55,7 @@ def __eq__(self, other: object): and self.op.attributes == other.op.attributes and self.op.properties == other.op.properties and self.op.operands == other.op.operands - and self.op.results_types == other.op.results_types + and self.op.result_types == other.op.result_types and all( s.is_structurally_equivalent(o) for s, o in zip(self.op.regions, other.op.regions, strict=True) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 16e62ea7fd..dc2ebdc165 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -285,7 +285,7 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): new_load = LoadOp.create( operands=[op.field], - result_types=load.results_types, + result_types=load.result_types, attributes=load.attributes.copy(), properties=load.properties.copy(), ) @@ -306,13 +306,13 @@ class UpdateApplyArgs(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): new_arg_types = tuple(o.type for o in op.args) - if new_arg_types == op.region.block.args_types: + if new_arg_types == op.region.block.arg_types: return new_block = Block(arg_types=new_arg_types) new_apply = ApplyOp.create( operands=op.operands, - result_types=op.results_types, + result_types=op.result_types, properties=op.properties.copy(), attributes=op.attributes.copy(), regions=[Region(new_block)], @@ -406,7 +406,7 @@ def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter): new_upper = op.upper new_lowerext = op.lowerext new_upperext = op.upperext - new_results_types = list(op.results_types) + new_results_types = list(op.result_types) new_results_types.pop(i) bounds = cast(StencilBoundsAttr, cast(TempType[Attribute], r.type).bounds) diff --git a/xdsl/transforms/stencil_to_csl_stencil.py b/xdsl/transforms/stencil_to_csl_stencil.py index c84284980b..2c123e7fda 100644 --- a/xdsl/transforms/stencil_to_csl_stencil.py +++ b/xdsl/transforms/stencil_to_csl_stencil.py @@ -250,7 +250,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): ) # rebuild stencil.apply op - r_types = apply_op.results_types + r_types = apply_op.result_types assert isa(r_types, Sequence[stencil.TempType[Attribute]]) new_apply_op = stencil.ApplyOp.build( operands=[[*apply_op.args, prefetch_op.result], apply_op.dest], @@ -521,7 +521,7 @@ def get_prefetch_overhead(o: OpResult): chunk_reduce, post_process, ], - result_types=[op.results_types], + result_types=[op.result_types], ) ) diff --git a/xdsl/transforms/stencil_unroll.py b/xdsl/transforms/stencil_unroll.py index 5418cc328b..f65d91bb90 100644 --- a/xdsl/transforms/stencil_unroll.py +++ b/xdsl/transforms/stencil_unroll.py @@ -78,7 +78,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): return # Enforced by verification - res_types = op.results_types + res_types = op.result_types assert isa(res_types, Sequence[TempType[Attribute]]) dim = res_types[0].get_num_dims() From c2f3f58ea581a6c6e20740e2417c95f0c60387e2 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 13 Aug 2024 14:16:19 +0100 Subject: [PATCH 5/5] IRDL type helpers. --- tests/dialects/test_mpi_lowering.py | 2 +- .../lowering/convert_func_to_riscv_func.py | 6 ++--- .../lowering/convert_scf_to_riscv_scf.py | 6 ++--- xdsl/backend/riscv/register_allocation.py | 4 +-- xdsl/dialects/affine.py | 2 +- xdsl/dialects/comb.py | 4 +-- xdsl/dialects/csl/csl_stencil.py | 2 +- xdsl/dialects/csl/csl_wrapper.py | 4 +-- xdsl/dialects/fsm.py | 24 ++++-------------- xdsl/dialects/func.py | 2 +- xdsl/dialects/gpu.py | 2 +- xdsl/dialects/hw.py | 2 +- xdsl/dialects/linalg.py | 8 +++--- xdsl/dialects/memref.py | 2 +- xdsl/dialects/memref_stream.py | 8 +++--- xdsl/dialects/snitch_stream.py | 4 +-- xdsl/dialects/stencil.py | 8 +++--- xdsl/irdl/declarative_assembly_format.py | 8 ++---- xdsl/irdl/operations.py | 25 ++++++++++++++++--- xdsl/transforms/csl_stencil_bufferize.py | 4 +-- xdsl/transforms/lower_mpi.py | 9 ++++--- xdsl/transforms/reconcile_unrealized_casts.py | 4 +-- xdsl/transforms/stencil_bufferize.py | 10 ++------ 23 files changed, 69 insertions(+), 81 deletions(-) diff --git a/tests/dialects/test_mpi_lowering.py b/tests/dialects/test_mpi_lowering.py index c87ab8a0d1..c08398e7d4 100644 --- a/tests/dialects/test_mpi_lowering.py +++ b/tests/dialects/test_mpi_lowering.py @@ -101,7 +101,7 @@ def test_lower_mpi_wait_with_status(): assert call.callee.string_value() == "MPI_Wait" assert len(call.arguments) == 2 assert isinstance(call.arguments[1], OpResult) - assert isinstance(call.arguments[1].op, llvm.AllocaOp) + assert isinstance(call.arguments[1].owner, llvm.AllocaOp) def test_lower_mpi_comm_rank(): diff --git a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py index c479cb4d99..54a39bfdc3 100644 --- a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py +++ b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py @@ -73,7 +73,7 @@ def match_and_rewrite(self, op: func.Call, rewriter: PatternRewriter) -> None: raise ValueError("Cannot lower func.call with more than 2 results") cast_operand_ops, register_operands = cast_to_regs(op.arguments) - operand_types = tuple(arg.type for arg in op.arguments) + operand_types = op.arguments.types move_operand_ops, moved_operands = move_to_a_regs( register_operands, operand_types ) @@ -109,9 +109,7 @@ def match_and_rewrite(self, op: func.Return, rewriter: PatternRewriter): raise ValueError("Cannot lower func.return with more than 2 arguments") cast_ops, register_values = cast_to_regs(op.arguments) - move_ops, moved_values = move_to_a_regs( - register_values, tuple(arg.type for arg in op.arguments) - ) + move_ops, moved_values = move_to_a_regs(register_values, op.arguments.types) rewriter.insert_op_before_matched_op(cast_ops) rewriter.insert_op_before_matched_op(move_ops) diff --git a/xdsl/backend/riscv/lowering/convert_scf_to_riscv_scf.py b/xdsl/backend/riscv/lowering/convert_scf_to_riscv_scf.py index 4ce1f1aec9..2de053b4ad 100644 --- a/xdsl/backend/riscv/lowering/convert_scf_to_riscv_scf.py +++ b/xdsl/backend/riscv/lowering/convert_scf_to_riscv_scf.py @@ -22,14 +22,12 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter) -> None: lb, ub, step, *args = cast_operands_to_regs(rewriter) new_region = rewriter.move_region_contents_to_new_regions(op.body) cast_block_args_to_regs(new_region.block, rewriter) - mv_ops, values = move_to_unallocated_regs( - args, tuple(arg.type for arg in op.iter_args) - ) + mv_ops, values = move_to_unallocated_regs(args, op.iter_args.types) rewriter.insert_op_before_matched_op(mv_ops) cast_matched_op_results(rewriter) new_op = riscv_scf.ForOp(lb, ub, step, values, new_region) mv_res_ops, res_values = move_to_unallocated_regs( - new_op.results, tuple(arg.type for arg in op.iter_args) + new_op.results, op.iter_args.types ) rewriter.replace_matched_op((new_op, *mv_res_ops), res_values) diff --git a/xdsl/backend/riscv/register_allocation.py b/xdsl/backend/riscv/register_allocation.py index 75f34a7759..bb0151f297 100644 --- a/xdsl/backend/riscv/register_allocation.py +++ b/xdsl/backend/riscv/register_allocation.py @@ -270,7 +270,7 @@ def allocate_for_loop(self, loop: riscv_scf.ForOp) -> None: self.allocate(loop.step) # Reserve the loop carried variables for allocation within the body - regs = tuple(arg.type for arg in loop.iter_args) + regs = loop.iter_args.types assert all(isinstance(reg, IntRegisterType | FloatRegisterType) for reg in regs) regs = cast(tuple[IntRegisterType | FloatRegisterType], regs) with self.available_registers.reserve_registers(regs): @@ -310,7 +310,7 @@ def allocate_frep_loop(self, loop: riscv_snitch.FRepOperation) -> None: self.allocate(loop.max_rep) # Reserve the loop carried variables for allocation within the body - regs = tuple(arg.type for arg in loop.iter_args) + regs = loop.iter_args.types assert all(isinstance(reg, IntRegisterType | FloatRegisterType) for reg in regs) regs = cast(tuple[IntRegisterType | FloatRegisterType], regs) with self.available_registers.reserve_registers(regs): diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index 23dc742d69..8e75e798bc 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -142,7 +142,7 @@ def verify_(self) -> None: raise VerifyException( "Expected as many upper bound operands as upper bound dimensions and symbols." ) - iter_types = tuple(op.type for op in self.inits) + iter_types = self.inits.types if iter_types != self.result_types: raise VerifyException( "Expected all operands and result pairs to have matching types" diff --git a/xdsl/dialects/comb.py b/xdsl/dialects/comb.py index e4c6c811e3..3d8c58e3df 100644 --- a/xdsl/dialects/comb.py +++ b/xdsl/dialects/comb.py @@ -490,7 +490,7 @@ def from_int_values(inputs: Sequence[SSAValue]) -> "ConcatOp | None": return ConcatOp(inputs, IntegerType(sum_of_width)) def verify_(self) -> None: - sum_of_width = _get_sum_of_int_width([inp.type for inp in self.inputs]) + sum_of_width = _get_sum_of_int_width(self.inputs.types) assert sum_of_width is not None assert isinstance(self.result.type, IntegerType) if sum_of_width != self.result.type.width.data: @@ -519,7 +519,7 @@ def print(self, printer: Printer): printer.print(" ") printer.print_list(self.inputs, printer.print_ssa_value) printer.print(" : ") - printer.print_list([inp.type for inp in self.inputs], printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) @irdl_op_definition diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index 0f0d1224b3..b2952ed72d 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -251,7 +251,7 @@ def print_arg(arg: SSAValue): printer.print_list(self.dest, print_arg) else: printer.print(") -> (") - printer.print_list((r.type for r in self.res), printer.print_attribute) + printer.print_list(self.res.types, printer.print_attribute) printer.print(") ") printer.print("<") diff --git a/xdsl/dialects/csl/csl_wrapper.py b/xdsl/dialects/csl/csl_wrapper.py index 57ce331b7e..544ee254e1 100644 --- a/xdsl/dialects/csl/csl_wrapper.py +++ b/xdsl/dialects/csl/csl_wrapper.py @@ -256,7 +256,7 @@ def verify_(self): # verify that block args are of the right type for the provided params for arg, param in zip( - [a.type for a in self.layout_module.block.args[4:]], + self.layout_module.block.arg_types[4:], self.params, strict=True, ): @@ -275,7 +275,7 @@ def verify_(self): # verify that params and yielded arguments are typed correctly # these may be followed by input-output symbols which we cannot verify, therefore setting `strict=False` for got, (name, exp) in zip( - [a.type for a in self.program_module.block.args[2:]], + self.program_module.block.arg_types[2:], itertools.chain( ( (param.key.data, cast(Attribute, param.type)) diff --git a/xdsl/dialects/fsm.py b/xdsl/dialects/fsm.py index d4736fa969..edb83a0d60 100644 --- a/xdsl/dialects/fsm.py +++ b/xdsl/dialects/fsm.py @@ -438,21 +438,13 @@ def verify_(self): if m is None: raise VerifyException("Machine definition does not exist.") - if not ( - [operand.type for operand in self.inputs] - == [result for result in m.function_type.inputs] - and len(self.inputs) == len(m.function_type.inputs) - ): + if self.inputs.types != tuple(result for result in m.function_type.inputs): raise VerifyException( "TriggerOp input types must be consistent with the machine " + str(m.sym_name) ) - if not ( - [operand.type for operand in self.outputs] - == [result for result in m.function_type.outputs] - and len(self.outputs) == len(m.function_type.outputs) - ): + if self.outputs.types != tuple(result for result in m.function_type.outputs): raise VerifyException( "TriggerOp output types must be consistent with the machine " + str(m.sym_name) @@ -506,21 +498,15 @@ def __init__( def verify_(self): m = SymbolTable.lookup_symbol(self, self.machine) if isinstance(m, MachineOp): - if not ( - [operand.type for operand in self.inputs] - == [result for result in m.function_type.inputs] - and len(self.inputs) == len(m.function_type.inputs) - ): + if self.inputs.types != tuple(result for result in m.function_type.inputs): raise VerifyException( "HWInstanceOp " + str(self.sym_name) + " input type must be consistent with the machine " + str(m.sym_name) ) - if not ( - [operand.type for operand in self.outputs] - == [result for result in m.function_type.outputs] - and len(self.outputs) == len(m.function_type.outputs) + if self.outputs.types != tuple( + result for result in m.function_type.outputs ): raise VerifyException( "HWInstanceOp " diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index b42daba236..a21f02b249 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -326,7 +326,7 @@ def verify_(self) -> None: assert isinstance(func_op, FuncOp) function_return_types = func_op.function_type.outputs.data - return_types = tuple(arg.type for arg in self.arguments) + return_types = self.arguments.types if function_return_types != return_types: raise VerifyException( "Expected arguments to have the same types as the function output types" diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index 7ac61ffc32..c4622d1ce9 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -757,7 +757,7 @@ def __init__(self, operands: Sequence[SSAValue | Operation]): def verify_(self) -> None: op = self.parent_op() if op is not None: - yield_type = tuple(o.type for o in self.values) + yield_type = self.values.types result_type = op.result_types if yield_type != result_type: raise VerifyException( diff --git a/xdsl/dialects/hw.py b/xdsl/dialects/hw.py index 8c433fdfcc..341e38134e 100644 --- a/xdsl/dialects/hw.py +++ b/xdsl/dialects/hw.py @@ -1301,7 +1301,7 @@ def print(self, printer: Printer): printer.print(" ") printer.print_list(self.inputs, printer.print_operand) printer.print(" : ") - printer.print_list((x.type for x in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) HW = Dialect( diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 1391a92c08..bc0eedbd03 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -221,14 +221,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") extra_attrs = self.attributes.copy() @@ -503,14 +503,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") if extra_attrs and not self.PRINT_ATTRS_IN_FRONT: diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 2c121523f3..8721f40c7f 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -297,7 +297,7 @@ class AllocaScopeReturnOp(IRDLOperation): def verify_(self) -> None: parent = cast(AllocaScopeOp, self.parent_op()) - if any(op.type != res.type for op, res in zip(self.ops, parent.results)): + if self.ops.types != parent.result_types: raise VerifyException( "Expected operand types to match parent's return types." ) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 92658b8b1d..f799bf93a0 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -265,14 +265,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") if self.attributes: @@ -532,14 +532,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") if self.inits: diff --git a/xdsl/dialects/snitch_stream.py b/xdsl/dialects/snitch_stream.py index 668f3dcfa7..d89db93078 100644 --- a/xdsl/dialects/snitch_stream.py +++ b/xdsl/dialects/snitch_stream.py @@ -292,14 +292,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") if self.attributes: diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index 5aa54d2bc5..43c5f9a023 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -485,7 +485,7 @@ def print_destination_operand(dest: SSAValue): printer.print("(") printer.print_list( - zip(self.region.block.args, self.args, (a.type for a in self.args)), + zip(self.region.block.args, self.args, self.args.types), print_assign_argument, ) if self.dest: @@ -493,7 +493,7 @@ def print_destination_operand(dest: SSAValue): printer.print_list(self.dest, print_destination_operand) else: printer.print(") -> (") - printer.print_list((r.type for r in self.res), printer.print_attribute) + printer.print_list(self.res.types, printer.print_attribute) printer.print(") ") printer.print_op_attributes(self.attributes, print_keyword=True) printer.print_region(self.region, print_entry_block_args=False) @@ -1433,9 +1433,7 @@ def get(res: Sequence[SSAValue | Operation]): def verify_(self) -> None: unroll_factor = self.unroll_factor - types = [ - o.type.elem if isinstance(o.type, ResultType) else o.type for o in self.arg - ] + types = [ot.elem if isinstance(ot, ResultType) else ot for ot in self.arg.types] apply = cast(ApplyOp, self.parent_op()) if len(apply.res) > 0: res_types = [ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 6a0255fea4..c7f70d6db7 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -552,9 +552,7 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - printer.print_list( - (o.type for o in getattr(op, self.name)), printer.print_attribute - ) + printer.print_list(getattr(op, self.name).types, printer.print_attribute) state.last_was_punctuation = False state.should_emit_space = True @@ -694,9 +692,7 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - printer.print_list( - (r.type for r in getattr(op, self.name)), printer.print_attribute - ) + printer.print_list(getattr(op, self.name).types, printer.print_attribute) state.last_was_punctuation = False state.should_emit_space = True diff --git a/xdsl/irdl/operations.py b/xdsl/irdl/operations.py index 5526280494..a98cd59f99 100644 --- a/xdsl/irdl/operations.py +++ b/xdsl/irdl/operations.py @@ -305,7 +305,10 @@ def __init__( self.constr = range_constr_coercion(attr) -VarOperand: TypeAlias = tuple[SSAValue, ...] +class VarOperand(tuple[Operand, ...]): + @property + def types(self): + return tuple(o.type for o in self) @dataclass(init=False) @@ -341,7 +344,10 @@ def __init__( self.constr = range_constr_coercion(attr) -VarOpResult: TypeAlias = tuple[OpResult, ...] +class VarOpResult(tuple[OpResult, ...]): + @property + def types(self): + return tuple(r.type for r in self) @dataclass(init=False) @@ -1339,7 +1345,12 @@ def get_operand_result_or_region( construct: VarIRConstruct, ) -> ( None - | SSAValue + | Operand + | VarOperand + | OptOperand + | OpResult + | VarOpResult + | OptOpResult | Sequence[SSAValue] | Sequence[OpResult] | Region @@ -1374,7 +1385,13 @@ def get_operand_result_or_region( return args[begin_arg] if isinstance(defs[arg_def_idx][1], VariadicDef): arg_size = variadic_sizes[previous_var_args] - return args[begin_arg : begin_arg + arg_size] + values = args[begin_arg : begin_arg + arg_size] + if isinstance(defs[arg_def_idx][1], OperandDef): + return VarOperand(cast(Sequence[Operand], values)) + elif isinstance(defs[arg_def_idx][1], ResultDef): + return VarOpResult(cast(Sequence[OpResult], values)) + else: + return values else: return args[begin_arg] diff --git a/xdsl/transforms/csl_stencil_bufferize.py b/xdsl/transforms/csl_stencil_bufferize.py index cdc6e0f990..2d0b2c8960 100644 --- a/xdsl/transforms/csl_stencil_bufferize.py +++ b/xdsl/transforms/csl_stencil_bufferize.py @@ -93,7 +93,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, # create new op buf_apply_op = csl_stencil.ApplyOp( operands=[op.communicated_stencil, buf_iter_arg.memref, op.args, op.dest], - result_types=[t.type for t in op.res] or [[]], + result_types=op.res.types or [[]], regions=[ self._get_empty_bufferized_region(op.chunk_reduce.block.args), self._get_empty_bufferized_region(op.post_process.block.args), @@ -267,7 +267,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): rewriter.replace_matched_op( func.FuncOp.build( operands=op.operands, - result_types=[r.type for r in op.results], + result_types=op.result_types, regions=[op.detach_region(op.body)], properties={**op.properties, "function_type": function_type}, attributes=op.attributes.copy(), diff --git a/xdsl/transforms/lower_mpi.py b/xdsl/transforms/lower_mpi.py index d103846ad1..362fde7a11 100644 --- a/xdsl/transforms/lower_mpi.py +++ b/xdsl/transforms/lower_mpi.py @@ -1,4 +1,5 @@ from abc import ABC +from collections.abc import Sequence from dataclasses import dataclass from math import prod from typing import TypeVar, cast @@ -753,7 +754,9 @@ class MpiAddExternalFuncDefs(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, /): # collect all func calls to MPI functions - funcs_to_emit: dict[str, tuple[list[Attribute], list[Attribute]]] = dict() + funcs_to_emit: dict[str, tuple[Sequence[Attribute], Sequence[Attribute]]] = ( + dict() + ) for op in module.walk(): if not isinstance(op, func.Call): @@ -761,8 +764,8 @@ def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, if op.callee.string_value() not in self.mpi_func_call_names: continue funcs_to_emit[op.callee.string_value()] = ( - [arg.type for arg in op.arguments], - [res.type for res in op.results], + op.arguments.types, + op.result_types, ) # for each func found, add a FuncOp to the top of the module. diff --git a/xdsl/transforms/reconcile_unrealized_casts.py b/xdsl/transforms/reconcile_unrealized_casts.py index 050c2523fb..b77afd8ec1 100644 --- a/xdsl/transforms/reconcile_unrealized_casts.py +++ b/xdsl/transforms/reconcile_unrealized_casts.py @@ -69,9 +69,7 @@ def gen_all_uses_cast( # because types are homogeneous (e.g. {A -> B, B -> A}) # otherwise it means the cast is not unifiable with its uses assert len(cast.results) == len(op.inputs) - has_trivial_cycle = all( - r.type == i.type for r, i in zip(cast.results, op.inputs) - ) + has_trivial_cycle = cast.result_types == op.inputs.types if is_live and not has_trivial_cycle: if warn_on_failure: warn( diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index c5a663bc77..72bd54db43 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -272,16 +272,10 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): properties=apply.properties.copy(), attributes=apply.attributes.copy(), regions=[ - Region(Block(arg_types=[SSAValue.get(a).type for a in apply.args])), + apply.detach_region(0), ], ) - rewriter.inline_block( - apply.region.block, - InsertPoint.at_start(new_apply.region.block), - new_apply.region.block.args, - ) - new_load = LoadOp.create( operands=[op.field], result_types=load.result_types, @@ -304,7 +298,7 @@ class UpdateApplyArgs(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): - new_arg_types = tuple(o.type for o in op.args) + new_arg_types = op.args.types if new_arg_types == op.region.block.arg_types: return