diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000000..23bc34a6e0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,9 @@ +# Changelog + +All changes should be documented in this file. + +## [Unreleased] +### Added +- Changelog file +### Changed +- Rename `module` to `builtin.module` \ No newline at end of file diff --git a/src/xdsl/dialects/builtin.py b/src/xdsl/dialects/builtin.py index 60701c3c41..c84c893905 100644 --- a/src/xdsl/dialects/builtin.py +++ b/src/xdsl/dialects/builtin.py @@ -3,8 +3,8 @@ from dataclasses import dataclass from typing import TypeAlias, List, cast, Type, Sequence, Optional -from xdsl.ir import (MLContext, TYPE_CHECKING, Data, ParametrizedAttribute, - Operation) +from xdsl.ir import (MLContext, TYPE_CHECKING, Data, MLIRType, + ParametrizedAttribute, Operation) from xdsl.irdl import (irdl_attr_definition, attr_constr_coercion, irdl_to_attr_constraint, irdl_op_definition, builder, ParameterDef, SingleBlockRegionDef, TypeVar, Generic, @@ -263,7 +263,7 @@ def from_type_list(types: List[Attribute]) -> TupleType: @irdl_attr_definition -class VectorType(Generic[_VectorTypeElems], ParametrizedAttribute): +class VectorType(Generic[_VectorTypeElems], ParametrizedAttribute, MLIRType): name = "vector" shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] @@ -305,7 +305,7 @@ def from_params( @irdl_attr_definition -class TensorType(Generic[_TensorTypeElems], ParametrizedAttribute): +class TensorType(Generic[_TensorTypeElems], ParametrizedAttribute, MLIRType): name = "tensor" shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] @@ -389,14 +389,14 @@ def tensor_from_list( @irdl_attr_definition -class Float32Type(ParametrizedAttribute): +class Float32Type(ParametrizedAttribute, MLIRType): name = "f32" f32 = Float32Type() -class Float64Type(ParametrizedAttribute): +class Float64Type(ParametrizedAttribute, MLIRType): name = "f64" @@ -457,7 +457,7 @@ class UnitAttr(ParametrizedAttribute): @irdl_attr_definition -class FunctionType(ParametrizedAttribute): +class FunctionType(ParametrizedAttribute, MLIRType): name = "fun" inputs: ParameterDef[ArrayAttr[Attribute]] diff --git a/src/xdsl/dialects/cmath.py b/src/xdsl/dialects/cmath.py index 442f08e1f3..d9d05c1b55 100644 --- a/src/xdsl/dialects/cmath.py +++ b/src/xdsl/dialects/cmath.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from xdsl.dialects.builtin import Float32Type, Float64Type -from xdsl.ir import MLContext +from xdsl.ir import MLContext, MLIRType from xdsl.irdl import (irdl_op_definition, Operation, OperandDef, irdl_attr_definition, ParameterDef, ParamAttrConstraint, AnyOf, ResultDef, ParametrizedAttribute, @@ -21,7 +21,7 @@ def __post_init__(self): @irdl_attr_definition -class ComplexType(ParametrizedAttribute): +class ComplexType(ParametrizedAttribute, MLIRType): name = "cmath.complex" data: ParameterDef[Float64Type | Float32Type] diff --git a/src/xdsl/dialects/llvm.py b/src/xdsl/dialects/llvm.py index 33e4abdeb7..ee1c11be6b 100644 --- a/src/xdsl/dialects/llvm.py +++ b/src/xdsl/dialects/llvm.py @@ -5,7 +5,8 @@ from xdsl.irdl import (ParameterDef, AnyAttr, irdl_op_builder, irdl_attr_definition, AttributeDef, OperandDef, ResultDef, irdl_op_definition, builder) -from xdsl.ir import (MLContext, ParametrizedAttribute, Attribute, Operation) +from xdsl.ir import (MLContext, MLIRType, ParametrizedAttribute, Attribute, + Operation) from xdsl.dialects.builtin import (StringAttr, ArrayOfConstraint, ArrayAttr, IntegerAttr, IntegerType) @@ -27,7 +28,7 @@ def __post_init__(self): @irdl_attr_definition -class LLVMStructType(ParametrizedAttribute): +class LLVMStructType(ParametrizedAttribute, MLIRType): name = "llvm.struct" # An empty string refers to a struct without a name. @@ -44,6 +45,12 @@ def from_type_list(types: list[Attribute]) -> LLVMStructType: [StringAttr.from_str(""), ArrayAttr.from_list(types)]) + def print_parameters_as_mlir(self, printer: Printer) -> None: + assert self.struct_name.data == "" + printer.print("<(") + printer.print_list(self.types.data, printer.print_attribute) + printer.print(")>") + @irdl_op_definition class LLVMExtractValue(Operation): diff --git a/src/xdsl/dialects/memref.py b/src/xdsl/dialects/memref.py index 3c100e9c8d..832518e902 100644 --- a/src/xdsl/dialects/memref.py +++ b/src/xdsl/dialects/memref.py @@ -6,7 +6,7 @@ from xdsl.dialects.builtin import (IntegerAttr, IndexType, ArrayAttr, IntegerType, FlatSymbolRefAttr, StringAttr, DenseIntOrFPElementsAttr) -from xdsl.ir import Operation, SSAValue, MLContext +from xdsl.ir import MLIRType, Operation, SSAValue, MLContext from xdsl.irdl import (irdl_attr_definition, irdl_op_definition, builder, ParameterDef, Generic, Attribute, ParametrizedAttribute, AnyAttr, OperandDef, VarOperandDef, ResultDef, @@ -36,7 +36,7 @@ def __post_init__(self): @irdl_attr_definition -class MemRefType(Generic[_MemRefTypeElement], ParametrizedAttribute): +class MemRefType(Generic[_MemRefTypeElement], ParametrizedAttribute, MLIRType): name = "memref" shape: ParameterDef[ArrayAttr[AnyIntegerAttr]] diff --git a/src/xdsl/ir.py b/src/xdsl/ir.py index e14ce03c00..7bdb755f0d 100644 --- a/src/xdsl/ir.py +++ b/src/xdsl/ir.py @@ -173,6 +173,22 @@ def __hash__(self) -> int: return hash(id(self)) +@dataclass +class MLIRType: + """ + A class representing an MLIR type. + This class should only be inherited by classes inheriting Attribute. + This class is only used for printing attributes in the MLIR format, + inheriting this class prefix the attribute by `!` instead of `#`. + """ + + def __post_init__(self): + if not isinstance(self, Attribute): + raise TypeError( + "MLIRType should only be inherited by classes inheriting Attribute" + ) + + A = TypeVar('A', bound='Attribute') @@ -221,6 +237,10 @@ def parse_parameter(parser: Parser) -> DataElement: def print_parameter(data: DataElement, printer: Printer) -> None: """Print the attribute parameter.""" + def print_parameter_as_mlir(self, printer: Printer) -> None: + """Print the attribute parameter in MLIR format.""" + self.print_parameter(self.data, printer) + @dataclass(frozen=True) class ParametrizedAttribute(Attribute): @@ -234,6 +254,12 @@ def irdl_definition(cls) -> ParamAttrDef: """Get the IRDL attribute definition.""" ... + def print_parameters_as_mlir(self, printer: Printer) -> None: + if len(self.parameters) != 0: + printer.print("<") + printer.print_list(self.parameters, printer.print_attribute) + printer.print(">") + @dataclass class Operation: diff --git a/src/xdsl/printer.py b/src/xdsl/printer.py index 3c77d99633..f22003784e 100644 --- a/src/xdsl/printer.py +++ b/src/xdsl/printer.py @@ -1,14 +1,22 @@ from __future__ import annotations +from frozenlist import FrozenList + from xdsl.diagnostic import Diagnostic -from typing import TypeVar, Any, Dict, Optional, List +from typing import TypeVar, Any, Dict, Optional, List, cast from dataclasses import dataclass, field -from xdsl.ir import (SSAValue, Block, Callable, Attribute, Region, Operation) -from xdsl.dialects.builtin import (FloatAttr, IntegerType, StringAttr, +from xdsl.dialects.memref import MemRefType +from xdsl.ir import (BlockArgument, MLIRType, SSAValue, Block, Callable, + Attribute, Region, Operation) +from xdsl.dialects.builtin import (AnyArrayAttr, AnyVectorType, + DenseIntOrFPElementsAttr, FloatAttr, + IndexType, IntegerType, StringAttr, FlatSymbolRefAttr, IntegerAttr, ArrayAttr, - ParametrizedAttribute, IntAttr, UnitAttr) + ParametrizedAttribute, IntAttr, TensorType, + UnitAttr, FunctionType, VectorType) from xdsl.irdl import Data +from enum import Enum indentNumSpaces = 2 @@ -16,11 +24,15 @@ @dataclass(eq=False, repr=False) class Printer: + class Target(Enum): + XDSL = 1 + MLIR = 2 + stream: Optional[Any] = field(default=None) print_generic_format: bool = field(default=False) - print_operand_types: bool = field(default=True) - print_result_types: bool = field(default=True) diagnostic: Diagnostic = field(default_factory=Diagnostic) + target: Target = field(default=Target.XDSL) + _indent: int = field(default=0, init=False) _ssa_values: Dict[SSAValue, str] = field(default_factory=dict, init=False) _ssa_names: Dict[str, int] = field(default_factory=dict, init=False) @@ -32,7 +44,7 @@ class Printer: _next_line_callback: List[Callable[[], None]] = field(default_factory=list, init=False) - def print(self, *argv) -> None: + def print(self, *argv: Any) -> None: for arg in argv: if isinstance(arg, str): self.print_string(arg) @@ -96,16 +108,20 @@ def _print_message(self, T = TypeVar('T') - def print_list(self, elems: List[T], print_fn: Callable[[T], - None]) -> None: + def print_list(self, + elems: FrozenList[T] | List[T], + print_fn: Callable[[T], None], + delimiter: str = ", ") -> None: if len(elems) == 0: return print_fn(elems[0]) for elem in elems[1:]: - self.print(", ") + self.print(delimiter) print_fn(elem) - def _print_new_line(self, indent=None, print_message=True) -> None: + def _print_new_line(self, + indent: int | None = None, + print_message: bool = True) -> None: indent = self._indent if indent is None else indent self.print("\n") if print_message: @@ -136,7 +152,7 @@ def _print_result_value(self, op: Operation, idx: int) -> None: name = self._get_new_valid_name_id() self._ssa_values[val] = name self.print("%s" % name) - if self.print_result_types: + if self.target == self.Target.XDSL: self.print(" : ") self.print_attribute(val.typ) @@ -174,7 +190,7 @@ def print_ssa_value(self, value: SSAValue) -> None: def _print_operand(self, operand: SSAValue) -> None: self.print_ssa_value(operand) - if self.print_operand_types: + if self.target == self.Target.XDSL: self.print(" : ") self.print_attribute(operand.typ) @@ -214,24 +230,32 @@ def _print_block_arg(self, arg: BlockArgument) -> None: def print_region(self, region: Region) -> None: if len(region.blocks) == 0: - self.print(" {}") + self.print("{}") return if len(region.blocks) == 1 and len(region.blocks[0].args) == 0: - self.print(" {") + self.print("{") self._print_ops(region.blocks[0].ops) self.print("}") return - self.print(" {") + self.print("{") self._print_new_line() for block in region.blocks: self._print_named_block(block) self.print("}") def print_regions(self, regions: List[Region]) -> None: - for region in regions: - self.print_region(region) + if len(regions) == 0: + return + + if self.target == self.Target.MLIR: + self.print(" (") + self.print_list(regions, self.print_region) + self.print(")") + else: + self.print(" ") + self.print_list(regions, self.print_region, delimiter=" ") def _print_operands(self, operands: FrozenList[SSAValue]) -> None: if len(operands) == 0: @@ -252,7 +276,10 @@ def print_attribute(self, attribute: Attribute) -> None: if isinstance(attribute, IntegerType): width = attribute.parameters[0] assert isinstance(width, IntAttr) - self.print(f'!i{width.data}') + if self.target == self.Target.MLIR: + self.print(f'i{width.data}') + else: + self.print(f'!i{width.data}') return if isinstance(attribute, StringAttr): @@ -286,6 +313,93 @@ def print_attribute(self, attribute: Attribute) -> None: self.print_string("]") return + # Function types have an alias in MLIR, but not in xDSL + if (isinstance(attribute, FunctionType) + and self.target == self.Target.MLIR): + self.print("(") + self.print_list(attribute.inputs.data, self.print_attribute) + self.print(") -> ") + outputs = attribute.outputs.data + if len(outputs) == 1 and not isinstance(outputs[0], FunctionType): + self.print_attribute(outputs[0]) + else: + self.print("(") + self.print_list(outputs, self.print_attribute) + self.print(")") + return + + # Dense element types have an alias in MLIR, but not in xDSL + if (isinstance(attribute, DenseIntOrFPElementsAttr) + and self.target == self.Target.MLIR): + + def print_dense_list(array: AnyArrayAttr): + + def print_one_elem(val: Attribute): + if isinstance(val, ArrayAttr): + print_dense_list(cast(AnyArrayAttr, val)) + elif isinstance(val, IntegerAttr): + self.print(val.value.data) + else: + raise Exception("unexpected attribute type " + "in DenseIntOrFPElementsAttr: " + f"{type(val)}") + + self.print("[") + self.print_list(array.data, print_one_elem) + self.print("]") + + self.print("dense<") + print_dense_list(attribute.data) + self.print("> : ") + self.print(attribute.type) + return + + # vector types have an alias in MLIR, but not in xDSL + if ((isinstance(attribute, VectorType) + or isinstance(attribute, TensorType)) + and self.target == self.Target.MLIR): + attribute = cast(AnyVectorType, attribute) + self.print( + "vector<" if isinstance(attribute, VectorType) else "tensor<") + self.print_list(attribute.shape.data, + lambda x: self.print(x.value.data), "x") + self.print("x", attribute.element_type) + self.print(">") + return + + # memref types have an alias in MLIR, but not in xDSL + if (isinstance(attribute, MemRefType) + and self.target == self.Target.MLIR): + attribute = cast(MemRefType[Attribute], attribute) + self.print("memref<") + self.print_list(attribute.shape.data, + lambda x: self.print(x.value.data), "x") + self.print("x", attribute.element_type) + self.print(">") + return + + # index type have an alias in MLIR, but not in xDSL + if (isinstance(attribute, IndexType) + and self.target == self.Target.MLIR): + self.print("index") + return + + if self.target == self.Target.MLIR: + # For the MLIR target, we may print differently some attributes + self.print("!" if isinstance(attribute, MLIRType) else "#") + self.print(attribute.name) + + if isinstance(attribute, Data): + self.print("<") + attribute.print_parameter_as_mlir(self) + self.print(">") + return + + assert isinstance(attribute, ParametrizedAttribute) + + attribute.print_parameters_as_mlir(self) + return + if isinstance(attribute, Data): self.print(f'!{attribute.name}<') attribute.print_parameter(attribute.data, self) @@ -303,9 +417,9 @@ def print_attribute(self, attribute: Attribute) -> None: def print_successors(self, successors: List[Block]): if len(successors) == 0: return - self.print(" (") + self.print(" (" if self.target == self.Target.XDSL else " [") self.print_list(successors, self.print_block_name) - self.print(")") + self.print(")" if self.target == self.Target.XDSL else "]") def _print_attr_string(self, attr_tuple: tuple[str, Attribute]) -> None: if isinstance(attr_tuple[1], UnitAttr): @@ -319,23 +433,51 @@ def _print_op_attributes(self, attributes: Dict[str, Attribute]) -> None: return self.print(" ") - self.print("[") + self.print("[" if self.target == Printer.Target.XDSL else "{") attribute_list = [p for p in attributes.items()] self.print_list(attribute_list, self._print_attr_string) - self.print("]") + self.print("]" if self.target == Printer.Target.XDSL else "}") def print_op_with_default_format(self, op: Operation) -> None: self._print_operands(op.operands) self.print_successors(op.successors) - self._print_op_attributes(op.attributes) - self.print_regions(op.regions) + + # We print attributes with the operation in xDSL. + if self.target == self.Target.XDSL: + self._print_op_attributes(op.attributes) + self.print_regions(op.regions) + else: + self.print_regions(op.regions) + self._print_op_attributes(op.attributes) + + # Print the operation type + if self.target == self.Target.MLIR: + self.print(" : (") + self.print_list(op.operands, + lambda operand: self.print_attribute(operand.typ)) + self.print(") -> ") + if len(op.results) == 0: + self.print("()") + elif len(op.results) == 1: + typ = op.results[0].typ + # Handle ambiguous case + if isinstance(typ, FunctionType): + self.print("(", typ, ")") + else: + self.print(typ) + else: + self.print("(") + self.print_list( + op.results, + lambda result: self.print_attribute(result.typ)) + self.print(")") def _print_op(self, op: Operation) -> None: begin_op_pos = self._current_column self._print_results(op) - if self.print_generic_format: + if self.print_generic_format or self.target == self.Target.MLIR: self.print(f'"{op.name}"') else: self.print(op.name) @@ -344,7 +486,7 @@ def _print_op(self, op: Operation) -> None: for message in self.diagnostic.op_messages[op]: self._add_message_on_next_line(message, begin_op_pos, end_op_pos) - if self.print_generic_format: + if self.print_generic_format or self.target == self.Target.MLIR: self.print_op_with_default_format(op) else: op.print(self) diff --git a/src/xdsl/xdsl_opt_main.py b/src/xdsl/xdsl_opt_main.py index e6f02dffa2..f22ca35889 100644 --- a/src/xdsl/xdsl_opt_main.py +++ b/src/xdsl/xdsl_opt_main.py @@ -142,6 +142,13 @@ def register_all_arguments(self, arg_parser: argparse.ArgumentParser): help="Prints the content of a triggered " "exception and exits with code 0") + arg_parser.add_argument( + "--use-mlir-bindings", + default=False, + action='store_true', + help="Use the MLIR bindings for printing MLIR. " + "This requires the MLIR Python bindings to be installed.") + def register_all_dialects(self): """ Register all dialects that can be used. @@ -197,9 +204,14 @@ def _output_xdsl(prog: ModuleOp, output: IOBase): printer.print_op(prog) def _output_mlir(prog: ModuleOp, output: IOBase): - converter = MLIRConverter(self.ctx) - mlir_module = converter.convert_module(prog) - print(mlir_module, file=output) + if self.args.use_mlir_bindings: + from xdsl.mlir_converter import MLIRConverter + converter = MLIRConverter(self.ctx) + mlir_module = converter.convert_module(prog) + print(mlir_module, file=output) + else: + printer = Printer(stream=output, target=Printer.Target.MLIR) + printer.print_op(prog) def _output_irdl(prog: ModuleOp, output: IOBase): irdl_to_mlir = IRDLPrinter(stream=output) @@ -207,13 +219,7 @@ def _output_irdl(prog: ModuleOp, output: IOBase): self.available_targets['xdsl'] = _output_xdsl self.available_targets['irdl'] = _output_irdl - - try: - from xdsl.mlir_converter import MLIRConverter - self.available_targets['mlir'] = _output_mlir - except ImportError: - # do not add mlir as target if import does not work - pass + self.available_targets['mlir'] = _output_mlir def setup_pipeline(self): """ diff --git a/tests/filecheck/mlir-conversion/llvm_test.xdsl b/tests/filecheck/mlir-conversion/llvm_test.xdsl index c1fb0f8b2e..a17d620727 100644 --- a/tests/filecheck/mlir-conversion/llvm_test.xdsl +++ b/tests/filecheck/mlir-conversion/llvm_test.xdsl @@ -1,4 +1,4 @@ -// RUN: xdsl-opt -t mlir %s | filecheck %s +// RUN: xdsl-opt -t mlir --use-mlir-bindings %s | filecheck %s builtin.module() { diff --git a/tests/filecheck/mlir-conversion/ops.xdsl b/tests/filecheck/mlir-conversion/ops.xdsl index 8ff03ee0ce..122a646614 100644 --- a/tests/filecheck/mlir-conversion/ops.xdsl +++ b/tests/filecheck/mlir-conversion/ops.xdsl @@ -1,4 +1,5 @@ -// RUN: xdsl-opt -t mlir %s | filecheck %s +// RUN: xdsl-opt -t mlir --use-mlir-bindings %s | filecheck %s +// RUN: xdsl-opt -t mlir --use-mlir-bindings %s | mlir-opt --mlir-print-op-generic > %t-1 && xdsl-opt -t mlir %s | mlir-opt --mlir-print-op-generic > %t-2 && diff %t-1 %t-2 // Tests if the non generic form can be printed. diff --git a/tests/filecheck/parser-printer/llvm_mlir_printer.xdsl b/tests/filecheck/parser-printer/llvm_mlir_printer.xdsl new file mode 100644 index 0000000000..b75ec9d480 --- /dev/null +++ b/tests/filecheck/parser-printer/llvm_mlir_printer.xdsl @@ -0,0 +1,17 @@ +// RUN: xdsl-opt -t mlir %s | filecheck %s + + +"builtin.module"() { + func.func() ["sym_name" = "struct_to_struct", "function_type" = !fun<[!llvm.struct<"", [!i32]>], [!llvm.struct<"", [!i32]>]>, "sym_visibility" = "private"] { + ^0(%0 : !llvm.struct<"", [!i32]>): + func.return(%0 : !llvm.struct<"", [!i32]>) + } +} + + +// CHECK: "builtin.module"() ({ +// CHECK-NEXT: "func.func"() ({ +// CHECK-NEXT: ^0(%0 : !llvm.struct<(i32)>): +// CHECK-NEXT: "func.return"(%0) : (!llvm.struct<(i32)>) -> () +// CHECK-NEXT: }) {"sym_name" = "struct_to_struct", "function_type" = (!llvm.struct<(i32)>) -> !llvm.struct<(i32)>, "sym_visibility" = "private"} : () -> () +// CHECK-NEXT:}) : () -> () diff --git a/tests/mlir_printer_test.py b/tests/mlir_printer_test.py new file mode 100644 index 0000000000..d9446d7404 --- /dev/null +++ b/tests/mlir_printer_test.py @@ -0,0 +1,249 @@ +from io import StringIO +from xdsl.ir import Attribute, Data, MLContext, MLIRType, Operation, ParametrizedAttribute +from xdsl.irdl import (AnyAttr, ParameterDef, RegionDef, VarOperandDef, + VarResultDef, irdl_attr_definition, irdl_op_definition) +from xdsl.parser import Parser +from xdsl.printer import Printer + +import re + + +@irdl_op_definition +class ModuleOp(Operation): + """Module operation. Redefined to not depend on the builtin dialect.""" + name = "module" + region = RegionDef() + + +@irdl_op_definition +class AnyOp(Operation): + """Operation only used for testing.""" + name = "any" + op = VarOperandDef(AnyAttr()) + res = VarResultDef(AnyAttr()) + + +@irdl_attr_definition +class DataAttr(Data[int]): + """Attribute only used for testing.""" + name = "data_attr" + + @staticmethod + def parse_parameter(parser: Parser) -> int: + return parser.parse_int_literal() + + @staticmethod + def print_parameter(data: int, printer: Printer) -> None: + printer.print(data) + + +@irdl_attr_definition +class DataType(Data[int], MLIRType): + """Attribute only used for testing.""" + name = "data_type" + + @staticmethod + def parse_parameter(parser: Parser) -> int: + return parser.parse_int_literal() + + @staticmethod + def print_parameter(data: int, printer: Printer) -> None: + printer.print(data) + + +@irdl_attr_definition +class ParamAttr(ParametrizedAttribute): + name = "param_attr" + + +@irdl_attr_definition +class ParamAttrWithParam(ParametrizedAttribute): + name = "param_attr_with_param" + data: ParameterDef[Attribute] + + +@irdl_attr_definition +class ParamType(ParametrizedAttribute, MLIRType): + name = "param_type" + + +@irdl_attr_definition +class DataAttrWithCustomFormat(Data[int]): + name = "data_custom_format" + + @staticmethod + def parse_parameter(parser: Parser) -> int: + return parser.parse_int_literal() + + @staticmethod + def print_parameter(data: int, printer: Printer) -> None: + printer.print(data) + + def print_parameter_as_mlir(self, printer: Printer) -> None: + printer.print(f"~{self.data}~") + + +@irdl_attr_definition +class ParamAttrWithCustomFormat(ParametrizedAttribute): + name = "param_custom_format" + param1: ParameterDef[ParamAttr] + + def print_parameters_as_mlir(self, printer: Printer) -> None: + printer.print(f"~~") + + +def print_as_mlir_and_compare(test_prog: str, expected: str): + ctx = MLContext() + + ctx.register_op(ModuleOp) + ctx.register_op(AnyOp) + ctx.register_attr(DataAttr) + ctx.register_attr(DataType) + ctx.register_attr(ParamAttr) + ctx.register_attr(ParamType) + ctx.register_attr(ParamAttrWithParam) + ctx.register_attr(DataAttrWithCustomFormat) + ctx.register_attr(ParamAttrWithCustomFormat) + + parser = Parser(ctx, test_prog) + module = parser.parse_op() + + res = StringIO() + printer = Printer(target=Printer.Target.MLIR, stream=res) + printer.print_op(module) + + # Remove all whitespace from the expected string. + regex = re.compile(r'[^\S]+') + assert (regex.sub("", res.getvalue()).strip() == \ + regex.sub("", expected).strip()) + + +def test_empty_op(): + """Test printing an empty operation.""" + print_as_mlir_and_compare( + """any()""", + """"any"() : () -> ()""", + ) + + +def test_data_attr(): + """Test printing an operation with a data attribute.""" + print_as_mlir_and_compare( + """any() [ "attr" = !data_attr<42> ]""", + """"any"() {"attr" = #data_attr<42>} : () -> ()""", + ) + + +def test_data_type(): + """Test printing an operation with a data type.""" + print_as_mlir_and_compare( + """%0 : !data_type<42> = any()""", + """%0 = "any"() : () -> !data_type<42>""", + ) + + +def test_param_attr(): + """Test printing an operation with a parametrized attribute.""" + print_as_mlir_and_compare( + """any() [ "attr" = !param_attr ]""", + """"any"() {"attr" = #param_attr } : () -> ()""", + ) + + +def test_param_type(): + """Test printing an operation with a parametrized type.""" + print_as_mlir_and_compare( + """%0 : !param_type = any()""", + """%0 = "any"() : () -> !param_type""", + ) + + +def test_param_attr_with_param(): + """ + Test printing an operation with a parametrized attribute with parameters. + """ + print_as_mlir_and_compare( + """any() [ "attr" = !param_attr_with_param ]""", + """"any"() {"attr" = #param_attr_with_param<#param_attr> } + : () -> ()""", + ) + + print_as_mlir_and_compare( + """any() [ "attr" = !param_attr_with_param ]""", + """"any"() {"attr" = #param_attr_with_param } + : () -> ()""", + ) + + +def test_op_with_region(): + """Test printing an operation with a region.""" + + print_as_mlir_and_compare( + """module() {}""", + """"module"() ({}) : () -> ()""", + ) + + +def test_op_with_results(): + """Test printing an operation with results.""" + + print_as_mlir_and_compare( + """%0 : !param_attr = any()""", + """%0 = "any"() : () -> #param_attr""", + ) + + print_as_mlir_and_compare( + """(%0 : !param_attr, %1 : !param_type) = any()""", + """(%0, %1) = "any"() : () -> (#param_attr, !param_type)""", + ) + + +def test_op_with_operands(): + """Test printing an operation with operands.""" + print_as_mlir_and_compare( + """module() { + %0 : !param_attr = any() + any(%0 : !param_attr) + }""", + """"module"() ({ + %0 = "any"() : () -> #param_attr + "any"(%0) : (#param_attr) -> () + }) : () -> () + """, + ) + + print_as_mlir_and_compare( + """module() { + %0 : !param_attr = any() + any(%0 : !param_attr, %0 : !param_attr) + }""", + """"module"() ({ + %0 = "any"() : () -> #param_attr + "any"(%0, %0) : (#param_attr, #param_attr) -> () + }) : () -> () + """, + ) + + +def test_op_with_attributes(): + """Test printing an operation with attributes.""" + print_as_mlir_and_compare( + """any() [ "attr" = !data_attr<42> ]""", + """"any"() {"attr" = #data_attr<42>} : () -> ()""", + ) + + +def test_data_custom_format(): + """Test printing an operation with a data attribute with custom format.""" + print_as_mlir_and_compare( + """any() [ "attr" = !data_custom_format<42> ]""", + """"any"() {"attr" = #data_custom_format<~42~>} : () -> ()""", + ) + + +def test_param_custom_format(): + """Test printing an operation with a param attribute with custom format.""" + print_as_mlir_and_compare( + """any() [ "attr" = !param_custom_format ]""", + """"any"() {"attr" = #param_custom_format~~} : () -> ()""", + )