Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: helpers for IRDL construct type access. #3025

Merged
merged 9 commits into from
Aug 13, 2024
2 changes: 1 addition & 1 deletion tests/dialects/test_mpi_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_lower_mpi_wait_with_status():
assert call.callee.string_value() == "MPI_Wait"
assert len(call.arguments) == 2
assert isinstance(call.arguments[1], OpResult)
assert isinstance(call.arguments[1].op, llvm.AllocaOp)
assert isinstance(call.arguments[1].owner, llvm.AllocaOp)


def test_lower_mpi_comm_rank():
Expand Down
6 changes: 2 additions & 4 deletions xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def match_and_rewrite(self, op: func.Call, rewriter: PatternRewriter) -> None:
raise ValueError("Cannot lower func.call with more than 2 results")

cast_operand_ops, register_operands = cast_to_regs(op.arguments)
operand_types = tuple(arg.type for arg in op.arguments)
operand_types = op.arguments.types
move_operand_ops, moved_operands = move_to_a_regs(
register_operands, operand_types
)
Expand Down Expand Up @@ -109,9 +109,7 @@ def match_and_rewrite(self, op: func.Return, rewriter: PatternRewriter):
raise ValueError("Cannot lower func.return with more than 2 arguments")

cast_ops, register_values = cast_to_regs(op.arguments)
move_ops, moved_values = move_to_a_regs(
register_values, tuple(arg.type for arg in op.arguments)
)
move_ops, moved_values = move_to_a_regs(register_values, op.arguments.types)

rewriter.insert_op_before_matched_op(cast_ops)
rewriter.insert_op_before_matched_op(move_ops)
Expand Down
6 changes: 2 additions & 4 deletions xdsl/backend/riscv/lowering/convert_scf_to_riscv_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter) -> None:
lb, ub, step, *args = cast_operands_to_regs(rewriter)
new_region = rewriter.move_region_contents_to_new_regions(op.body)
cast_block_args_to_regs(new_region.block, rewriter)
mv_ops, values = move_to_unallocated_regs(
args, tuple(arg.type for arg in op.iter_args)
)
mv_ops, values = move_to_unallocated_regs(args, op.iter_args.types)
rewriter.insert_op_before_matched_op(mv_ops)
cast_matched_op_results(rewriter)
new_op = riscv_scf.ForOp(lb, ub, step, values, new_region)
mv_res_ops, res_values = move_to_unallocated_regs(
new_op.results, tuple(arg.type for arg in op.iter_args)
new_op.results, op.iter_args.types
)

rewriter.replace_matched_op((new_op, *mv_res_ops), res_values)
Expand Down
4 changes: 2 additions & 2 deletions xdsl/backend/riscv/register_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def allocate_for_loop(self, loop: riscv_scf.ForOp) -> None:
self.allocate(loop.step)

# Reserve the loop carried variables for allocation within the body
regs = tuple(arg.type for arg in loop.iter_args)
regs = loop.iter_args.types
assert all(isinstance(reg, IntRegisterType | FloatRegisterType) for reg in regs)
regs = cast(tuple[IntRegisterType | FloatRegisterType], regs)
with self.available_registers.reserve_registers(regs):
Expand Down Expand Up @@ -310,7 +310,7 @@ def allocate_frep_loop(self, loop: riscv_snitch.FRepOperation) -> None:
self.allocate(loop.max_rep)

# Reserve the loop carried variables for allocation within the body
regs = tuple(arg.type for arg in loop.iter_args)
regs = loop.iter_args.types
assert all(isinstance(reg, IntRegisterType | FloatRegisterType) for reg in regs)
regs = cast(tuple[IntRegisterType | FloatRegisterType], regs)
with self.available_registers.reserve_registers(regs):
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def verify_(self) -> None:
raise VerifyException(
"Expected as many upper bound operands as upper bound dimensions and symbols."
)
iter_types = tuple(op.type for op in self.inits)
iter_types = self.inits.types
if iter_types != self.result_types:
raise VerifyException(
"Expected all operands and result pairs to have matching types"
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/comb.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def from_int_values(inputs: Sequence[SSAValue]) -> "ConcatOp | None":
return ConcatOp(inputs, IntegerType(sum_of_width))

def verify_(self) -> None:
sum_of_width = _get_sum_of_int_width([inp.type for inp in self.inputs])
sum_of_width = _get_sum_of_int_width(self.inputs.types)
assert sum_of_width is not None
assert isinstance(self.result.type, IntegerType)
if sum_of_width != self.result.type.width.data:
Expand Down Expand Up @@ -519,7 +519,7 @@ def print(self, printer: Printer):
printer.print(" ")
printer.print_list(self.inputs, printer.print_ssa_value)
printer.print(" : ")
printer.print_list([inp.type for inp in self.inputs], printer.print_attribute)
printer.print_list(self.inputs.types, printer.print_attribute)


@irdl_op_definition
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def print_arg(arg: SSAValue):
printer.print_list(self.dest, print_arg)
else:
printer.print(") -> (")
printer.print_list((r.type for r in self.res), printer.print_attribute)
printer.print_list(self.res.types, printer.print_attribute)

printer.print(") ")
printer.print("<")
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/csl/csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def verify_(self):

# verify that block args are of the right type for the provided params
for arg, param in zip(
[a.type for a in self.layout_module.block.args[4:]],
self.layout_module.block.arg_types[4:],
self.params,
strict=True,
):
Expand All @@ -275,7 +275,7 @@ def verify_(self):
# verify that params and yielded arguments are typed correctly
# these may be followed by input-output symbols which we cannot verify, therefore setting `strict=False`
for got, (name, exp) in zip(
[a.type for a in self.program_module.block.args[2:]],
self.program_module.block.arg_types[2:],
itertools.chain(
(
(param.key.data, cast(Attribute, param.type))
Expand Down
24 changes: 5 additions & 19 deletions xdsl/dialects/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,21 +438,13 @@ def verify_(self):
if m is None:
raise VerifyException("Machine definition does not exist.")

if not (
[operand.type for operand in self.inputs]
== [result for result in m.function_type.inputs]
and len(self.inputs) == len(m.function_type.inputs)
):
if self.inputs.types != tuple(result for result in m.function_type.inputs):
raise VerifyException(
"TriggerOp input types must be consistent with the machine "
+ str(m.sym_name)
)

if not (
[operand.type for operand in self.outputs]
== [result for result in m.function_type.outputs]
and len(self.outputs) == len(m.function_type.outputs)
):
if self.outputs.types != tuple(result for result in m.function_type.outputs):
raise VerifyException(
"TriggerOp output types must be consistent with the machine "
+ str(m.sym_name)
Expand Down Expand Up @@ -506,21 +498,15 @@ def __init__(
def verify_(self):
m = SymbolTable.lookup_symbol(self, self.machine)
if isinstance(m, MachineOp):
if not (
[operand.type for operand in self.inputs]
== [result for result in m.function_type.inputs]
and len(self.inputs) == len(m.function_type.inputs)
):
if self.inputs.types != tuple(result for result in m.function_type.inputs):
raise VerifyException(
"HWInstanceOp "
+ str(self.sym_name)
+ " input type must be consistent with the machine "
+ str(m.sym_name)
)
if not (
[operand.type for operand in self.outputs]
== [result for result in m.function_type.outputs]
and len(self.outputs) == len(m.function_type.outputs)
if self.outputs.types != tuple(
result for result in m.function_type.outputs
):
raise VerifyException(
"HWInstanceOp "
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def verify_(self) -> None:
assert isinstance(func_op, FuncOp)

function_return_types = func_op.function_type.outputs.data
return_types = tuple(arg.type for arg in self.arguments)
return_types = self.arguments.types
if function_return_types != return_types:
raise VerifyException(
"Expected arguments to have the same types as the function output types"
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def __init__(self, operands: Sequence[SSAValue | Operation]):
def verify_(self) -> None:
op = self.parent_op()
if op is not None:
yield_type = tuple(o.type for o in self.values)
yield_type = self.values.types
result_type = op.result_types
if yield_type != result_type:
raise VerifyException(
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/hw.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,7 +1301,7 @@ def print(self, printer: Printer):
printer.print(" ")
printer.print_list(self.inputs, printer.print_operand)
printer.print(" : ")
printer.print_list((x.type for x in self.inputs), printer.print_attribute)
printer.print_list(self.inputs.types, printer.print_attribute)


HW = Dialect(
Expand Down
8 changes: 4 additions & 4 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,14 @@ def print(self, printer: Printer):
printer.print_string(" ins(")
printer.print_list(self.inputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((i.type for i in self.inputs), printer.print_attribute)
printer.print_list(self.inputs.types, printer.print_attribute)
printer.print_string(")")

if self.outputs:
printer.print_string(" outs(")
printer.print_list(self.outputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((o.type for o in self.outputs), printer.print_attribute)
printer.print_list(self.outputs.types, printer.print_attribute)
printer.print_string(")")

extra_attrs = self.attributes.copy()
Expand Down Expand Up @@ -503,14 +503,14 @@ def print(self, printer: Printer):
printer.print_string(" ins(")
printer.print_list(self.inputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((i.type for i in self.inputs), printer.print_attribute)
printer.print_list(self.inputs.types, printer.print_attribute)
printer.print_string(")")

if self.outputs:
printer.print_string(" outs(")
printer.print_list(self.outputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((o.type for o in self.outputs), printer.print_attribute)
printer.print_list(self.outputs.types, printer.print_attribute)
printer.print_string(")")

if extra_attrs and not self.PRINT_ATTRS_IN_FRONT:
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ class AllocaScopeReturnOp(IRDLOperation):

def verify_(self) -> None:
parent = cast(AllocaScopeOp, self.parent_op())
if any(op.type != res.type for op, res in zip(self.ops, parent.results)):
if self.ops.types != parent.result_types:
raise VerifyException(
"Expected operand types to match parent's return types."
)
Expand Down
8 changes: 4 additions & 4 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,14 @@ def print(self, printer: Printer):
printer.print_string(" ins(")
printer.print_list(self.inputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((i.type for i in self.inputs), printer.print_attribute)
printer.print_list(self.inputs.types, printer.print_attribute)
printer.print_string(")")

if self.outputs:
printer.print_string(" outs(")
printer.print_list(self.outputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((o.type for o in self.outputs), printer.print_attribute)
printer.print_list(self.outputs.types, printer.print_attribute)
printer.print_string(")")

if self.attributes:
Expand Down Expand Up @@ -532,14 +532,14 @@ def print(self, printer: Printer):
printer.print_string(" ins(")
printer.print_list(self.inputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((i.type for i in self.inputs), printer.print_attribute)
printer.print_list(self.inputs.types, printer.print_attribute)
printer.print_string(")")

if self.outputs:
printer.print_string(" outs(")
printer.print_list(self.outputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((o.type for o in self.outputs), printer.print_attribute)
printer.print_list(self.outputs.types, printer.print_attribute)
printer.print_string(")")

if self.inits:
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/snitch_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,14 @@ def print(self, printer: Printer):
printer.print_string(" ins(")
printer.print_list(self.inputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((i.type for i in self.inputs), printer.print_attribute)
printer.print_list(self.inputs.types, printer.print_attribute)
printer.print_string(")")

if self.outputs:
printer.print_string(" outs(")
printer.print_list(self.outputs, printer.print_ssa_value)
printer.print_string(" : ")
printer.print_list((o.type for o in self.outputs), printer.print_attribute)
printer.print_list(self.outputs.types, printer.print_attribute)
printer.print_string(")")

if self.attributes:
Expand Down
8 changes: 3 additions & 5 deletions xdsl/dialects/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,15 +485,15 @@ def print_destination_operand(dest: SSAValue):

printer.print("(")
printer.print_list(
zip(self.region.block.args, self.args, (a.type for a in self.args)),
zip(self.region.block.args, self.args, self.args.types),
print_assign_argument,
)
if self.dest:
printer.print(") outs (")
printer.print_list(self.dest, print_destination_operand)
else:
printer.print(") -> (")
printer.print_list((r.type for r in self.res), printer.print_attribute)
printer.print_list(self.res.types, printer.print_attribute)
printer.print(") ")
printer.print_op_attributes(self.attributes, print_keyword=True)
printer.print_region(self.region, print_entry_block_args=False)
Expand Down Expand Up @@ -1433,9 +1433,7 @@ def get(res: Sequence[SSAValue | Operation]):

def verify_(self) -> None:
unroll_factor = self.unroll_factor
types = [
o.type.elem if isinstance(o.type, ResultType) else o.type for o in self.arg
]
types = [ot.elem if isinstance(ot, ResultType) else ot for ot in self.arg.types]
apply = cast(ApplyOp, self.parent_op())
if len(apply.res) > 0:
res_types = [
Expand Down
8 changes: 2 additions & 6 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,7 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_list(
(o.type for o in getattr(op, self.name)), printer.print_attribute
)
printer.print_list(getattr(op, self.name).types, printer.print_attribute)
state.last_was_punctuation = False
state.should_emit_space = True

Expand Down Expand Up @@ -694,9 +692,7 @@ def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_list(
(r.type for r in getattr(op, self.name)), printer.print_attribute
)
printer.print_list(getattr(op, self.name).types, printer.print_attribute)
state.last_was_punctuation = False
state.should_emit_space = True

Expand Down
25 changes: 21 additions & 4 deletions xdsl/irdl/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,10 @@ def __init__(
self.constr = range_constr_coercion(attr)


VarOperand: TypeAlias = tuple[SSAValue, ...]
class VarOperand(tuple[Operand, ...]):
@property
def types(self):
return tuple(o.type for o in self)


@dataclass(init=False)
Expand Down Expand Up @@ -341,7 +344,10 @@ def __init__(
self.constr = range_constr_coercion(attr)


VarOpResult: TypeAlias = tuple[OpResult, ...]
class VarOpResult(tuple[OpResult, ...]):
@property
def types(self):
return tuple(r.type for r in self)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great if this doesn't need an or [[]]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't actually needed where you used it 🙂 Happy to discuss it in chat if you're curious



@dataclass(init=False)
Expand Down Expand Up @@ -1339,7 +1345,12 @@ def get_operand_result_or_region(
construct: VarIRConstruct,
) -> (
None
| SSAValue
| Operand
| VarOperand
| OptOperand
| OpResult
| VarOpResult
| OptOpResult
| Sequence[SSAValue]
| Sequence[OpResult]
| Region
Expand Down Expand Up @@ -1374,7 +1385,13 @@ def get_operand_result_or_region(
return args[begin_arg]
if isinstance(defs[arg_def_idx][1], VariadicDef):
arg_size = variadic_sizes[previous_var_args]
return args[begin_arg : begin_arg + arg_size]
values = args[begin_arg : begin_arg + arg_size]
if isinstance(defs[arg_def_idx][1], OperandDef):
return VarOperand(cast(Sequence[Operand], values))
elif isinstance(defs[arg_def_idx][1], ResultDef):
return VarOpResult(cast(Sequence[OpResult], values))
else:
return values
else:
return args[begin_arg]

Expand Down
Loading
Loading