Skip to content

Commit

Permalink
dialects: (memref_stream) add doc and library call support (#2928)
Browse files Browse the repository at this point in the history
We are using the `library_call` attribute to designate to which
accelerator we want to dispatch a certain linalg generic to. Right now,
this information gets lost when converting from linalg.generic to
memref_stream.generic, which is somewhat annoying. This PR adds support
for these generic attributes in the memref stream dialect.
  • Loading branch information
jorendumoulin authored Jul 23, 2024
1 parent d500729 commit 7edabb0
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
8 changes: 6 additions & 2 deletions tests/filecheck/dialects/memref_stream/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ memref_stream.generic {
affine_map<(d0, d1) -> (d1)>,
affine_map<(d0, d1) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel"]
iterator_types = ["parallel", "parallel"],
doc = "documentation string",
library_call = "library call"
} ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {hello = "world"} {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
memref_stream.yield %arg3 : f32
Expand All @@ -70,12 +72,14 @@ memref_stream.generic {
// CHECK-NEXT: affine_map<(d0, d1) -> (d0, d1)>
// CHECK-NEXT: ],
// CHECK-NEXT: iterator_types = ["parallel", "parallel"]
// CHECK-NEXT: doc = "documentation string",
// CHECK-NEXT: library_call = "library call"
// CHECK-NEXT: } ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {"hello" = "world"} {
// CHECK-NEXT: ^1(%arg3 : f32, %arg4 : f32, %arg5 : f32):
// CHECK-NEXT: memref_stream.yield %arg3 : f32
// CHECK-NEXT: }

// CHECK-GENERIC-NEXT: "memref_stream.generic"(%A, %B, %C) <{"bounds" = [3 : index, 2 : index], "init_indices" = [], "indexing_maps" = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type<parallel>, #memref_stream.iterator_type<parallel>], "operandSegmentSizes" = array<i32: 2, 1, 0>}> ({
// CHECK-GENERIC-NEXT: "memref_stream.generic"(%A, %B, %C) <{"bounds" = [3 : index, 2 : index], "init_indices" = [], "indexing_maps" = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type<parallel>, #memref_stream.iterator_type<parallel>], "doc" = "documentation string", "library_call" = "library call", "operandSegmentSizes" = array<i32: 2, 1, 0>}> ({
// CHECK-GENERIC-NEXT: ^1(%arg3 : f32, %arg4 : f32, %arg5 : f32):
// CHECK-GENERIC-NEXT: "memref_stream.yield"(%arg3) : (f32) -> ()
// CHECK-GENERIC-NEXT: }) {"hello" = "world"} : (memref<2xf32>, memref<3xf32>, memref<3x2xf64>) -> ()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ linalg.generic {
affine_map<() -> ()>,
affine_map<() -> ()>
],
iterator_types = []
iterator_types = [],
doc = "documentation string",
library_call = "library call"
} ins(%A, %B : memref<f64>, memref<f64>) outs(%C : memref<f64>) {
^0(%a : f64, %b : f64, %acc_old : f64):
%prod = arith.mulf %a, %b : f64
Expand All @@ -29,7 +31,9 @@ linalg.generic {
// CHECK-NEXT: affine_map<() -> ()>,
// CHECK-NEXT: affine_map<() -> ()>
// CHECK-NEXT: ],
// CHECK-NEXT: iterator_types = []
// CHECK-NEXT: iterator_types = [],
// CHECK-NEXT: doc = "documentation string",
// CHECK-NEXT: library_call = "library call"
// CHECK-NEXT: } ins(%A, %B : memref<f64>, memref<f64>) outs(%C : memref<f64>) {
// CHECK-NEXT: ^0(%a : f64, %b : f64, %acc_old : f64):
// CHECK-NEXT: %prod = arith.mulf %a, %b : f64
Expand Down
16 changes: 16 additions & 0 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
irdl_attr_definition,
irdl_op_definition,
operand_def,
opt_prop_def,
prop_def,
region_def,
var_operand_def,
Expand Down Expand Up @@ -397,6 +398,9 @@ class GenericOp(IRDLOperation):
Indices into the `outputs` that correspond to the initial values in `inits`.
"""

doc: StringAttr | None = opt_prop_def(StringAttr)
library_call: StringAttr | None = opt_prop_def(StringAttr)

body: Region = region_def("single_block")

traits = frozenset((GenericOpHasCanonicalizationPatternsTrait(),))
Expand All @@ -413,6 +417,8 @@ def __init__(
iterator_types: ArrayAttr[Attribute],
bounds: ArrayAttr[IntegerAttr[IndexType]],
init_indices: ArrayAttr[IntAttr],
doc: StringAttr | None = None,
library_call: StringAttr | None = None,
) -> None:
for m in indexing_maps:
if m.data.num_symbols:
Expand All @@ -426,6 +432,8 @@ def __init__(
"init_indices": init_indices,
"indexing_maps": indexing_maps,
"iterator_types": iterator_types,
"doc": doc,
"library_call": library_call,
},
regions=[body],
)
Expand Down Expand Up @@ -513,6 +521,12 @@ def print(self, printer: Printer):
lambda iterator_type: printer.print_string_literal(iterator_type.data),
)
printer.print_string("]")
if self.doc:
printer.print_string(",\ndoc = ")
printer.print_attribute(self.doc)
if self.library_call:
printer.print_string(",\nlibrary_call = ")
printer.print_attribute(self.library_call)
printer.print_string("\n}")

if self.inputs:
Expand Down Expand Up @@ -716,6 +730,8 @@ def parse(cls, parser: Parser) -> Self:
ArrayAttr(iterator_types),
bounds,
ArrayAttr(IntAttr(index) for index in init_indices),
doc,
library_call,
)
generic.attributes |= attrs
generic.attributes |= extra_attrs
Expand Down
2 changes: 2 additions & 0 deletions xdsl/transforms/convert_linalg_to_memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter) -> No
iterator_types,
bounds,
ArrayAttr(()),
op.doc,
op.library_call,
)
)

Expand Down

0 comments on commit 7edabb0

Please sign in to comment.