Skip to content

Commit

Permalink
Allow to print MLIR instead of xDSL (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiasgrosser authored Aug 24, 2022
1 parent 7aee0ad commit 302b12d
Show file tree
Hide file tree
Showing 12 changed files with 509 additions and 52 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Changelog

All changes should be documented in this file.

## [Unreleased]
### Added
- Changelog file
### Changed
- Rename `module` to `builtin.module`
14 changes: 7 additions & 7 deletions src/xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from dataclasses import dataclass
from typing import TypeAlias, List, cast, Type, Sequence, Optional

from xdsl.ir import (MLContext, TYPE_CHECKING, Data, ParametrizedAttribute,
Operation)
from xdsl.ir import (MLContext, TYPE_CHECKING, Data, MLIRType,
ParametrizedAttribute, Operation)
from xdsl.irdl import (irdl_attr_definition, attr_constr_coercion,
irdl_to_attr_constraint, irdl_op_definition, builder,
ParameterDef, SingleBlockRegionDef, TypeVar, Generic,
Expand Down Expand Up @@ -263,7 +263,7 @@ def from_type_list(types: List[Attribute]) -> TupleType:


@irdl_attr_definition
class VectorType(Generic[_VectorTypeElems], ParametrizedAttribute):
class VectorType(Generic[_VectorTypeElems], ParametrizedAttribute, MLIRType):
name = "vector"

shape: ParameterDef[ArrayAttr[AnyIntegerAttr]]
Expand Down Expand Up @@ -305,7 +305,7 @@ def from_params(


@irdl_attr_definition
class TensorType(Generic[_TensorTypeElems], ParametrizedAttribute):
class TensorType(Generic[_TensorTypeElems], ParametrizedAttribute, MLIRType):
name = "tensor"

shape: ParameterDef[ArrayAttr[AnyIntegerAttr]]
Expand Down Expand Up @@ -389,14 +389,14 @@ def tensor_from_list(


@irdl_attr_definition
class Float32Type(ParametrizedAttribute):
class Float32Type(ParametrizedAttribute, MLIRType):
name = "f32"


f32 = Float32Type()


class Float64Type(ParametrizedAttribute):
class Float64Type(ParametrizedAttribute, MLIRType):
name = "f64"


Expand Down Expand Up @@ -457,7 +457,7 @@ class UnitAttr(ParametrizedAttribute):


@irdl_attr_definition
class FunctionType(ParametrizedAttribute):
class FunctionType(ParametrizedAttribute, MLIRType):
name = "fun"

inputs: ParameterDef[ArrayAttr[Attribute]]
Expand Down
4 changes: 2 additions & 2 deletions src/xdsl/dialects/cmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass

from xdsl.dialects.builtin import Float32Type, Float64Type
from xdsl.ir import MLContext
from xdsl.ir import MLContext, MLIRType
from xdsl.irdl import (irdl_op_definition, Operation, OperandDef,
irdl_attr_definition, ParameterDef, ParamAttrConstraint,
AnyOf, ResultDef, ParametrizedAttribute,
Expand All @@ -21,7 +21,7 @@ def __post_init__(self):


@irdl_attr_definition
class ComplexType(ParametrizedAttribute):
class ComplexType(ParametrizedAttribute, MLIRType):
name = "cmath.complex"
data: ParameterDef[Float64Type | Float32Type]

Expand Down
11 changes: 9 additions & 2 deletions src/xdsl/dialects/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from xdsl.irdl import (ParameterDef, AnyAttr, irdl_op_builder,
irdl_attr_definition, AttributeDef, OperandDef,
ResultDef, irdl_op_definition, builder)
from xdsl.ir import (MLContext, ParametrizedAttribute, Attribute, Operation)
from xdsl.ir import (MLContext, MLIRType, ParametrizedAttribute, Attribute,
Operation)
from xdsl.dialects.builtin import (StringAttr, ArrayOfConstraint, ArrayAttr,
IntegerAttr, IntegerType)

Expand All @@ -27,7 +28,7 @@ def __post_init__(self):


@irdl_attr_definition
class LLVMStructType(ParametrizedAttribute):
class LLVMStructType(ParametrizedAttribute, MLIRType):
name = "llvm.struct"

# An empty string refers to a struct without a name.
Expand All @@ -44,6 +45,12 @@ def from_type_list(types: list[Attribute]) -> LLVMStructType:
[StringAttr.from_str(""),
ArrayAttr.from_list(types)])

def print_parameters_as_mlir(self, printer: Printer) -> None:
assert self.struct_name.data == ""
printer.print("<(")
printer.print_list(self.types.data, printer.print_attribute)
printer.print(")>")


@irdl_op_definition
class LLVMExtractValue(Operation):
Expand Down
4 changes: 2 additions & 2 deletions src/xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from xdsl.dialects.builtin import (IntegerAttr, IndexType, ArrayAttr,
IntegerType, FlatSymbolRefAttr, StringAttr,
DenseIntOrFPElementsAttr)
from xdsl.ir import Operation, SSAValue, MLContext
from xdsl.ir import MLIRType, Operation, SSAValue, MLContext
from xdsl.irdl import (irdl_attr_definition, irdl_op_definition, builder,
ParameterDef, Generic, Attribute, ParametrizedAttribute,
AnyAttr, OperandDef, VarOperandDef, ResultDef,
Expand Down Expand Up @@ -36,7 +36,7 @@ def __post_init__(self):


@irdl_attr_definition
class MemRefType(Generic[_MemRefTypeElement], ParametrizedAttribute):
class MemRefType(Generic[_MemRefTypeElement], ParametrizedAttribute, MLIRType):
name = "memref"

shape: ParameterDef[ArrayAttr[AnyIntegerAttr]]
Expand Down
26 changes: 26 additions & 0 deletions src/xdsl/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,22 @@ def __hash__(self) -> int:
return hash(id(self))


@dataclass
class MLIRType:
"""
A class representing an MLIR type.
This class should only be inherited by classes inheriting Attribute.
This class is only used for printing attributes in the MLIR format,
inheriting this class prefix the attribute by `!` instead of `#`.
"""

def __post_init__(self):
if not isinstance(self, Attribute):
raise TypeError(
"MLIRType should only be inherited by classes inheriting Attribute"
)


A = TypeVar('A', bound='Attribute')


Expand Down Expand Up @@ -221,6 +237,10 @@ def parse_parameter(parser: Parser) -> DataElement:
def print_parameter(data: DataElement, printer: Printer) -> None:
"""Print the attribute parameter."""

def print_parameter_as_mlir(self, printer: Printer) -> None:
"""Print the attribute parameter in MLIR format."""
self.print_parameter(self.data, printer)


@dataclass(frozen=True)
class ParametrizedAttribute(Attribute):
Expand All @@ -234,6 +254,12 @@ def irdl_definition(cls) -> ParamAttrDef:
"""Get the IRDL attribute definition."""
...

def print_parameters_as_mlir(self, printer: Printer) -> None:
if len(self.parameters) != 0:
printer.print("<")
printer.print_list(self.parameters, printer.print_attribute)
printer.print(">")


@dataclass
class Operation:
Expand Down
Loading

0 comments on commit 302b12d

Please sign in to comment.