Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Implement MemrefLayoutAttr base class and use in parser. #2718

Merged
merged 6 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1175,8 +1175,18 @@ def from_strings(name: str, value: str, type: Attribute = NoneAttr()) -> OpaqueA
return OpaqueAttr([StringAttr(name), StringAttr(value), type])


class MemrefLayoutAttr(Attribute):
"""
Some explanation here
"""

name = "memref_layout_att"

pass


@irdl_attr_definition
class StridedLayoutAttr(ParametrizedAttribute):
class StridedLayoutAttr(MemrefLayoutAttr, ParametrizedAttribute):
"""
An attribute representing a strided layout of a shaped type.
See https://mlir.llvm.org/docs/Dialects/Builtin/#stridedlayoutattr
Expand Down Expand Up @@ -1218,7 +1228,7 @@ def __init__(


@irdl_attr_definition
class AffineMapAttr(Data[AffineMap]):
class AffineMapAttr(MemrefLayoutAttr, Data[AffineMap]):
"""An Attribute containing an AffineMap object."""

name = "affine_map"
Expand Down Expand Up @@ -1525,14 +1535,14 @@ class MemRefType(

shape: ParameterDef[ArrayAttr[IntAttr]]
element_type: ParameterDef[_MemRefTypeElement]
layout: ParameterDef[Attribute]
layout: ParameterDef[MemrefLayoutAttr | NoneAttr]
memory_space: ParameterDef[Attribute]

def __init__(
self: MemRefType[_MemRefTypeElement],
element_type: _MemRefTypeElement,
shape: Iterable[int | IntAttr],
layout: Attribute = NoneAttr(),
layout: MemrefLayoutAttr | NoneAttr = NoneAttr(),
memory_space: Attribute = NoneAttr(),
):
shape = ArrayAttr(
Expand Down Expand Up @@ -1561,7 +1571,7 @@ def get_element_type(self) -> _MemRefTypeElement:
def from_element_type_and_shape(
referenced_type: _MemRefTypeElement,
shape: Iterable[int | AnyIntegerAttr],
layout: Attribute = NoneAttr(),
layout: MemrefLayoutAttr | NoneAttr = NoneAttr(),
memory_space: Attribute = NoneAttr(),
) -> MemRefType[_MemRefTypeElement]:
shape_int = [i if isinstance(i, int) else i.value.data for i in shape]
Expand All @@ -1574,7 +1584,7 @@ def from_params(
shape: ArrayAttr[AnyIntegerAttr] = ArrayAttr(
[IntegerAttr.from_int_and_width(1, 64)]
),
layout: Attribute = NoneAttr(),
layout: MemrefLayoutAttr | NoneAttr = NoneAttr(),
memory_space: Attribute = NoneAttr(),
) -> MemRefType[_MemRefTypeElement]:
shape_int = [i.value.data for i in shape.data]
Expand Down
3 changes: 2 additions & 1 deletion xdsl/dialects/experimental/aie.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
IntAttr,
IntegerAttr,
IntegerType,
MemrefLayoutAttr,
MemRefType,
ModuleOp,
NoneAttr,
Expand Down Expand Up @@ -163,7 +164,7 @@ class ObjectFIFO(Generic[AttributeInvT], ParametrizedAttribute):
def from_element_type_and_shape(
referenced_type: AttributeInvT,
shape: Iterable[int | IntAttr],
layout: Attribute = NoneAttr(),
layout: MemrefLayoutAttr | NoneAttr = NoneAttr(),
memory_space: Attribute = NoneAttr(),
) -> ObjectFIFO[AttributeInvT]:
return ObjectFIFO(
Expand Down
5 changes: 3 additions & 2 deletions xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
IntAttr,
IntegerAttr,
IntegerType,
MemrefLayoutAttr,
MemRefType,
NoneAttr,
StridedLayoutAttr,
Expand Down Expand Up @@ -167,7 +168,7 @@ def get(
alignment: int | AnyIntegerAttr | None = None,
shape: Iterable[int | IntAttr] | None = None,
dynamic_sizes: Sequence[SSAValue | Operation] | None = None,
layout: Attribute = NoneAttr(),
layout: MemrefLayoutAttr | NoneAttr = NoneAttr(),
memory_space: Attribute = NoneAttr(),
) -> Self:
if shape is None:
Expand Down Expand Up @@ -305,7 +306,7 @@ def get(
alignment: int | AnyIntegerAttr | None = None,
shape: Iterable[int | IntAttr] | None = None,
dynamic_sizes: Sequence[SSAValue | Operation] | None = None,
layout: Attribute = NoneAttr(),
layout: MemrefLayoutAttr | NoneAttr = NoneAttr(),
memory_space: Attribute = NoneAttr(),
) -> Alloca:
if shape is None:
Expand Down
17 changes: 7 additions & 10 deletions xdsl/parser/attribute_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
IntegerAttr,
IntegerType,
LocationAttr,
MemrefLayoutAttr,
MemRefType,
NoneAttr,
NoneType,
Expand Down Expand Up @@ -520,6 +521,8 @@ def _parse_memref_attrs(
# layout is the second one
if self.parse_optional_punctuation(",") is not None:
memory_space = self.parse_attribute()
if not isinstance(memory_or_layout, MemrefLayoutAttr):
self.raise_error("Expected a MemRef layout attribute!")
return MemRefType(type, shape, memory_or_layout, memory_space)

# Otherwise, there is a single argument, so we check based on the
Expand All @@ -528,18 +531,12 @@ def _parse_memref_attrs(
# support.

# If the argument is an integer, it is a memory space
if isa(memory_or_layout, AnyIntegerAttr):
return MemRefType(type, shape, memory_space=memory_or_layout)
if isa(memory_or_layout, MemrefLayoutAttr):
return MemRefType(type, shape, layout=memory_or_layout)

# We only accept strided layouts and affine_maps
if isa(memory_or_layout, StridedLayoutAttr) or (
isinstance(memory_or_layout, UnregisteredAttr)
and memory_or_layout.attr_name.data == "affine_map"
):
return MemRefType(type, shape, layout=memory_or_layout)
self.raise_error(
"Cannot decide if the given attribute " "is a layout or a memory space!"
)
else:
return MemRefType(type, shape, memory_space=memory_or_layout)

def _parse_vector_attrs(self) -> AnyVectorType:
dims: list[int] = []
Expand Down
Loading