From 51c8ecb0c459d9ac86fb1bb4d8ac70a6767fe4b2 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Mon, 5 Feb 2024 10:30:17 +0100 Subject: [PATCH] add custom printer and parser for doc and library call (#1) --- .../filecheck/dialects/linalg/linalg_ops.mlir | 14 +++++++++++ xdsl/dialects/linalg.py | 25 +++++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/tests/filecheck/dialects/linalg/linalg_ops.mlir b/tests/filecheck/dialects/linalg/linalg_ops.mlir index 2d5d3cbe09..1e7deb4db4 100644 --- a/tests/filecheck/dialects/linalg/linalg_ops.mlir +++ b/tests/filecheck/dialects/linalg/linalg_ops.mlir @@ -8,6 +8,11 @@ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) linalg.yield %arg3 : f32 } +linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], doc="a_docstring", library_call="a_library_call"} ins(%0 : f32) outs(%1 : memref<1x256xf32>) { +^bb0(%arg3: f32, %arg4: f32): + linalg.yield %arg3 : f32 +} + linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : f32) outs(%1 : memref<1x256xf32>) attrs = {hello="world"} { ^bb0(%arg3: f32, %arg4: f32): linalg.yield %arg3 : f32 @@ -19,6 +24,10 @@ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) // CHECK-NEXT: ^0(%{{.*}}: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %{{.*}} : f32 // CHECK-NEXT: } +// CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], doc = "a_docstring", library_call = "a_library_call"} ins(%0 : f32) outs(%1 : memref<1x256xf32>) { +// CHECK-NEXT: ^1(%arg3_1 : f32, %arg4_1 : f32): +// CHECK-NEXT: linalg.yield %arg3_1 : f32 +// CHECK-NEXT: } // CHECK-NEXT: linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : f32) outs(%1 : memref<1x256xf32>) attrs = {"hello" = "world"} { // CHECK-NEXT: ^{{.*}}(%{{.*}}: f32, %{{.*}}: f32): // CHECK-NEXT: linalg.yield %{{.*}} : f32 @@ -29,6 +38,11 @@ linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) // CHECK-GENERIC-NEXT: ^0(%{{.*}}: f32, %{{.*}}: f32): // CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f32) -> () // CHECK-GENERIC-NEXT: }) : (f32, memref<1x256xf32>) -> () +// CHECK-GENERIC-NEXT: "linalg.generic"(%0, %1) <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type, #linalg.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^1(%arg3_1 : f32, %arg4_1 : f32): +// CHECK-GENERIC-NEXT: "linalg.yield"(%arg3_1) : (f32) -> () +// CHECK-GENERIC-NEXT: }) {"doc" = "a_docstring", "library_call" = "a_library_call"} : (f32, memref<1x256xf32>) -> () + // CHECK-GENERIC: "linalg.generic"(%0, %1) <{"indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#linalg.iterator_type, #linalg.iterator_type], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}}: f32, %{{.*}}: f32): // CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f32) -> () diff --git a/xdsl/dialects/linalg.py b/xdsl/dialects/linalg.py index f45fef46a5..2f70b925b7 100644 --- a/xdsl/dialects/linalg.py +++ b/xdsl/dialects/linalg.py @@ -199,7 +199,14 @@ def print(self, printer: Printer): iterator_type.data.value ), ) - printer.print_string("]}") + printer.print_string("]") + if self.doc: + printer.print_string(", doc = ") + printer.print_attribute(self.doc) + if self.library_call: + printer.print_string(", library_call = ") + printer.print_attribute(self.library_call) + printer.print_string("}") if self.inputs: printer.print_string(" ins(") @@ -281,6 +288,20 @@ def parse(cls, parser: Parser) -> Self: attrs_end_pos, ) + if "doc" in attrs: + doc = attrs["doc"] + assert isinstance(doc, StringAttr) + del attrs["doc"] + else: + doc = None + + if "library_call" in attrs: + library_call = attrs["library_call"] + assert isinstance(library_call, StringAttr) + del attrs["library_call"] + else: + library_call = None + pos = parser.pos if parser.parse_optional_characters("ins"): parser.parse_punctuation("(") @@ -321,7 +342,7 @@ def parse(cls, parser: Parser) -> Self: body = parser.parse_region() - generic = cls(ins, outs, body, indexing_maps, iterator_types) + generic = cls(ins, outs, body, indexing_maps, iterator_types, doc, library_call) generic.attributes |= attrs generic.attributes |= extra_attrs