Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (irdl) Update for mlir compatibility #3002

Merged
merged 11 commits into from
Aug 15, 2024
33 changes: 28 additions & 5 deletions tests/dialects/test_irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ResultsOp,
TypeOp,
)
from xdsl.dialects.irdl.irdl import VariadicityArrayAttr, VariadicityEnum
from xdsl.ir import Block, Region
from xdsl.irdl import IRDLOperation, irdl_op_definition
from xdsl.utils.exceptions import PyRDLOpDefinitionError
Expand All @@ -46,16 +47,38 @@ def test_named_region_op_init(
assert len(op.body.blocks) == 1


@pytest.mark.parametrize("op_type", [ParametersOp, OperandsOp, ResultsOp])
def test_parameters_init(op_type: type[ParametersOp | OperandsOp | ResultsOp]):
def test_parameter_op_init():
"""
Test __init__ of ParametersOp, OperandsOp, ResultsOp.
Test __init__ of ParametersOp.
"""

val1 = TestSSAValue(AttributeType())
val2 = TestSSAValue(AttributeType())
op = op_type([val1, val2])
op2 = op_type.create(operands=[val1, val2])
op = ParametersOp([val1, val2])
op2 = ParametersOp.create(operands=[val1, val2])

assert op.is_structurally_equivalent(op2)

assert op.args == (val1, val2)


@pytest.mark.parametrize("op_type", [OperandsOp, ResultsOp])
def test_parameters_init(op_type: type[OperandsOp | ResultsOp]):
"""
Test __init__ of OperandsOp, ResultsOp.
"""

val1 = TestSSAValue(AttributeType())
val2 = TestSSAValue(AttributeType())
op = op_type([(VariadicityEnum.Single, val1), (VariadicityEnum.Optional, val2)])
op2 = op_type.create(
operands=[val1, val2],
attributes={
"variadicity": VariadicityArrayAttr(
(VariadicityEnum.Single, VariadicityEnum.Optional)
)
},
)

assert op.is_structurally_equivalent(op2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@

# CHECK: irdl.dialect @cmath {

# CHECK-NEXT: irdl.attribute @cmath.complex {
# CHECK-NEXT: irdl.attribute @complex {
# CHECK-NEXT: %{{.*}} = irdl.any
# CHECK-NEXT: irdl.parameters(%{{.*}})
# CHECK-NEXT: }

# CHECK-NEXT: irdl.operation @cmath.norm {
# CHECK-NEXT: irdl.operation @norm {
# CHECK-NEXT: %{{.*}} = irdl.any
# CHECK-NEXT: irdl.operands(%{{.*}})
# CHECK-NEXT: %{{.*}} = irdl.any
# CHECK-NEXT: irdl.results(%{{.*}})
# CHECK-NEXT: }

# CHECK-NEXT: irdl.operation @cmath.mul {
# CHECK-NEXT: irdl.operation @mul {
# CHECK-NEXT: %{{.*}} = irdl.any
# CHECK-NEXT: %{{.*}} = irdl.any
# CHECK-NEXT: irdl.operands(%{{.*}}, %{{.*}})
Expand Down
11 changes: 11 additions & 0 deletions tests/filecheck/dialects/irdl/testd.irdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,16 @@ builtin.module {
%2 = irdl.any_of(%0, %1)
irdl.results(%2, %2)
}

// CHECK: irdl.operation @variadicity {
// CHECK-NEXT: %{{.*}} = irdl.any
// CHECK-NEXT: irdl.operands(%{{.*}}, %{{.*}}, optional %{{.*}}, variadic %{{.*}})
// CHECK-NEXT: irdl.results(%{{.*}}, %{{.*}}, optional %{{.*}}, variadic %{{.*}})
// CHECK-NEXT: }
irdl.operation @variadicity {
%0 = irdl.any
irdl.operands(%0, single %0, optional %0, variadic %0)
irdl.results(%0, single %0, optional %0, variadic %0)
}
}
}
128 changes: 116 additions & 12 deletions xdsl/dialects/irdl/irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,24 @@

from __future__ import annotations

from collections.abc import Mapping, Sequence
from collections.abc import Iterable, Mapping, Sequence

from xdsl.dialects.builtin import StringAttr, SymbolRefAttr
from xdsl.dialects.builtin import ArrayAttr, StringAttr, SymbolRefAttr
from xdsl.ir import (
Attribute,
Block,
Dialect,
EnumAttribute,
OpResult,
ParametrizedAttribute,
Region,
SpacedOpaqueSyntaxAttribute,
SSAValue,
TypeAttribute,
)
from xdsl.irdl import (
IRDLOperation,
ParameterDef,
VarOperand,
attr_def,
irdl_attr_definition,
Expand All @@ -26,7 +29,7 @@
result_def,
var_operand_def,
)
from xdsl.parser import Parser
from xdsl.parser import AttrParser, Parser
from xdsl.printer import Printer
from xdsl.traits import (
HasParent,
Expand All @@ -35,12 +38,47 @@
SymbolTable,
)
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.str_enum import StrEnum

################################################################################
# Dialect, Operation, and Attribute definitions #
################################################################################


class VariadicityEnum(StrEnum):
Single = "single"
Optional = "optional"
Variadic = "variadic"
alexarice marked this conversation as resolved.
Show resolved Hide resolved


@irdl_attr_definition
class VariadicityAttr(EnumAttribute[VariadicityEnum], SpacedOpaqueSyntaxAttribute):
name = "irdl.variadicity"
Comment on lines +56 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class VariadicityAttr(EnumAttribute[VariadicityEnum], SpacedOpaqueSyntaxAttribute):
name = "irdl.variadicity"
class VariadicityAttr(EnumAttribute[VariadicityEnum], SpacedOpaqueSyntaxAttribute):
name = "irdl.variadicity"
SINGLE = VariadicityAttr(VariadicityEnum.SINGLE)

etc



@irdl_attr_definition
class VariadicityArrayAttr(ParametrizedAttribute, SpacedOpaqueSyntaxAttribute):
name = "irdl.variadicity_array"

value: ParameterDef[ArrayAttr[VariadicityAttr]]

def __init__(self, variadicities: Iterable[VariadicityEnum]) -> None:
array_attr = ArrayAttr(tuple(VariadicityAttr(x) for x in variadicities))
super().__init__((array_attr,))

@classmethod
def parse_parameters(cls, parser: AttrParser) -> tuple[ArrayAttr[VariadicityAttr]]:
data = parser.parse_comma_separated_list(
AttrParser.Delimiter.SQUARE, lambda: VariadicityAttr.parse_parameter(parser)
)
return (ArrayAttr(VariadicityAttr(x) for x in data),)

def print_parameters(self, printer: Printer) -> None:
printer.print_string("[")
printer.print_list(self.value, lambda var: var.print_parameter(printer))
printer.print_string("]")


@irdl_attr_definition
class AttributeType(ParametrizedAttribute, TypeAttribute):
"""Type of a attribute handle."""
Expand Down Expand Up @@ -115,6 +153,24 @@ def qualified_name(self):
return f"{dialect_op.sym_name.data}.{self.sym_name.data}"


@irdl_op_definition
class CPredOp(IRDLOperation):
"""Constraints an attribute using a C++ predicate"""

name = "irdl.c_pred"

pred = attr_def(StringAttr)

output: OpResult = result_def(AttributeType())

assembly_format = "$pred attr-dict"

def __init__(self, pred: str | StringAttr):
if isinstance(pred, str):
pred = StringAttr(pred)
super().__init__(attributes={"pred": pred}, result_types=[AttributeType()])

alexarice marked this conversation as resolved.
Show resolved Hide resolved

@irdl_op_definition
class AttributeOp(IRDLOperation):
"""An attribute definition."""
Expand Down Expand Up @@ -215,6 +271,23 @@ def qualified_name(self):
return f"{dialect_op.sym_name.data}.{self.sym_name.data}"


def _parse_argument(parser: Parser) -> tuple[VariadicityEnum, SSAValue]:
variadicity = parser.parse_optional_str_enum(VariadicityEnum)
if variadicity is None:
variadicity = VariadicityEnum.Single

arg = parser.parse_operand()

return (variadicity, arg)


def _print_argument(printer: Printer, data: tuple[VariadicityAttr, SSAValue]) -> None:
variadicity = data[0].data
if variadicity != VariadicityEnum.Single:
printer.print(variadicity, " ")
printer.print(data[1])


@irdl_op_definition
class OperandsOp(IRDLOperation):
"""An operation operand definition."""
Expand All @@ -223,21 +296,34 @@ class OperandsOp(IRDLOperation):

args: VarOperand = var_operand_def(AttributeType)

variadicity = attr_def(VariadicityArrayAttr)

traits = frozenset([HasParent(OperationOp)])

def __init__(self, args: Sequence[SSAValue]):
super().__init__(operands=[args])
def __init__(self, args: Sequence[tuple[VariadicityEnum, SSAValue] | SSAValue]):
args_list = [
(VariadicityEnum.Single, x) if isinstance(x, SSAValue) else x for x in args
]
operands = tuple(operand for _, operand in args_list)
attributes = {
"variadicity": VariadicityArrayAttr(tuple(v for v, _ in args_list))
}
super().__init__(operands=[operands], attributes=attributes)

@classmethod
def parse(cls, parser: Parser) -> OperandsOp:
args = parser.parse_comma_separated_list(
parser.Delimiter.PAREN, parser.parse_operand
parser.Delimiter.PAREN, lambda: _parse_argument(parser)
)
return OperandsOp(args)

def print(self, printer: Printer) -> None:
printer.print("(")
printer.print_list(self.args, printer.print, ", ")
printer.print_list(
zip(self.variadicity.value, self.args),
lambda x: _print_argument(printer, x),
", ",
)
printer.print(")")


Expand All @@ -249,21 +335,34 @@ class ResultsOp(IRDLOperation):

args: VarOperand = var_operand_def(AttributeType)

variadicity = attr_def(VariadicityArrayAttr)

traits = frozenset([HasParent(OperationOp)])

def __init__(self, args: Sequence[SSAValue]):
super().__init__(operands=[args])
def __init__(self, args: Sequence[tuple[VariadicityEnum, SSAValue] | SSAValue]):
args_list = [
(VariadicityEnum.Single, x) if isinstance(x, SSAValue) else x for x in args
]
operands = [x[1] for x in args_list]
attributes = {
"variadicity": VariadicityArrayAttr(map(lambda x: x[0], args_list))
}
super().__init__(operands=[operands], attributes=attributes)

@classmethod
def parse(cls, parser: Parser) -> ResultsOp:
args = parser.parse_comma_separated_list(
parser.Delimiter.PAREN, parser.parse_operand
parser.Delimiter.PAREN, lambda: _parse_argument(parser)
)
return ResultsOp(args)

def print(self, printer: Printer) -> None:
printer.print("(")
printer.print_list(self.args, printer.print, ", ")
printer.print_list(
zip(self.variadicity.value, self.args),
lambda x: _print_argument(printer, x),
", ",
)
printer.print(")")


Expand Down Expand Up @@ -460,6 +559,7 @@ def print(self, printer: Printer) -> None:
[
DialectOp,
TypeOp,
CPredOp,
AttributeOp,
BaseOp,
ParametersOp,
Expand All @@ -472,5 +572,9 @@ def print(self, printer: Printer) -> None:
AnyOfOp,
AllOfOp,
],
[AttributeType],
[
AttributeType,
VariadicityAttr,
VariadicityArrayAttr,
],
)
4 changes: 2 additions & 2 deletions xdsl/dialects/irdl/pyrdl_to_irdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def op_def_to_irdl(op: type[IRDLOperation]) -> OperationOp:
if result_values:
builder.insert(ResultsOp(result_values))

return OperationOp(op_def.name, Region([block]))
return OperationOp(Dialect.split_name(op_def.name)[1], Region([block]))


def attr_def_to_irdl(
Expand All @@ -69,7 +69,7 @@ def attr_def_to_irdl(
param_values.append(constraint_to_irdl(builder, param[1]))
builder.insert(ParametersOp(param_values))

return AttributeOp(attr_def.name, Region([block]))
return AttributeOp(Dialect.split_name(attr_def.name)[1], Region([block]))


def dialect_to_irdl(dialect: Dialect, name: str) -> DialectOp:
Expand Down
Loading