From 2585bbf5b82c3e4db7359ac0cd745edcbbc006e9 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 13 Aug 2024 13:11:48 +0100 Subject: [PATCH 1/4] Simple core helpers for type access. --- tests/dialects/test_hw.py | 4 ++-- tests/dialects/test_linalg.py | 2 +- tests/dialects/test_pdl.py | 2 +- tests/irdl/test_operation_builder.py | 20 +++++++++---------- .../pattern_rewriter/test_pattern_rewriter.py | 2 +- tests/test_ir.py | 4 ++-- .../lowering/convert_func_to_riscv_func.py | 2 +- .../lowering/convert_riscv_scf_to_riscv_cf.py | 8 +++----- .../convert_snitch_stream_to_snitch.py | 2 +- xdsl/dialects/affine.py | 8 ++++---- xdsl/dialects/csl/csl.py | 6 ++---- xdsl/dialects/fsm.py | 6 +----- xdsl/dialects/func.py | 6 +++--- xdsl/dialects/gpu.py | 18 ++++++++--------- xdsl/dialects/hw.py | 2 +- xdsl/dialects/pdl.py | 2 +- xdsl/ir/core.py | 14 ++++++++++++- xdsl/irdl/operations.py | 2 +- xdsl/printer.py | 4 +--- .../immutable_ir/immutable_ir.py | 2 +- .../common_subexpression_elimination.py | 8 +++----- xdsl/transforms/stencil_bufferize.py | 14 ++++++------- xdsl/transforms/stencil_to_csl_stencil.py | 4 ++-- xdsl/transforms/stencil_unroll.py | 4 ++-- 24 files changed, 71 insertions(+), 75 deletions(-) diff --git a/tests/dialects/test_hw.py b/tests/dialects/test_hw.py index a4b2e5ef62..7ad8e70db5 100644 --- a/tests/dialects/test_hw.py +++ b/tests/dialects/test_hw.py @@ -364,8 +364,8 @@ def test_instance_builder(): assert inst_op.arg_names.data == (StringAttr("foo"), StringAttr("bar")) assert inst_op.result_names.data == (StringAttr("baz"), StringAttr("qux")) - assert [op.type for op in inst_op.operands] == [i32, i64] - assert [res.type for res in inst_op.results] == [i32, i64] + assert inst_op.operands_types == (i32, i64) + assert inst_op.results_types == (i32, i64) def test_hwmoduleop_hwmodulelike(): diff --git a/tests/dialects/test_linalg.py b/tests/dialects/test_linalg.py index 2d632fa8ac..12f5081ff3 100644 --- a/tests/dialects/test_linalg.py +++ b/tests/dialects/test_linalg.py @@ -43,7 +43,7 @@ def test_matmul_on_memrefs(): matmul_op = linalg.MatmulOp(inputs=(a.memref, b.memref), outputs=(c.memref,)) - assert tuple(result.type for result in matmul_op.results) == () + assert matmul_op.results_types == () def test_loop_range_methods(): diff --git a/tests/dialects/test_pdl.py b/tests/dialects/test_pdl.py index f334f1ac46..937ee58f5b 100644 --- a/tests/dialects/test_pdl.py +++ b/tests/dialects/test_pdl.py @@ -39,7 +39,7 @@ def test_build_anr(): assert anr.constraint_name == StringAttr("anr") assert anr.args == (type_val,) assert len(anr.results) == 1 - assert [r.type for r in anr.results] == [attribute_type] + assert anr.results_types == (attribute_type,) def test_build_rewrite(): diff --git a/tests/irdl/test_operation_builder.py b/tests/irdl/test_operation_builder.py index 12501a09d1..5a5fb2745a 100644 --- a/tests/irdl/test_operation_builder.py +++ b/tests/irdl/test_operation_builder.py @@ -57,7 +57,7 @@ class ResultOp(IRDLOperation): def test_result_builder(): op = ResultOp.build(result_types=[StringAttr("")]) op.verify() - assert [res.type for res in op.results] == [StringAttr("")] + assert op.results_types == (StringAttr(""),) def test_result_builder_exception(): @@ -79,7 +79,7 @@ def test_opt_result_builder(): op1.verify() op2.verify() op3.verify() - assert [res.type for res in op1.results] == [StringAttr("")] + assert op1.results_types == (StringAttr(""),) assert len(op2.results) == 0 assert len(op3.results) == 0 @@ -99,10 +99,10 @@ class VarResultOp(IRDLOperation): def test_var_result_builder(): op = VarResultOp.build(result_types=[[StringAttr("0"), StringAttr("1")]]) op.verify() - assert [res.type for res in op.results] == [ + assert op.results_types == ( StringAttr("0"), StringAttr("1"), - ] + ) @irdl_op_definition @@ -122,12 +122,12 @@ def test_two_var_result_builder(): ] ) op.verify() - assert [res.type for res in op.results] == [ + assert op.results_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), StringAttr("3"), - ] + ) assert op.attributes[ AttrSizedResultSegments.attribute_name @@ -142,12 +142,12 @@ def test_two_var_result_builder2(): ] ) op.verify() - assert [res.type for res in op.results] == [ + assert op.results_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), StringAttr("3"), - ] + ) assert op.attributes[ AttrSizedResultSegments.attribute_name ] == DenseArrayBase.from_list(i32, [1, 3]) @@ -172,13 +172,13 @@ def test_var_mixed_builder(): ] ) op.verify() - assert [res.type for res in op.results] == [ + assert op.results_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), StringAttr("3"), StringAttr("4"), - ] + ) assert op.attributes[ AttrSizedResultSegments.attribute_name diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index a32ac3fa88..e9cc19855b 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -1075,7 +1075,7 @@ def match_and_rewrite(self, op: ModuleOp, rewriter: PatternRewriter): old_op = next(ops_iter) assert isinstance(old_op, test.TestOp) new_region = rewriter.move_region_contents_to_new_regions(old_op.regions[0]) - res_types = [r.type for r in old_op.results] + res_types = old_op.results_types new_op = test.TestOp.create(result_types=res_types, regions=[new_region]) rewriter.insert_op(new_op, InsertPoint.after(old_op)) diff --git a/tests/test_ir.py b/tests/test_ir.py index 725bc97700..47ef25d985 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -488,8 +488,8 @@ def test_split_block_args(): new_block = old_block.split_before(op, arg_types=(i32, i64)) - arg_types = [a.type for a in new_block.args] - assert arg_types == [i32, i64] + arg_types = new_block.args_types + assert arg_types == (i32, i64) def test_region_clone_into_circular_blocks(): diff --git a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py index a6f2226c00..f2a54d5e50 100644 --- a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py +++ b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py @@ -31,7 +31,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): if (first_block := op.body.blocks.first) is not None: cast_block_args_from_a_regs(first_block, rewriter) - input_types = [arg.type for arg in first_block.args] + input_types = first_block.args_types else: input_types = tuple(a_regs_for_types(op.function_type.inputs.data)) result_types = list(a_regs_for_types(op.function_type.outputs.data)) diff --git a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py index 68e1ab3819..1ea29cbe87 100644 --- a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py +++ b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py @@ -78,15 +78,13 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /): init_block = op.parent_block() assert init_block is not None - body_args = op.body.blocks[0].args + body = op.body.blocks[0] # TODO: add method to rewriter - end_block = init_block.split_before( - op, arg_types=(arg.type for arg in body_args) - ) + end_block = init_block.split_before(op, arg_types=body.args_types) # The first argument of the loop body block is the loop counter by SCF invariant. - loop_var_reg = body_args[0].type + loop_var_reg = body.args[0].type assert isinstance(loop_var_reg, riscv.IntRegisterType) # Use the first block of the loop body as the condition block since it is the diff --git a/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py b/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py index 8fcfc5e7a2..f46b28e1cb 100644 --- a/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py +++ b/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py @@ -178,7 +178,7 @@ def match_and_rewrite( block = op.body.block rewriter.insert_op_before_matched_op( - enable_op := snitch.SsrEnable(tuple(arg.type for arg in block.args)) + enable_op := snitch.SsrEnable(block.args_types) ) for val, arg in zip(enable_op.streams, block.args): diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index 52ad520c9b..ebbc74372b 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -142,14 +142,14 @@ def verify_(self) -> None: raise VerifyException( "Expected as many upper bound operands as upper bound dimensions and symbols." ) - iter_types = [op.type for op in self.inits] - if iter_types != [res.type for res in self.results]: + iter_types = tuple(op.type for op in self.inits) + if iter_types != self.results_types: raise VerifyException( "Expected all operands and result pairs to have matching types" ) entry_block: Block = self.body.blocks[0] - block_arg_types = [IndexType()] + iter_types - arg_types = [arg.type for arg in entry_block.args] + block_arg_types = (IndexType(), *iter_types) + arg_types = entry_block.args_types if block_arg_types != arg_types: raise VerifyException( "Expected BlockArguments to have the same types as the operands" diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index e0b1de741c..22e8851151 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -151,7 +151,7 @@ def _verify(self): return entry_block: Block = self.body.blocks[0] - block_arg_types = [arg.type for arg in entry_block.args] + block_arg_types = entry_block.args_types if self.function_type.inputs.data != tuple(block_arg_types): raise VerifyException( "Expected entry block arguments to have the same types as the function " @@ -718,9 +718,7 @@ def verify_(self) -> None: func_op = self.parent_op() assert isinstance(func_op, FuncOp) or isinstance(func_op, TaskOp) - if tuple(func_op.function_type.outputs) != tuple( - val.type for val in self.operands - ): + if tuple(func_op.function_type.outputs.data) != self.operands_types: raise VerifyException( "Expected arguments to have the same types as the function output types" ) diff --git a/xdsl/dialects/fsm.py b/xdsl/dialects/fsm.py index 049f1cd206..499857bbba 100644 --- a/xdsl/dialects/fsm.py +++ b/xdsl/dialects/fsm.py @@ -205,11 +205,7 @@ def verify_(self): raise VerifyException("Transition regions should not output any value") while (parent := parent.parent_op()) is not None: if isinstance(parent, MachineOp): - if not ( - [operand.type for operand in self.operands] - == [result for result in parent.function_type.outputs] - and len(self.operands) == len(parent.function_type.outputs) - ): + if not (self.operands_types == parent.function_type.outputs.data): raise VerifyException( "OutputOp output type must be consistent with the machine " + str(parent.sym_name) diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 1bcfc8bfea..4825a78019 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -114,7 +114,7 @@ def verify_(self) -> None: # TODO: how to verify that there is a terminator? entry_block = self.body.blocks.first assert entry_block is not None - block_arg_types = [arg.type for arg in entry_block.args] + block_arg_types = entry_block.args_types if self.function_type.inputs.data != tuple(block_arg_types): raise VerifyException( "Expected entry block arguments to have the same types as the function " @@ -222,10 +222,10 @@ def update_function_type(self): not self.is_declaration ), "update_function_type does not work with function declarations!" return_op = self.get_return_op() - return_type: tuple[Attribute, ...] = self.function_type.outputs.data + return_type = self.function_type.outputs.data if return_op is not None: - return_type = tuple(arg.type for arg in return_op.operands) + return_type = return_op.operands_types self.properties["function_type"] = FunctionType.from_lists( [arg.type for arg in self.args], diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index 67d1519540..f66b88d9a0 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -245,7 +245,7 @@ def verify_(self) -> None: f"{self.operand.type}. They must be the same type for gpu.all_reduce" ) - non_empty_body = any(b.ops for b in self.body.blocks) + non_empty_body = len(self.body.blocks) > 0 op_attr = self.op is not None if non_empty_body == op_attr: if op_attr: @@ -258,9 +258,8 @@ def verify_(self) -> None: "gpu.all_reduce need either a non empty body or an op attribute." ) if non_empty_body: - region_args = self.body.blocks[0].args - args_types = [r.type for r in region_args] - if args_types != [self.result.type, self.operand.type]: + args_types = self.body.blocks[0].args_types + if args_types != (self.result.type, self.operand.type): raise VerifyException( f"Expected {[str(t) for t in [self.result.type, self.operand.type]]}, " f"got {[str(t) for t in args_types]}. A gpu.all_reduce's body must " @@ -431,7 +430,7 @@ def __init__( def verify_(self): entry_block: Block = self.body.blocks[0] function_inputs = self.function_type.inputs.data - block_arg_types = tuple(a.type for a in entry_block.args) + block_arg_types = entry_block.args_types if function_inputs != block_arg_types: raise VerifyException( "Expected first entry block arguments to have the same types as the " @@ -558,9 +557,8 @@ def __init__( def verify_(self) -> None: if not any(b.ops for b in self.body.blocks): raise VerifyException("gpu.launch requires a non-empty body.") - body_args = self.body.blocks[0].args - args_type = [a.type for a in body_args] - if args_type != [IndexType()] * 12: + args_type = self.body.blocks[0].args_types + if args_type != (IndexType(),) * 12: raise VerifyException( f"Expected [12 x {str(IndexType())}], got {[str(t) for t in args_type]}. " "gpu.launch's body arguments are 12 index arguments, with 3 block " @@ -759,8 +757,8 @@ def __init__(self, operands: Sequence[SSAValue | Operation]): def verify_(self) -> None: op = self.parent_op() if op is not None: - yield_type = [o.type for o in self.values] - result_type = [r.type for r in op.results] + yield_type = tuple(o.type for o in self.values) + result_type = op.results_types if yield_type != result_type: raise VerifyException( f"Expected {[str(t) for t in result_type]}, got {[str(t) for t in yield_type]}. The gpu.yield values " diff --git a/xdsl/dialects/hw.py b/xdsl/dialects/hw.py index 0bdb0d243b..b0f1a533a5 100644 --- a/xdsl/dialects/hw.py +++ b/xdsl/dialects/hw.py @@ -1230,7 +1230,7 @@ def print_output_port(name: str, port_type: Attribute): printer.print_list( zip( (name.data for name in self.result_names), - (result.type for result in self.results), + self.results_types, ), lambda x: print_output_port(*x), ) diff --git a/xdsl/dialects/pdl.py b/xdsl/dialects/pdl.py index 2b341649d6..02c5f18ef1 100644 --- a/xdsl/dialects/pdl.py +++ b/xdsl/dialects/pdl.py @@ -260,7 +260,7 @@ def print(self, printer: Printer) -> None: printer.print(")") if len(self.results) != 0: printer.print(" : ") - printer.print_list([res.type for res in self.results], printer.print) + printer.print_list(self.results_types, printer.print) @irdl_op_definition diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index e0607efe90..a9aff23181 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -614,6 +614,14 @@ class Operation(IRNode): def parent_node(self) -> IRNode | None: return self.parent + @property + def results_types(self) -> Sequence[Attribute]: + return tuple(r.type for r in self.results) + + @property + def operands_types(self) -> Sequence[Attribute]: + return tuple(operand.type for operand in self.operands) + def parent_op(self) -> Operation | None: if p := self.parent_region(): return p.parent @@ -896,7 +904,7 @@ def clone_without_regions( (value_mapper[operand] if operand in value_mapper else operand) for operand in self.operands ] - result_types = [res.type for res in self.results] + result_types = self.results_types attributes = self.attributes.copy() properties = self.properties.copy() successors = [ @@ -1214,6 +1222,10 @@ def __init__( self.add_ops(ops) + @property + def args_types(self) -> Sequence[Attribute]: + return tuple(arg.type for arg in self._args) + @property def parent_node(self) -> IRNode | None: return self.parent diff --git a/xdsl/irdl/operations.py b/xdsl/irdl/operations.py index 54ad8de93b..f00c2fd817 100644 --- a/xdsl/irdl/operations.py +++ b/xdsl/irdl/operations.py @@ -1390,7 +1390,7 @@ def irdl_op_verify_regions( f"{len(region.blocks)} blocks" ) if (first_block := region.blocks.first) is not None: - entry_args_types = tuple(a.type for a in first_block.args) + entry_args_types = first_block.args_types try: region_def.entry_args.verify(entry_args_types, constraint_context) except Exception as e: diff --git a/xdsl/printer.py b/xdsl/printer.py index f68a05e14f..2bd636458c 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -938,9 +938,7 @@ def print_function_type( self.print_string(")") def print_operation_type(self, op: Operation) -> None: - self.print_function_type( - (o.type for o in op.operands), (r.type for r in op.results) - ) + self.print_function_type(op.operands_types, op.results_types) if self.print_debuginfo: self.print_string(" loc(unknown)") diff --git a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py index 4203983bdb..98dde304bf 100644 --- a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py +++ b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py @@ -570,7 +570,7 @@ def from_mutable( op.name, op_type, operands, - [result.type for result in op.results], + op.results_types, properties, attributes, successors, diff --git a/xdsl/transforms/common_subexpression_elimination.py b/xdsl/transforms/common_subexpression_elimination.py index 9da5fed014..c361ef59e8 100644 --- a/xdsl/transforms/common_subexpression_elimination.py +++ b/xdsl/transforms/common_subexpression_elimination.py @@ -42,7 +42,7 @@ def __hash__(self): self.name, sum(hash(i) for i in self.op.attributes.items()), sum(hash(i) for i in self.op.properties.items()), - hash(tuple(i.type for i in self.op.results)), + hash(self.op.results_types), hash(self.op.operands), ) ) @@ -55,12 +55,10 @@ def __eq__(self, other: object): and self.op.attributes == other.op.attributes and self.op.properties == other.op.properties and self.op.operands == other.op.operands - and len(self.op.results) == len(other.op.results) - and all(r.type == o.type for r, o in zip(self.op.results, other.op.results)) - and len(self.op.regions) == len(other.op.regions) + and self.op.results_types == other.op.results_types and all( s.is_structurally_equivalent(o) - for s, o in zip(self.op.regions, other.op.regions) + for s, o in zip(self.op.regions, other.op.regions, strict=True) ) ) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 34f85dfc77..16e62ea7fd 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -1,6 +1,5 @@ from collections.abc import Generator from dataclasses import dataclass -from itertools import chain from typing import Any, TypeVar, cast from xdsl.context import MLContext @@ -286,7 +285,7 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): new_load = LoadOp.create( operands=[op.field], - result_types=[r.type for r in load.results], + result_types=load.results_types, attributes=load.attributes.copy(), properties=load.properties.copy(), ) @@ -306,14 +305,14 @@ class UpdateApplyArgs(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): - new_arg_types = [o.type for o in op.args] - if new_arg_types == [a.type for a in op.region.block.args]: + new_arg_types = tuple(o.type for o in op.args) + if new_arg_types == op.region.block.args_types: return new_block = Block(arg_types=new_arg_types) new_apply = ApplyOp.create( operands=op.operands, - result_types=[r.type for r in op.results], + result_types=op.results_types, properties=op.properties.copy(), attributes=op.attributes.copy(), regions=[Region(new_block)], @@ -407,9 +406,8 @@ def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter): new_upper = op.upper new_lowerext = op.lowerext new_upperext = op.upperext - new_results_types = [ - r.type for r in chain(op.results[:i], op.results[i + 1 :]) - ] + new_results_types = list(op.results_types) + new_results_types.pop(i) bounds = cast(StencilBoundsAttr, cast(TempType[Attribute], r.type).bounds) newub = list(bounds.ub) diff --git a/xdsl/transforms/stencil_to_csl_stencil.py b/xdsl/transforms/stencil_to_csl_stencil.py index b34d72a9e2..c84284980b 100644 --- a/xdsl/transforms/stencil_to_csl_stencil.py +++ b/xdsl/transforms/stencil_to_csl_stencil.py @@ -250,7 +250,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): ) # rebuild stencil.apply op - r_types = [r.type for r in apply_op.results] + r_types = apply_op.results_types assert isa(r_types, Sequence[stencil.TempType[Attribute]]) new_apply_op = stencil.ApplyOp.build( operands=[[*apply_op.args, prefetch_op.result], apply_op.dest], @@ -521,7 +521,7 @@ def get_prefetch_overhead(o: OpResult): chunk_reduce, post_process, ], - result_types=[r.type for r in op.results] or [[]], + result_types=[op.results_types], ) ) diff --git a/xdsl/transforms/stencil_unroll.py b/xdsl/transforms/stencil_unroll.py index a672483e9b..5418cc328b 100644 --- a/xdsl/transforms/stencil_unroll.py +++ b/xdsl/transforms/stencil_unroll.py @@ -78,8 +78,8 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): return # Enforced by verification - res_types = [r.type for r in op.results] - assert isa(res_types, list[TempType[Attribute]]) + res_types = op.results_types + assert isa(res_types, Sequence[TempType[Attribute]]) dim = res_types[0].get_num_dims() # If unroll factors list is shorter than the dim, fill with ones from the front From fb61fbb482b6ceb6ef4c00052fc7b461fc9d6e9f Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 13 Aug 2024 13:34:02 +0100 Subject: [PATCH 2/4] Update xdsl/dialects/fsm.py Co-authored-by: Sasha Lopoukhine --- xdsl/dialects/fsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/dialects/fsm.py b/xdsl/dialects/fsm.py index 499857bbba..a428ff3920 100644 --- a/xdsl/dialects/fsm.py +++ b/xdsl/dialects/fsm.py @@ -205,7 +205,7 @@ def verify_(self): raise VerifyException("Transition regions should not output any value") while (parent := parent.parent_op()) is not None: if isinstance(parent, MachineOp): - if not (self.operands_types == parent.function_type.outputs.data): + if self.operands_types != parent.function_type.outputs.data: raise VerifyException( "OutputOp output type must be consistent with the machine " + str(parent.sym_name) From 481c893e22e5a1c6a4f14bd51b53480161d217c7 Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 13 Aug 2024 13:34:07 +0100 Subject: [PATCH 3/4] Update xdsl/dialects/gpu.py Co-authored-by: Sasha Lopoukhine --- xdsl/dialects/gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index f66b88d9a0..445fd849dc 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -245,7 +245,7 @@ def verify_(self) -> None: f"{self.operand.type}. They must be the same type for gpu.all_reduce" ) - non_empty_body = len(self.body.blocks) > 0 + non_empty_body = bool(self.body.blocks) op_attr = self.op is not None if non_empty_body == op_attr: if op_attr: From 01fb2285d777edde6dc51a09a78d2feb2c2a0dfa Mon Sep 17 00:00:00 2001 From: Emilien Bauer Date: Tue, 13 Aug 2024 13:35:37 +0100 Subject: [PATCH 4/4] Consistent naming. --- tests/dialects/test_hw.py | 4 ++-- tests/dialects/test_linalg.py | 2 +- tests/dialects/test_pdl.py | 2 +- tests/irdl/test_operation_builder.py | 12 ++++++------ tests/pattern_rewriter/test_pattern_rewriter.py | 2 +- tests/test_ir.py | 2 +- .../riscv/lowering/convert_func_to_riscv_func.py | 2 +- .../riscv/lowering/convert_riscv_scf_to_riscv_cf.py | 2 +- .../lowering/convert_snitch_stream_to_snitch.py | 2 +- xdsl/dialects/affine.py | 4 ++-- xdsl/dialects/csl/csl.py | 4 ++-- xdsl/dialects/fsm.py | 2 +- xdsl/dialects/func.py | 4 ++-- xdsl/dialects/gpu.py | 8 ++++---- xdsl/dialects/hw.py | 2 +- xdsl/dialects/pdl.py | 2 +- xdsl/ir/core.py | 8 ++++---- xdsl/irdl/operations.py | 2 +- xdsl/printer.py | 2 +- .../immutable_ir/immutable_ir.py | 2 +- xdsl/transforms/common_subexpression_elimination.py | 4 ++-- xdsl/transforms/stencil_bufferize.py | 8 ++++---- xdsl/transforms/stencil_to_csl_stencil.py | 4 ++-- xdsl/transforms/stencil_unroll.py | 2 +- 24 files changed, 44 insertions(+), 44 deletions(-) diff --git a/tests/dialects/test_hw.py b/tests/dialects/test_hw.py index 7ad8e70db5..37cfa56501 100644 --- a/tests/dialects/test_hw.py +++ b/tests/dialects/test_hw.py @@ -364,8 +364,8 @@ def test_instance_builder(): assert inst_op.arg_names.data == (StringAttr("foo"), StringAttr("bar")) assert inst_op.result_names.data == (StringAttr("baz"), StringAttr("qux")) - assert inst_op.operands_types == (i32, i64) - assert inst_op.results_types == (i32, i64) + assert inst_op.operand_types == (i32, i64) + assert inst_op.result_types == (i32, i64) def test_hwmoduleop_hwmodulelike(): diff --git a/tests/dialects/test_linalg.py b/tests/dialects/test_linalg.py index 12f5081ff3..adf486c2ce 100644 --- a/tests/dialects/test_linalg.py +++ b/tests/dialects/test_linalg.py @@ -43,7 +43,7 @@ def test_matmul_on_memrefs(): matmul_op = linalg.MatmulOp(inputs=(a.memref, b.memref), outputs=(c.memref,)) - assert matmul_op.results_types == () + assert matmul_op.result_types == () def test_loop_range_methods(): diff --git a/tests/dialects/test_pdl.py b/tests/dialects/test_pdl.py index 937ee58f5b..1cd1867356 100644 --- a/tests/dialects/test_pdl.py +++ b/tests/dialects/test_pdl.py @@ -39,7 +39,7 @@ def test_build_anr(): assert anr.constraint_name == StringAttr("anr") assert anr.args == (type_val,) assert len(anr.results) == 1 - assert anr.results_types == (attribute_type,) + assert anr.result_types == (attribute_type,) def test_build_rewrite(): diff --git a/tests/irdl/test_operation_builder.py b/tests/irdl/test_operation_builder.py index 5a5fb2745a..f267ee3071 100644 --- a/tests/irdl/test_operation_builder.py +++ b/tests/irdl/test_operation_builder.py @@ -57,7 +57,7 @@ class ResultOp(IRDLOperation): def test_result_builder(): op = ResultOp.build(result_types=[StringAttr("")]) op.verify() - assert op.results_types == (StringAttr(""),) + assert op.result_types == (StringAttr(""),) def test_result_builder_exception(): @@ -79,7 +79,7 @@ def test_opt_result_builder(): op1.verify() op2.verify() op3.verify() - assert op1.results_types == (StringAttr(""),) + assert op1.result_types == (StringAttr(""),) assert len(op2.results) == 0 assert len(op3.results) == 0 @@ -99,7 +99,7 @@ class VarResultOp(IRDLOperation): def test_var_result_builder(): op = VarResultOp.build(result_types=[[StringAttr("0"), StringAttr("1")]]) op.verify() - assert op.results_types == ( + assert op.result_types == ( StringAttr("0"), StringAttr("1"), ) @@ -122,7 +122,7 @@ def test_two_var_result_builder(): ] ) op.verify() - assert op.results_types == ( + assert op.result_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), @@ -142,7 +142,7 @@ def test_two_var_result_builder2(): ] ) op.verify() - assert op.results_types == ( + assert op.result_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), @@ -172,7 +172,7 @@ def test_var_mixed_builder(): ] ) op.verify() - assert op.results_types == ( + assert op.result_types == ( StringAttr("0"), StringAttr("1"), StringAttr("2"), diff --git a/tests/pattern_rewriter/test_pattern_rewriter.py b/tests/pattern_rewriter/test_pattern_rewriter.py index e9cc19855b..0f673b9066 100644 --- a/tests/pattern_rewriter/test_pattern_rewriter.py +++ b/tests/pattern_rewriter/test_pattern_rewriter.py @@ -1075,7 +1075,7 @@ def match_and_rewrite(self, op: ModuleOp, rewriter: PatternRewriter): old_op = next(ops_iter) assert isinstance(old_op, test.TestOp) new_region = rewriter.move_region_contents_to_new_regions(old_op.regions[0]) - res_types = old_op.results_types + res_types = old_op.result_types new_op = test.TestOp.create(result_types=res_types, regions=[new_region]) rewriter.insert_op(new_op, InsertPoint.after(old_op)) diff --git a/tests/test_ir.py b/tests/test_ir.py index 47ef25d985..8ad55cde35 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -488,7 +488,7 @@ def test_split_block_args(): new_block = old_block.split_before(op, arg_types=(i32, i64)) - arg_types = new_block.args_types + arg_types = new_block.arg_types assert arg_types == (i32, i64) diff --git a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py index f2a54d5e50..c479cb4d99 100644 --- a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py +++ b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py @@ -31,7 +31,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter): if (first_block := op.body.blocks.first) is not None: cast_block_args_from_a_regs(first_block, rewriter) - input_types = first_block.args_types + input_types = first_block.arg_types else: input_types = tuple(a_regs_for_types(op.function_type.inputs.data)) result_types = list(a_regs_for_types(op.function_type.outputs.data)) diff --git a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py index 1ea29cbe87..24abc9f874 100644 --- a/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py +++ b/xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py @@ -81,7 +81,7 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /): body = op.body.blocks[0] # TODO: add method to rewriter - end_block = init_block.split_before(op, arg_types=body.args_types) + end_block = init_block.split_before(op, arg_types=body.arg_types) # The first argument of the loop body block is the loop counter by SCF invariant. loop_var_reg = body.args[0].type diff --git a/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py b/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py index f46b28e1cb..b049934222 100644 --- a/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py +++ b/xdsl/backend/riscv/lowering/convert_snitch_stream_to_snitch.py @@ -178,7 +178,7 @@ def match_and_rewrite( block = op.body.block rewriter.insert_op_before_matched_op( - enable_op := snitch.SsrEnable(block.args_types) + enable_op := snitch.SsrEnable(block.arg_types) ) for val, arg in zip(enable_op.streams, block.args): diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index ebbc74372b..23dc742d69 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -143,13 +143,13 @@ def verify_(self) -> None: "Expected as many upper bound operands as upper bound dimensions and symbols." ) iter_types = tuple(op.type for op in self.inits) - if iter_types != self.results_types: + if iter_types != self.result_types: raise VerifyException( "Expected all operands and result pairs to have matching types" ) entry_block: Block = self.body.blocks[0] block_arg_types = (IndexType(), *iter_types) - arg_types = entry_block.args_types + arg_types = entry_block.arg_types if block_arg_types != arg_types: raise VerifyException( "Expected BlockArguments to have the same types as the operands" diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index 22e8851151..1cce7ee47d 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -151,7 +151,7 @@ def _verify(self): return entry_block: Block = self.body.blocks[0] - block_arg_types = entry_block.args_types + block_arg_types = entry_block.arg_types if self.function_type.inputs.data != tuple(block_arg_types): raise VerifyException( "Expected entry block arguments to have the same types as the function " @@ -718,7 +718,7 @@ def verify_(self) -> None: func_op = self.parent_op() assert isinstance(func_op, FuncOp) or isinstance(func_op, TaskOp) - if tuple(func_op.function_type.outputs.data) != self.operands_types: + if tuple(func_op.function_type.outputs.data) != self.operand_types: raise VerifyException( "Expected arguments to have the same types as the function output types" ) diff --git a/xdsl/dialects/fsm.py b/xdsl/dialects/fsm.py index a428ff3920..d4736fa969 100644 --- a/xdsl/dialects/fsm.py +++ b/xdsl/dialects/fsm.py @@ -205,7 +205,7 @@ def verify_(self): raise VerifyException("Transition regions should not output any value") while (parent := parent.parent_op()) is not None: if isinstance(parent, MachineOp): - if self.operands_types != parent.function_type.outputs.data: + if self.operand_types != parent.function_type.outputs.data: raise VerifyException( "OutputOp output type must be consistent with the machine " + str(parent.sym_name) diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index 4825a78019..b42daba236 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -114,7 +114,7 @@ def verify_(self) -> None: # TODO: how to verify that there is a terminator? entry_block = self.body.blocks.first assert entry_block is not None - block_arg_types = entry_block.args_types + block_arg_types = entry_block.arg_types if self.function_type.inputs.data != tuple(block_arg_types): raise VerifyException( "Expected entry block arguments to have the same types as the function " @@ -225,7 +225,7 @@ def update_function_type(self): return_type = self.function_type.outputs.data if return_op is not None: - return_type = return_op.operands_types + return_type = return_op.operand_types self.properties["function_type"] = FunctionType.from_lists( [arg.type for arg in self.args], diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index 445fd849dc..7ac61ffc32 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -258,7 +258,7 @@ def verify_(self) -> None: "gpu.all_reduce need either a non empty body or an op attribute." ) if non_empty_body: - args_types = self.body.blocks[0].args_types + args_types = self.body.blocks[0].arg_types if args_types != (self.result.type, self.operand.type): raise VerifyException( f"Expected {[str(t) for t in [self.result.type, self.operand.type]]}, " @@ -430,7 +430,7 @@ def __init__( def verify_(self): entry_block: Block = self.body.blocks[0] function_inputs = self.function_type.inputs.data - block_arg_types = entry_block.args_types + block_arg_types = entry_block.arg_types if function_inputs != block_arg_types: raise VerifyException( "Expected first entry block arguments to have the same types as the " @@ -557,7 +557,7 @@ def __init__( def verify_(self) -> None: if not any(b.ops for b in self.body.blocks): raise VerifyException("gpu.launch requires a non-empty body.") - args_type = self.body.blocks[0].args_types + args_type = self.body.blocks[0].arg_types if args_type != (IndexType(),) * 12: raise VerifyException( f"Expected [12 x {str(IndexType())}], got {[str(t) for t in args_type]}. " @@ -758,7 +758,7 @@ def verify_(self) -> None: op = self.parent_op() if op is not None: yield_type = tuple(o.type for o in self.values) - result_type = op.results_types + result_type = op.result_types if yield_type != result_type: raise VerifyException( f"Expected {[str(t) for t in result_type]}, got {[str(t) for t in yield_type]}. The gpu.yield values " diff --git a/xdsl/dialects/hw.py b/xdsl/dialects/hw.py index b0f1a533a5..8c433fdfcc 100644 --- a/xdsl/dialects/hw.py +++ b/xdsl/dialects/hw.py @@ -1230,7 +1230,7 @@ def print_output_port(name: str, port_type: Attribute): printer.print_list( zip( (name.data for name in self.result_names), - self.results_types, + self.result_types, ), lambda x: print_output_port(*x), ) diff --git a/xdsl/dialects/pdl.py b/xdsl/dialects/pdl.py index 02c5f18ef1..64b3d93f09 100644 --- a/xdsl/dialects/pdl.py +++ b/xdsl/dialects/pdl.py @@ -260,7 +260,7 @@ def print(self, printer: Printer) -> None: printer.print(")") if len(self.results) != 0: printer.print(" : ") - printer.print_list(self.results_types, printer.print) + printer.print_list(self.result_types, printer.print) @irdl_op_definition diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index a9aff23181..08ec48a47f 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -615,11 +615,11 @@ def parent_node(self) -> IRNode | None: return self.parent @property - def results_types(self) -> Sequence[Attribute]: + def result_types(self) -> Sequence[Attribute]: return tuple(r.type for r in self.results) @property - def operands_types(self) -> Sequence[Attribute]: + def operand_types(self) -> Sequence[Attribute]: return tuple(operand.type for operand in self.operands) def parent_op(self) -> Operation | None: @@ -904,7 +904,7 @@ def clone_without_regions( (value_mapper[operand] if operand in value_mapper else operand) for operand in self.operands ] - result_types = self.results_types + result_types = self.result_types attributes = self.attributes.copy() properties = self.properties.copy() successors = [ @@ -1223,7 +1223,7 @@ def __init__( self.add_ops(ops) @property - def args_types(self) -> Sequence[Attribute]: + def arg_types(self) -> Sequence[Attribute]: return tuple(arg.type for arg in self._args) @property diff --git a/xdsl/irdl/operations.py b/xdsl/irdl/operations.py index f00c2fd817..5526280494 100644 --- a/xdsl/irdl/operations.py +++ b/xdsl/irdl/operations.py @@ -1390,7 +1390,7 @@ def irdl_op_verify_regions( f"{len(region.blocks)} blocks" ) if (first_block := region.blocks.first) is not None: - entry_args_types = first_block.args_types + entry_args_types = first_block.arg_types try: region_def.entry_args.verify(entry_args_types, constraint_context) except Exception as e: diff --git a/xdsl/printer.py b/xdsl/printer.py index 2bd636458c..20aa8a5cee 100644 --- a/xdsl/printer.py +++ b/xdsl/printer.py @@ -938,7 +938,7 @@ def print_function_type( self.print_string(")") def print_operation_type(self, op: Operation) -> None: - self.print_function_type(op.operands_types, op.results_types) + self.print_function_type(op.operand_types, op.result_types) if self.print_debuginfo: self.print_string(" loc(unknown)") diff --git a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py index 98dde304bf..b6adb4a787 100644 --- a/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py +++ b/xdsl/rewriting/composable_rewriting/immutable_ir/immutable_ir.py @@ -570,7 +570,7 @@ def from_mutable( op.name, op_type, operands, - op.results_types, + op.result_types, properties, attributes, successors, diff --git a/xdsl/transforms/common_subexpression_elimination.py b/xdsl/transforms/common_subexpression_elimination.py index c361ef59e8..fc05983560 100644 --- a/xdsl/transforms/common_subexpression_elimination.py +++ b/xdsl/transforms/common_subexpression_elimination.py @@ -42,7 +42,7 @@ def __hash__(self): self.name, sum(hash(i) for i in self.op.attributes.items()), sum(hash(i) for i in self.op.properties.items()), - hash(self.op.results_types), + hash(self.op.result_types), hash(self.op.operands), ) ) @@ -55,7 +55,7 @@ def __eq__(self, other: object): and self.op.attributes == other.op.attributes and self.op.properties == other.op.properties and self.op.operands == other.op.operands - and self.op.results_types == other.op.results_types + and self.op.result_types == other.op.result_types and all( s.is_structurally_equivalent(o) for s, o in zip(self.op.regions, other.op.regions, strict=True) diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index 16e62ea7fd..dc2ebdc165 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -285,7 +285,7 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): new_load = LoadOp.create( operands=[op.field], - result_types=load.results_types, + result_types=load.result_types, attributes=load.attributes.copy(), properties=load.properties.copy(), ) @@ -306,13 +306,13 @@ class UpdateApplyArgs(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): new_arg_types = tuple(o.type for o in op.args) - if new_arg_types == op.region.block.args_types: + if new_arg_types == op.region.block.arg_types: return new_block = Block(arg_types=new_arg_types) new_apply = ApplyOp.create( operands=op.operands, - result_types=op.results_types, + result_types=op.result_types, properties=op.properties.copy(), attributes=op.attributes.copy(), regions=[Region(new_block)], @@ -406,7 +406,7 @@ def match_and_rewrite(self, op: CombineOp, rewriter: PatternRewriter): new_upper = op.upper new_lowerext = op.lowerext new_upperext = op.upperext - new_results_types = list(op.results_types) + new_results_types = list(op.result_types) new_results_types.pop(i) bounds = cast(StencilBoundsAttr, cast(TempType[Attribute], r.type).bounds) diff --git a/xdsl/transforms/stencil_to_csl_stencil.py b/xdsl/transforms/stencil_to_csl_stencil.py index c84284980b..2c123e7fda 100644 --- a/xdsl/transforms/stencil_to_csl_stencil.py +++ b/xdsl/transforms/stencil_to_csl_stencil.py @@ -250,7 +250,7 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /): ) # rebuild stencil.apply op - r_types = apply_op.results_types + r_types = apply_op.result_types assert isa(r_types, Sequence[stencil.TempType[Attribute]]) new_apply_op = stencil.ApplyOp.build( operands=[[*apply_op.args, prefetch_op.result], apply_op.dest], @@ -521,7 +521,7 @@ def get_prefetch_overhead(o: OpResult): chunk_reduce, post_process, ], - result_types=[op.results_types], + result_types=[op.result_types], ) ) diff --git a/xdsl/transforms/stencil_unroll.py b/xdsl/transforms/stencil_unroll.py index 5418cc328b..f65d91bb90 100644 --- a/xdsl/transforms/stencil_unroll.py +++ b/xdsl/transforms/stencil_unroll.py @@ -78,7 +78,7 @@ def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter, /): return # Enforced by verification - res_types = op.results_types + res_types = op.result_types assert isa(res_types, Sequence[TempType[Attribute]]) dim = res_types[0].get_num_dims()