diff --git a/tests/dialects/test_mpi_lowering.py b/tests/dialects/test_mpi_lowering.py index c87ab8a0d1..c08398e7d4 100644 --- a/tests/dialects/test_mpi_lowering.py +++ b/tests/dialects/test_mpi_lowering.py @@ -101,7 +101,7 @@ def test_lower_mpi_wait_with_status(): assert call.callee.string_value() == "MPI_Wait" assert len(call.arguments) == 2 assert isinstance(call.arguments[1], OpResult) - assert isinstance(call.arguments[1].op, llvm.AllocaOp) + assert isinstance(call.arguments[1].owner, llvm.AllocaOp) def test_lower_mpi_comm_rank(): diff --git a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py index c479cb4d99..54a39bfdc3 100644 --- a/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py +++ b/xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py @@ -73,7 +73,7 @@ def match_and_rewrite(self, op: func.Call, rewriter: PatternRewriter) -> None: raise ValueError("Cannot lower func.call with more than 2 results") cast_operand_ops, register_operands = cast_to_regs(op.arguments) - operand_types = tuple(arg.type for arg in op.arguments) + operand_types = op.arguments.types move_operand_ops, moved_operands = move_to_a_regs( register_operands, operand_types ) @@ -109,9 +109,7 @@ def match_and_rewrite(self, op: func.Return, rewriter: PatternRewriter): raise ValueError("Cannot lower func.return with more than 2 arguments") cast_ops, register_values = cast_to_regs(op.arguments) - move_ops, moved_values = move_to_a_regs( - register_values, tuple(arg.type for arg in op.arguments) - ) + move_ops, moved_values = move_to_a_regs(register_values, op.arguments.types) rewriter.insert_op_before_matched_op(cast_ops) rewriter.insert_op_before_matched_op(move_ops) diff --git a/xdsl/backend/riscv/lowering/convert_scf_to_riscv_scf.py b/xdsl/backend/riscv/lowering/convert_scf_to_riscv_scf.py index 4ce1f1aec9..2de053b4ad 100644 --- a/xdsl/backend/riscv/lowering/convert_scf_to_riscv_scf.py +++ b/xdsl/backend/riscv/lowering/convert_scf_to_riscv_scf.py @@ -22,14 +22,12 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter) -> None: lb, ub, step, *args = cast_operands_to_regs(rewriter) new_region = rewriter.move_region_contents_to_new_regions(op.body) cast_block_args_to_regs(new_region.block, rewriter) - mv_ops, values = move_to_unallocated_regs( - args, tuple(arg.type for arg in op.iter_args) - ) + mv_ops, values = move_to_unallocated_regs(args, op.iter_args.types) rewriter.insert_op_before_matched_op(mv_ops) cast_matched_op_results(rewriter) new_op = riscv_scf.ForOp(lb, ub, step, values, new_region) mv_res_ops, res_values = move_to_unallocated_regs( - new_op.results, tuple(arg.type for arg in op.iter_args) + new_op.results, op.iter_args.types ) rewriter.replace_matched_op((new_op, *mv_res_ops), res_values) diff --git a/xdsl/backend/riscv/register_allocation.py b/xdsl/backend/riscv/register_allocation.py index 75f34a7759..bb0151f297 100644 --- a/xdsl/backend/riscv/register_allocation.py +++ b/xdsl/backend/riscv/register_allocation.py @@ -270,7 +270,7 @@ def allocate_for_loop(self, loop: riscv_scf.ForOp) -> None: self.allocate(loop.step) # Reserve the loop carried variables for allocation within the body - regs = tuple(arg.type for arg in loop.iter_args) + regs = loop.iter_args.types assert all(isinstance(reg, IntRegisterType | FloatRegisterType) for reg in regs) regs = cast(tuple[IntRegisterType | FloatRegisterType], regs) with self.available_registers.reserve_registers(regs): @@ -310,7 +310,7 @@ def allocate_frep_loop(self, loop: riscv_snitch.FRepOperation) -> None: self.allocate(loop.max_rep) # Reserve the loop carried variables for allocation within the body - regs = tuple(arg.type for arg in loop.iter_args) + regs = loop.iter_args.types assert all(isinstance(reg, IntRegisterType | FloatRegisterType) for reg in regs) regs = cast(tuple[IntRegisterType | FloatRegisterType], regs) with self.available_registers.reserve_registers(regs): diff --git a/xdsl/dialects/affine.py b/xdsl/dialects/affine.py index 23dc742d69..8e75e798bc 100644 --- a/xdsl/dialects/affine.py +++ b/xdsl/dialects/affine.py @@ -142,7 +142,7 @@ def verify_(self) -> None: raise VerifyException( "Expected as many upper bound operands as upper bound dimensions and symbols." ) - iter_types = tuple(op.type for op in self.inits) + iter_types = self.inits.types if iter_types != self.result_types: raise VerifyException( "Expected all operands and result pairs to have matching types" diff --git a/xdsl/dialects/comb.py b/xdsl/dialects/comb.py index e4c6c811e3..3d8c58e3df 100644 --- a/xdsl/dialects/comb.py +++ b/xdsl/dialects/comb.py @@ -490,7 +490,7 @@ def from_int_values(inputs: Sequence[SSAValue]) -> "ConcatOp | None": return ConcatOp(inputs, IntegerType(sum_of_width)) def verify_(self) -> None: - sum_of_width = _get_sum_of_int_width([inp.type for inp in self.inputs]) + sum_of_width = _get_sum_of_int_width(self.inputs.types) assert sum_of_width is not None assert isinstance(self.result.type, IntegerType) if sum_of_width != self.result.type.width.data: @@ -519,7 +519,7 @@ def print(self, printer: Printer): printer.print(" ") printer.print_list(self.inputs, printer.print_ssa_value) printer.print(" : ") - printer.print_list([inp.type for inp in self.inputs], printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) @irdl_op_definition diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index 0f0d1224b3..b2952ed72d 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -251,7 +251,7 @@ def print_arg(arg: SSAValue): printer.print_list(self.dest, print_arg) else: printer.print(") -> (") - printer.print_list((r.type for r in self.res), printer.print_attribute) + printer.print_list(self.res.types, printer.print_attribute) printer.print(") ") printer.print("<") diff --git a/xdsl/dialects/csl/csl_wrapper.py b/xdsl/dialects/csl/csl_wrapper.py index 57ce331b7e..544ee254e1 100644 --- a/xdsl/dialects/csl/csl_wrapper.py +++ b/xdsl/dialects/csl/csl_wrapper.py @@ -256,7 +256,7 @@ def verify_(self): # verify that block args are of the right type for the provided params for arg, param in zip( - [a.type for a in self.layout_module.block.args[4:]], + self.layout_module.block.arg_types[4:], self.params, strict=True, ): @@ -275,7 +275,7 @@ def verify_(self): # verify that params and yielded arguments are typed correctly # these may be followed by input-output symbols which we cannot verify, therefore setting `strict=False` for got, (name, exp) in zip( - [a.type for a in self.program_module.block.args[2:]], + self.program_module.block.arg_types[2:], itertools.chain( ( (param.key.data, cast(Attribute, param.type)) diff --git a/xdsl/dialects/fsm.py b/xdsl/dialects/fsm.py index d4736fa969..edb83a0d60 100644 --- a/xdsl/dialects/fsm.py +++ b/xdsl/dialects/fsm.py @@ -438,21 +438,13 @@ def verify_(self): if m is None: raise VerifyException("Machine definition does not exist.") - if not ( - [operand.type for operand in self.inputs] - == [result for result in m.function_type.inputs] - and len(self.inputs) == len(m.function_type.inputs) - ): + if self.inputs.types != tuple(result for result in m.function_type.inputs): raise VerifyException( "TriggerOp input types must be consistent with the machine " + str(m.sym_name) ) - if not ( - [operand.type for operand in self.outputs] - == [result for result in m.function_type.outputs] - and len(self.outputs) == len(m.function_type.outputs) - ): + if self.outputs.types != tuple(result for result in m.function_type.outputs): raise VerifyException( "TriggerOp output types must be consistent with the machine " + str(m.sym_name) @@ -506,21 +498,15 @@ def __init__( def verify_(self): m = SymbolTable.lookup_symbol(self, self.machine) if isinstance(m, MachineOp): - if not ( - [operand.type for operand in self.inputs] - == [result for result in m.function_type.inputs] - and len(self.inputs) == len(m.function_type.inputs) - ): + if self.inputs.types != tuple(result for result in m.function_type.inputs): raise VerifyException( "HWInstanceOp " + str(self.sym_name) + " input type must be consistent with the machine " + str(m.sym_name) ) - if not ( - [operand.type for operand in self.outputs] - == [result for result in m.function_type.outputs] - and len(self.outputs) == len(m.function_type.outputs) + if self.outputs.types != tuple( + result for result in m.function_type.outputs ): raise VerifyException( "HWInstanceOp " diff --git a/xdsl/dialects/func.py b/xdsl/dialects/func.py index b42daba236..a21f02b249 100644 --- a/xdsl/dialects/func.py +++ b/xdsl/dialects/func.py @@ -326,7 +326,7 @@ def verify_(self) -> None: assert isinstance(func_op, FuncOp) function_return_types = func_op.function_type.outputs.data - return_types = tuple(arg.type for arg in self.arguments) + return_types = self.arguments.types if function_return_types != return_types: raise VerifyException( "Expected arguments to have the same types as the function output types" diff --git a/xdsl/dialects/gpu.py b/xdsl/dialects/gpu.py index 7ac61ffc32..c4622d1ce9 100644 --- a/xdsl/dialects/gpu.py +++ b/xdsl/dialects/gpu.py @@ -757,7 +757,7 @@ def __init__(self, operands: Sequence[SSAValue | Operation]): def verify_(self) -> None: op = self.parent_op() if op is not None: - yield_type = tuple(o.type for o in self.values) + yield_type = self.values.types result_type = op.result_types if yield_type != result_type: raise VerifyException( diff --git a/xdsl/dialects/hw.py b/xdsl/dialects/hw.py index 8c433fdfcc..341e38134e 100644 --- a/xdsl/dialects/hw.py +++ b/xdsl/dialects/hw.py @@ -1301,7 +1301,7 @@ def print(self, printer: Printer): printer.print(" ") printer.print_list(self.inputs, printer.print_operand) printer.print(" : ") - printer.print_list((x.type for x in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) HW = Dialect( diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index 1391a92c08..bc0eedbd03 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -221,14 +221,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") extra_attrs = self.attributes.copy() @@ -503,14 +503,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") if extra_attrs and not self.PRINT_ATTRS_IN_FRONT: diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 2c121523f3..8721f40c7f 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -297,7 +297,7 @@ class AllocaScopeReturnOp(IRDLOperation): def verify_(self) -> None: parent = cast(AllocaScopeOp, self.parent_op()) - if any(op.type != res.type for op, res in zip(self.ops, parent.results)): + if self.ops.types != parent.result_types: raise VerifyException( "Expected operand types to match parent's return types." ) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 92658b8b1d..f799bf93a0 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -265,14 +265,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") if self.attributes: @@ -532,14 +532,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") if self.inits: diff --git a/xdsl/dialects/snitch_stream.py b/xdsl/dialects/snitch_stream.py index 668f3dcfa7..d89db93078 100644 --- a/xdsl/dialects/snitch_stream.py +++ b/xdsl/dialects/snitch_stream.py @@ -292,14 +292,14 @@ def print(self, printer: Printer): printer.print_string(" ins(") printer.print_list(self.inputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((i.type for i in self.inputs), printer.print_attribute) + printer.print_list(self.inputs.types, printer.print_attribute) printer.print_string(")") if self.outputs: printer.print_string(" outs(") printer.print_list(self.outputs, printer.print_ssa_value) printer.print_string(" : ") - printer.print_list((o.type for o in self.outputs), printer.print_attribute) + printer.print_list(self.outputs.types, printer.print_attribute) printer.print_string(")") if self.attributes: diff --git a/xdsl/dialects/stencil.py b/xdsl/dialects/stencil.py index 5aa54d2bc5..43c5f9a023 100644 --- a/xdsl/dialects/stencil.py +++ b/xdsl/dialects/stencil.py @@ -485,7 +485,7 @@ def print_destination_operand(dest: SSAValue): printer.print("(") printer.print_list( - zip(self.region.block.args, self.args, (a.type for a in self.args)), + zip(self.region.block.args, self.args, self.args.types), print_assign_argument, ) if self.dest: @@ -493,7 +493,7 @@ def print_destination_operand(dest: SSAValue): printer.print_list(self.dest, print_destination_operand) else: printer.print(") -> (") - printer.print_list((r.type for r in self.res), printer.print_attribute) + printer.print_list(self.res.types, printer.print_attribute) printer.print(") ") printer.print_op_attributes(self.attributes, print_keyword=True) printer.print_region(self.region, print_entry_block_args=False) @@ -1433,9 +1433,7 @@ def get(res: Sequence[SSAValue | Operation]): def verify_(self) -> None: unroll_factor = self.unroll_factor - types = [ - o.type.elem if isinstance(o.type, ResultType) else o.type for o in self.arg - ] + types = [ot.elem if isinstance(ot, ResultType) else ot for ot in self.arg.types] apply = cast(ApplyOp, self.parent_op()) if len(apply.res) > 0: res_types = [ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index 6a0255fea4..c7f70d6db7 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -552,9 +552,7 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - printer.print_list( - (o.type for o in getattr(op, self.name)), printer.print_attribute - ) + printer.print_list(getattr(op, self.name).types, printer.print_attribute) state.last_was_punctuation = False state.should_emit_space = True @@ -694,9 +692,7 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool: def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if state.should_emit_space or not state.last_was_punctuation: printer.print(" ") - printer.print_list( - (r.type for r in getattr(op, self.name)), printer.print_attribute - ) + printer.print_list(getattr(op, self.name).types, printer.print_attribute) state.last_was_punctuation = False state.should_emit_space = True diff --git a/xdsl/irdl/operations.py b/xdsl/irdl/operations.py index 5526280494..a98cd59f99 100644 --- a/xdsl/irdl/operations.py +++ b/xdsl/irdl/operations.py @@ -305,7 +305,10 @@ def __init__( self.constr = range_constr_coercion(attr) -VarOperand: TypeAlias = tuple[SSAValue, ...] +class VarOperand(tuple[Operand, ...]): + @property + def types(self): + return tuple(o.type for o in self) @dataclass(init=False) @@ -341,7 +344,10 @@ def __init__( self.constr = range_constr_coercion(attr) -VarOpResult: TypeAlias = tuple[OpResult, ...] +class VarOpResult(tuple[OpResult, ...]): + @property + def types(self): + return tuple(r.type for r in self) @dataclass(init=False) @@ -1339,7 +1345,12 @@ def get_operand_result_or_region( construct: VarIRConstruct, ) -> ( None - | SSAValue + | Operand + | VarOperand + | OptOperand + | OpResult + | VarOpResult + | OptOpResult | Sequence[SSAValue] | Sequence[OpResult] | Region @@ -1374,7 +1385,13 @@ def get_operand_result_or_region( return args[begin_arg] if isinstance(defs[arg_def_idx][1], VariadicDef): arg_size = variadic_sizes[previous_var_args] - return args[begin_arg : begin_arg + arg_size] + values = args[begin_arg : begin_arg + arg_size] + if isinstance(defs[arg_def_idx][1], OperandDef): + return VarOperand(cast(Sequence[Operand], values)) + elif isinstance(defs[arg_def_idx][1], ResultDef): + return VarOpResult(cast(Sequence[OpResult], values)) + else: + return values else: return args[begin_arg] diff --git a/xdsl/transforms/csl_stencil_bufferize.py b/xdsl/transforms/csl_stencil_bufferize.py index cdc6e0f990..2d0b2c8960 100644 --- a/xdsl/transforms/csl_stencil_bufferize.py +++ b/xdsl/transforms/csl_stencil_bufferize.py @@ -93,7 +93,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, # create new op buf_apply_op = csl_stencil.ApplyOp( operands=[op.communicated_stencil, buf_iter_arg.memref, op.args, op.dest], - result_types=[t.type for t in op.res] or [[]], + result_types=op.res.types or [[]], regions=[ self._get_empty_bufferized_region(op.chunk_reduce.block.args), self._get_empty_bufferized_region(op.post_process.block.args), @@ -267,7 +267,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /): rewriter.replace_matched_op( func.FuncOp.build( operands=op.operands, - result_types=[r.type for r in op.results], + result_types=op.result_types, regions=[op.detach_region(op.body)], properties={**op.properties, "function_type": function_type}, attributes=op.attributes.copy(), diff --git a/xdsl/transforms/lower_mpi.py b/xdsl/transforms/lower_mpi.py index d103846ad1..362fde7a11 100644 --- a/xdsl/transforms/lower_mpi.py +++ b/xdsl/transforms/lower_mpi.py @@ -1,4 +1,5 @@ from abc import ABC +from collections.abc import Sequence from dataclasses import dataclass from math import prod from typing import TypeVar, cast @@ -753,7 +754,9 @@ class MpiAddExternalFuncDefs(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, /): # collect all func calls to MPI functions - funcs_to_emit: dict[str, tuple[list[Attribute], list[Attribute]]] = dict() + funcs_to_emit: dict[str, tuple[Sequence[Attribute], Sequence[Attribute]]] = ( + dict() + ) for op in module.walk(): if not isinstance(op, func.Call): @@ -761,8 +764,8 @@ def match_and_rewrite(self, module: builtin.ModuleOp, rewriter: PatternRewriter, if op.callee.string_value() not in self.mpi_func_call_names: continue funcs_to_emit[op.callee.string_value()] = ( - [arg.type for arg in op.arguments], - [res.type for res in op.results], + op.arguments.types, + op.result_types, ) # for each func found, add a FuncOp to the top of the module. diff --git a/xdsl/transforms/reconcile_unrealized_casts.py b/xdsl/transforms/reconcile_unrealized_casts.py index 050c2523fb..b77afd8ec1 100644 --- a/xdsl/transforms/reconcile_unrealized_casts.py +++ b/xdsl/transforms/reconcile_unrealized_casts.py @@ -69,9 +69,7 @@ def gen_all_uses_cast( # because types are homogeneous (e.g. {A -> B, B -> A}) # otherwise it means the cast is not unifiable with its uses assert len(cast.results) == len(op.inputs) - has_trivial_cycle = all( - r.type == i.type for r, i in zip(cast.results, op.inputs) - ) + has_trivial_cycle = cast.result_types == op.inputs.types if is_live and not has_trivial_cycle: if warn_on_failure: warn( diff --git a/xdsl/transforms/stencil_bufferize.py b/xdsl/transforms/stencil_bufferize.py index c5a663bc77..72bd54db43 100644 --- a/xdsl/transforms/stencil_bufferize.py +++ b/xdsl/transforms/stencil_bufferize.py @@ -272,16 +272,10 @@ def match_and_rewrite(self, op: StoreOp, rewriter: PatternRewriter): properties=apply.properties.copy(), attributes=apply.attributes.copy(), regions=[ - Region(Block(arg_types=[SSAValue.get(a).type for a in apply.args])), + apply.detach_region(0), ], ) - rewriter.inline_block( - apply.region.block, - InsertPoint.at_start(new_apply.region.block), - new_apply.region.block.args, - ) - new_load = LoadOp.create( operands=[op.field], result_types=load.result_types, @@ -304,7 +298,7 @@ class UpdateApplyArgs(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: ApplyOp, rewriter: PatternRewriter): - new_arg_types = tuple(o.type for o in op.args) + new_arg_types = op.args.types if new_arg_types == op.region.block.arg_types: return