diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index e6f4587b49..f8ed76c7db 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -1175,8 +1175,18 @@ def from_strings(name: str, value: str, type: Attribute = NoneAttr()) -> OpaqueA return OpaqueAttr([StringAttr(name), StringAttr(value), type]) +class MemrefLayoutAttr(Attribute, ABC): + """ + Interface for any attribute acceptable as a memref layout. + """ + + name = "abstract.memref_layout_att" + + pass + + @irdl_attr_definition -class StridedLayoutAttr(ParametrizedAttribute): +class StridedLayoutAttr(MemrefLayoutAttr, ParametrizedAttribute): """ An attribute representing a strided layout of a shaped type. See https://mlir.llvm.org/docs/Dialects/Builtin/#stridedlayoutattr @@ -1218,7 +1228,7 @@ def __init__( @irdl_attr_definition -class AffineMapAttr(Data[AffineMap]): +class AffineMapAttr(MemrefLayoutAttr, Data[AffineMap]): """An Attribute containing an AffineMap object.""" name = "affine_map" @@ -1525,14 +1535,14 @@ class MemRefType( shape: ParameterDef[ArrayAttr[IntAttr]] element_type: ParameterDef[_MemRefTypeElement] - layout: ParameterDef[Attribute] + layout: ParameterDef[MemrefLayoutAttr | NoneAttr] memory_space: ParameterDef[Attribute] def __init__( self: MemRefType[_MemRefTypeElement], element_type: _MemRefTypeElement, shape: Iterable[int | IntAttr], - layout: Attribute = NoneAttr(), + layout: MemrefLayoutAttr | NoneAttr = NoneAttr(), memory_space: Attribute = NoneAttr(), ): shape = ArrayAttr( @@ -1561,7 +1571,7 @@ def get_element_type(self) -> _MemRefTypeElement: def from_element_type_and_shape( referenced_type: _MemRefTypeElement, shape: Iterable[int | AnyIntegerAttr], - layout: Attribute = NoneAttr(), + layout: MemrefLayoutAttr | NoneAttr = NoneAttr(), memory_space: Attribute = NoneAttr(), ) -> MemRefType[_MemRefTypeElement]: shape_int = [i if isinstance(i, int) else i.value.data for i in shape] @@ -1574,7 +1584,7 @@ def from_params( shape: ArrayAttr[AnyIntegerAttr] = ArrayAttr( [IntegerAttr.from_int_and_width(1, 64)] ), - layout: Attribute = NoneAttr(), + layout: MemrefLayoutAttr | NoneAttr = NoneAttr(), memory_space: Attribute = NoneAttr(), ) -> MemRefType[_MemRefTypeElement]: shape_int = [i.value.data for i in shape.data] diff --git a/xdsl/dialects/experimental/aie.py b/xdsl/dialects/experimental/aie.py index e9eab2e8c9..965d8df7dc 100644 --- a/xdsl/dialects/experimental/aie.py +++ b/xdsl/dialects/experimental/aie.py @@ -22,6 +22,7 @@ IntAttr, IntegerAttr, IntegerType, + MemrefLayoutAttr, MemRefType, ModuleOp, NoneAttr, @@ -163,7 +164,7 @@ class ObjectFIFO(Generic[AttributeInvT], ParametrizedAttribute): def from_element_type_and_shape( referenced_type: AttributeInvT, shape: Iterable[int | IntAttr], - layout: Attribute = NoneAttr(), + layout: MemrefLayoutAttr | NoneAttr = NoneAttr(), memory_space: Attribute = NoneAttr(), ) -> ObjectFIFO[AttributeInvT]: return ObjectFIFO( diff --git a/xdsl/dialects/memref.py b/xdsl/dialects/memref.py index 86394ee78f..f0de04545c 100644 --- a/xdsl/dialects/memref.py +++ b/xdsl/dialects/memref.py @@ -17,6 +17,7 @@ IntAttr, IntegerAttr, IntegerType, + MemrefLayoutAttr, MemRefType, NoneAttr, StridedLayoutAttr, @@ -167,7 +168,7 @@ def get( alignment: int | AnyIntegerAttr | None = None, shape: Iterable[int | IntAttr] | None = None, dynamic_sizes: Sequence[SSAValue | Operation] | None = None, - layout: Attribute = NoneAttr(), + layout: MemrefLayoutAttr | NoneAttr = NoneAttr(), memory_space: Attribute = NoneAttr(), ) -> Self: if shape is None: @@ -305,7 +306,7 @@ def get( alignment: int | AnyIntegerAttr | None = None, shape: Iterable[int | IntAttr] | None = None, dynamic_sizes: Sequence[SSAValue | Operation] | None = None, - layout: Attribute = NoneAttr(), + layout: MemrefLayoutAttr | NoneAttr = NoneAttr(), memory_space: Attribute = NoneAttr(), ) -> Alloca: if shape is None: diff --git a/xdsl/parser/attribute_parser.py b/xdsl/parser/attribute_parser.py index 6d79a4229e..216a698672 100644 --- a/xdsl/parser/attribute_parser.py +++ b/xdsl/parser/attribute_parser.py @@ -38,6 +38,7 @@ IntegerAttr, IntegerType, LocationAttr, + MemrefLayoutAttr, MemRefType, NoneAttr, NoneType, @@ -520,27 +521,18 @@ def _parse_memref_attrs( # layout is the second one if self.parse_optional_punctuation(",") is not None: memory_space = self.parse_attribute() + if not isinstance(memory_or_layout, MemrefLayoutAttr): + self.raise_error("Expected a MemRef layout attribute") return MemRefType(type, shape, memory_or_layout, memory_space) - # Otherwise, there is a single argument, so we check based on the - # attribute type. If we don't know, we return an error. - # MLIR base itself on the `MemRefLayoutAttrInterface`, which we do not - # support. + # If the argument is a MemrefLayoutAttr, use it as layout + if isinstance(memory_or_layout, MemrefLayoutAttr): + return MemRefType(type, shape, layout=memory_or_layout) - # If the argument is an integer, it is a memory space - if isa(memory_or_layout, AnyIntegerAttr): + # Otherwise, consider it as the memory space. + else: return MemRefType(type, shape, memory_space=memory_or_layout) - # We only accept strided layouts and affine_maps - if isa(memory_or_layout, StridedLayoutAttr) or ( - isinstance(memory_or_layout, UnregisteredAttr) - and memory_or_layout.attr_name.data == "affine_map" - ): - return MemRefType(type, shape, layout=memory_or_layout) - self.raise_error( - "Cannot decide if the given attribute " "is a layout or a memory space!" - ) - def _parse_vector_attrs(self) -> AnyVectorType: dims: list[int] = [] num_scalable_dims = 0