diff --git a/tests/dialects/test_irdl.py b/tests/dialects/test_irdl.py index 2b8e41104f..3913aa0ec4 100644 --- a/tests/dialects/test_irdl.py +++ b/tests/dialects/test_irdl.py @@ -3,7 +3,7 @@ import pytest -from xdsl.dialects.builtin import StringAttr, SymbolRefAttr, i32 +from xdsl.dialects.builtin import ArrayAttr, StringAttr, SymbolRefAttr, i32 from xdsl.dialects.irdl import ( AllOfOp, AnyOfOp, @@ -20,6 +20,7 @@ ResultsOp, TypeOp, ) +from xdsl.dialects.irdl.irdl import VariadicityArrayAttr, VariadicityAttr from xdsl.ir import Block, Region from xdsl.irdl import IRDLOperation, irdl_op_definition from xdsl.utils.exceptions import PyRDLOpDefinitionError @@ -46,16 +47,38 @@ def test_named_region_op_init( assert len(op.body.blocks) == 1 -@pytest.mark.parametrize("op_type", [ParametersOp, OperandsOp, ResultsOp]) -def test_parameters_init(op_type: type[ParametersOp | OperandsOp | ResultsOp]): +def test_parameter_op_init(): """ - Test __init__ of ParametersOp, OperandsOp, ResultsOp. + Test __init__ of ParametersOp. """ val1 = TestSSAValue(AttributeType()) val2 = TestSSAValue(AttributeType()) - op = op_type([val1, val2]) - op2 = op_type.create(operands=[val1, val2]) + op = ParametersOp([val1, val2]) + op2 = ParametersOp.create(operands=[val1, val2]) + + assert op.is_structurally_equivalent(op2) + + assert op.args == (val1, val2) + + +@pytest.mark.parametrize("op_type", [OperandsOp, ResultsOp]) +def test_parameters_init(op_type: type[OperandsOp | ResultsOp]): + """ + Test __init__ of OperandsOp, ResultsOp. + """ + + val1 = TestSSAValue(AttributeType()) + val2 = TestSSAValue(AttributeType()) + op = op_type([(VariadicityAttr.SINGLE, val1), (VariadicityAttr.OPTIONAL, val2)]) + op2 = op_type.create( + operands=[val1, val2], + attributes={ + "variadicity": VariadicityArrayAttr( + ArrayAttr((VariadicityAttr.SINGLE, VariadicityAttr.OPTIONAL)) + ) + }, + ) assert op.is_structurally_equivalent(op2) diff --git a/tests/filecheck/dialects/irdl/pyrdl-to-irdl/cmath-conversion.py b/tests/filecheck/dialects/irdl/pyrdl-to-irdl/cmath-conversion.py index 7b68562404..7f4bcda7b2 100644 --- a/tests/filecheck/dialects/irdl/pyrdl-to-irdl/cmath-conversion.py +++ b/tests/filecheck/dialects/irdl/pyrdl-to-irdl/cmath-conversion.py @@ -7,19 +7,19 @@ # CHECK: irdl.dialect @cmath { -# CHECK-NEXT: irdl.attribute @cmath.complex { +# CHECK-NEXT: irdl.attribute @complex { # CHECK-NEXT: %{{.*}} = irdl.any # CHECK-NEXT: irdl.parameters(%{{.*}}) # CHECK-NEXT: } -# CHECK-NEXT: irdl.operation @cmath.norm { +# CHECK-NEXT: irdl.operation @norm { # CHECK-NEXT: %{{.*}} = irdl.any # CHECK-NEXT: irdl.operands(%{{.*}}) # CHECK-NEXT: %{{.*}} = irdl.any # CHECK-NEXT: irdl.results(%{{.*}}) # CHECK-NEXT: } -# CHECK-NEXT: irdl.operation @cmath.mul { +# CHECK-NEXT: irdl.operation @mul { # CHECK-NEXT: %{{.*}} = irdl.any # CHECK-NEXT: %{{.*}} = irdl.any # CHECK-NEXT: irdl.operands(%{{.*}}, %{{.*}}) diff --git a/tests/filecheck/dialects/irdl/testd.irdl.mlir b/tests/filecheck/dialects/irdl/testd.irdl.mlir index 91faea5299..de63ddf370 100644 --- a/tests/filecheck/dialects/irdl/testd.irdl.mlir +++ b/tests/filecheck/dialects/irdl/testd.irdl.mlir @@ -133,5 +133,16 @@ builtin.module { %2 = irdl.any_of(%0, %1) irdl.results(%2, %2) } + + // CHECK: irdl.operation @variadicity { + // CHECK-NEXT: %{{.*}} = irdl.any + // CHECK-NEXT: irdl.operands(%{{.*}}, %{{.*}}, optional %{{.*}}, variadic %{{.*}}) + // CHECK-NEXT: irdl.results(%{{.*}}, %{{.*}}, optional %{{.*}}, variadic %{{.*}}) + // CHECK-NEXT: } + irdl.operation @variadicity { + %0 = irdl.any + irdl.operands(%0, single %0, optional %0, variadic %0) + irdl.results(%0, single %0, optional %0, variadic %0) + } } } diff --git a/xdsl/dialects/irdl/irdl.py b/xdsl/dialects/irdl/irdl.py index 954d7f920f..1c48f744cc 100644 --- a/xdsl/dialects/irdl/irdl.py +++ b/xdsl/dialects/irdl/irdl.py @@ -3,20 +3,24 @@ from __future__ import annotations from collections.abc import Mapping, Sequence +from typing import ClassVar -from xdsl.dialects.builtin import StringAttr, SymbolRefAttr +from xdsl.dialects.builtin import ArrayAttr, StringAttr, SymbolRefAttr from xdsl.ir import ( Attribute, Block, Dialect, + EnumAttribute, OpResult, ParametrizedAttribute, Region, + SpacedOpaqueSyntaxAttribute, SSAValue, TypeAttribute, ) from xdsl.irdl import ( IRDLOperation, + ParameterDef, VarOperand, attr_def, irdl_attr_definition, @@ -26,7 +30,7 @@ result_def, var_operand_def, ) -from xdsl.parser import Parser +from xdsl.parser import AttrParser, Parser from xdsl.printer import Printer from xdsl.traits import ( HasParent, @@ -35,12 +39,55 @@ SymbolTable, ) from xdsl.utils.exceptions import VerifyException +from xdsl.utils.str_enum import StrEnum ################################################################################ # Dialect, Operation, and Attribute definitions # ################################################################################ +class VariadicityEnum(StrEnum): + SINGLE = "single" + OPTIONAL = "optional" + VARIADIC = "variadic" + + +@irdl_attr_definition +class VariadicityAttr(EnumAttribute[VariadicityEnum], SpacedOpaqueSyntaxAttribute): + name = "irdl.variadicity" + + SINGLE: ClassVar[VariadicityAttr] + OPTIONAL: ClassVar[VariadicityAttr] + VARIADIC: ClassVar[VariadicityAttr] + + +setattr(VariadicityAttr, "SINGLE", VariadicityAttr(VariadicityEnum.SINGLE)) +setattr(VariadicityAttr, "OPTIONAL", VariadicityAttr(VariadicityEnum.OPTIONAL)) +setattr(VariadicityAttr, "VARIADIC", VariadicityAttr(VariadicityEnum.VARIADIC)) + + +@irdl_attr_definition +class VariadicityArrayAttr(ParametrizedAttribute, SpacedOpaqueSyntaxAttribute): + name = "irdl.variadicity_array" + + value: ParameterDef[ArrayAttr[VariadicityAttr]] + + def __init__(self, variadicities: ArrayAttr[VariadicityAttr]) -> None: + super().__init__((variadicities,)) + + @classmethod + def parse_parameters(cls, parser: AttrParser) -> tuple[ArrayAttr[VariadicityAttr]]: + data = parser.parse_comma_separated_list( + AttrParser.Delimiter.SQUARE, lambda: VariadicityAttr.parse_parameter(parser) + ) + return (ArrayAttr(VariadicityAttr(x) for x in data),) + + def print_parameters(self, printer: Printer) -> None: + printer.print_string("[") + printer.print_list(self.value, lambda var: var.print_parameter(printer)) + printer.print_string("]") + + @irdl_attr_definition class AttributeType(ParametrizedAttribute, TypeAttribute): """Type of a attribute handle.""" @@ -115,6 +162,24 @@ def qualified_name(self): return f"{dialect_op.sym_name.data}.{self.sym_name.data}" +@irdl_op_definition +class CPredOp(IRDLOperation): + """Constraints an attribute using a C++ predicate""" + + name = "irdl.c_pred" + + pred = attr_def(StringAttr) + + output: OpResult = result_def(AttributeType()) + + assembly_format = "$pred attr-dict" + + def __init__(self, pred: str | StringAttr): + if isinstance(pred, str): + pred = StringAttr(pred) + super().__init__(attributes={"pred": pred}, result_types=[AttributeType()]) + + @irdl_op_definition class AttributeOp(IRDLOperation): """An attribute definition.""" @@ -215,6 +280,23 @@ def qualified_name(self): return f"{dialect_op.sym_name.data}.{self.sym_name.data}" +def _parse_argument(parser: Parser) -> tuple[VariadicityAttr, SSAValue]: + variadicity = parser.parse_optional_str_enum(VariadicityEnum) + if variadicity is None: + variadicity = VariadicityEnum.SINGLE + + arg = parser.parse_operand() + + return (VariadicityAttr(variadicity), arg) + + +def _print_argument(printer: Printer, data: tuple[VariadicityAttr, SSAValue]) -> None: + variadicity = data[0].data + if variadicity != VariadicityEnum.SINGLE: + printer.print(variadicity, " ") + printer.print(data[1]) + + @irdl_op_definition class OperandsOp(IRDLOperation): """An operation operand definition.""" @@ -223,21 +305,36 @@ class OperandsOp(IRDLOperation): args: VarOperand = var_operand_def(AttributeType) + variadicity = attr_def(VariadicityArrayAttr) + traits = frozenset([HasParent(OperationOp)]) - def __init__(self, args: Sequence[SSAValue]): - super().__init__(operands=[args]) + def __init__(self, args: Sequence[tuple[VariadicityAttr, SSAValue] | SSAValue]): + args_list = [ + (VariadicityAttr.SINGLE, x) if isinstance(x, SSAValue) else x for x in args + ] + operands = tuple(operand for _, operand in args_list) + attributes = { + "variadicity": VariadicityArrayAttr( + ArrayAttr(tuple(v for v, _ in args_list)) + ) + } + super().__init__(operands=[operands], attributes=attributes) @classmethod def parse(cls, parser: Parser) -> OperandsOp: args = parser.parse_comma_separated_list( - parser.Delimiter.PAREN, parser.parse_operand + parser.Delimiter.PAREN, lambda: _parse_argument(parser) ) return OperandsOp(args) def print(self, printer: Printer) -> None: printer.print("(") - printer.print_list(self.args, printer.print, ", ") + printer.print_list( + zip(self.variadicity.value, self.args), + lambda x: _print_argument(printer, x), + ", ", + ) printer.print(")") @@ -249,21 +346,36 @@ class ResultsOp(IRDLOperation): args: VarOperand = var_operand_def(AttributeType) + variadicity = attr_def(VariadicityArrayAttr) + traits = frozenset([HasParent(OperationOp)]) - def __init__(self, args: Sequence[SSAValue]): - super().__init__(operands=[args]) + def __init__(self, args: Sequence[tuple[VariadicityAttr, SSAValue] | SSAValue]): + args_list = [ + (VariadicityAttr.SINGLE, x) if isinstance(x, SSAValue) else x for x in args + ] + operands = [x[1] for x in args_list] + attributes = { + "variadicity": VariadicityArrayAttr( + ArrayAttr(tuple(v for v, _ in args_list)) + ) + } + super().__init__(operands=[operands], attributes=attributes) @classmethod def parse(cls, parser: Parser) -> ResultsOp: args = parser.parse_comma_separated_list( - parser.Delimiter.PAREN, parser.parse_operand + parser.Delimiter.PAREN, lambda: _parse_argument(parser) ) return ResultsOp(args) def print(self, printer: Printer) -> None: printer.print("(") - printer.print_list(self.args, printer.print, ", ") + printer.print_list( + zip(self.variadicity.value, self.args), + lambda x: _print_argument(printer, x), + ", ", + ) printer.print(")") @@ -460,6 +572,7 @@ def print(self, printer: Printer) -> None: [ DialectOp, TypeOp, + CPredOp, AttributeOp, BaseOp, ParametersOp, @@ -472,5 +585,9 @@ def print(self, printer: Printer) -> None: AnyOfOp, AllOfOp, ], - [AttributeType], + [ + AttributeType, + VariadicityAttr, + VariadicityArrayAttr, + ], ) diff --git a/xdsl/dialects/irdl/pyrdl_to_irdl.py b/xdsl/dialects/irdl/pyrdl_to_irdl.py index 943d2e04a9..405b57febd 100644 --- a/xdsl/dialects/irdl/pyrdl_to_irdl.py +++ b/xdsl/dialects/irdl/pyrdl_to_irdl.py @@ -51,7 +51,7 @@ def op_def_to_irdl(op: type[IRDLOperation]) -> OperationOp: if result_values: builder.insert(ResultsOp(result_values)) - return OperationOp(op_def.name, Region([block])) + return OperationOp(Dialect.split_name(op_def.name)[1], Region([block])) def attr_def_to_irdl( @@ -69,7 +69,7 @@ def attr_def_to_irdl( param_values.append(constraint_to_irdl(builder, param[1])) builder.insert(ParametersOp(param_values)) - return AttributeOp(attr_def.name, Region([block])) + return AttributeOp(Dialect.split_name(attr_def.name)[1], Region([block])) def dialect_to_irdl(dialect: Dialect, name: str) -> DialectOp: