Skip to content

Commit

Permalink
dialects: (affine) Add custom syntax to affine.apply (#2848)
Browse files Browse the repository at this point in the history
Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
  • Loading branch information
superlopuh and zero9178 authored Jul 8, 2024
1 parent 3ec488a commit 38d2ef8
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 10 deletions.
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/affine/affine_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@
// CHECK-NEXT: "affine.store"(%value, %memref) <{"map" = affine_map<() -> (0, 0)>}> : (f64, memref<2x3xf64>) -> ()

%zero = "test.op"() : () -> index
%2 = "affine.apply"(%zero, %zero) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
%2 = affine.apply affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))> (%zero)[%zero]
%min = "affine.min"(%zero) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
%same_value = "affine.load"(%memref, %zero, %zero) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

// CHECK: %zero = "test.op"() : () -> index
// CHECK-NEXT: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
// CHECK-NEXT: %{{.*}} = affine.apply affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))> (%{{.*}})[%{{.*}}]
// CHECK-NEXT: %{{.*}} = "affine.min"(%{{.*}}) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
// CHECK-NEXT: %same_value = "affine.load"(%memref, %zero, %zero) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

Expand All @@ -78,7 +78,7 @@
}, {
"affine.yield"() : () -> ()
}) {"condition" = affine_set<() : (0 == 0)>} : () -> ()

func.return
}
// CHECK: func.func @empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@
// CHECK-NEXT: "affine.store"(%{{.*}}, %{{.*}}) <{"map" = affine_map<() -> (0, 0)>}> : (f64, memref<2x3xf64>) -> ()

%zero = "test.op"() : () -> index
%2 = "affine.apply"(%zero, %zero) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
%2 = affine.apply affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))> (%zero)[%zero]
%min = "affine.min"(%zero) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
%same_value = "affine.load"(%memref, %zero, %zero) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

// CHECK: %{{.*}} = "test.op"() : () -> index
// CHECK-NEXT: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
// CHECK-NEXT: %{{.*}} = affine.apply affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))> (%{{.*}})[%{{.*}}]
// CHECK-NEXT: %{{.*}} = "affine.min"(%{{.*}}) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
// CHECK-NEXT: %{{.*}} = "affine.load"(%{{.*}}, %{{.*}}, %{{.*}}) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

Expand All @@ -78,7 +78,7 @@
}, {
"affine.yield"() : () -> ()
}) {"condition" = affine_set<() : (0 == 0)>} : () -> ()

func.return
}
// CHECK: func.func @empty() {
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/transforms/convert_linalg_to_loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ linalg.generic {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}) <{"map" = affine_map<(d0, d1) -> ((d0 + d1))>}> : (index, index) -> index
// CHECK-NEXT: %{{.*}} = affine.apply affine_map<(d0, d1) -> ((d0 + d1))>
// CHECK-NEXT: %{{.*}} = memref.load %G[%{{.*}}] : memref<4xf64>
// CHECK-NEXT: %{{.*}} = memref.load %H[%{{.*}}] : memref<2xf64>
// CHECK-NEXT: %{{.*}} = memref.load %I[%{{.*}}] : memref<3xf64>
Expand Down
6 changes: 3 additions & 3 deletions tests/filecheck/transforms/cse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ func.func @simple_constant() -> (i32, i32) {
func.func @basic() -> (index, index) {
%2 = arith.constant 0 : index
%3 = arith.constant 0 : index
%4 = "affine.apply"(%2) <{"map" = affine_map<(d0) -> ((d0 mod 2))>}> : (index) -> index
%5 = "affine.apply"(%3) <{"map" = affine_map<(d0) -> ((d0 mod 2))>}> : (index) -> index
%4 = affine.apply affine_map<(d0) -> ((d0 mod 2))>(%2)
%5 = affine.apply affine_map<(d0) -> ((d0 mod 2))>(%3)
func.return %4, %5 : index, index
}

// CHECK: func.func @basic() -> (index, index) {
// CHECK-NEXT: %0 = arith.constant 0 : index
// CHECK-NEXT: %1 = "affine.apply"(%0) <{"map" = affine_map<(d0) -> ((d0 mod 2))>}> : (index) -> index
// CHECK-NEXT: %1 = affine.apply affine_map<(d0) -> ((d0 mod 2))> (%0)
// CHECK-NEXT: func.return %1, %1 : index, index
// CHECK-NEXT: }

Expand Down
40 changes: 40 additions & 0 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
var_operand_def,
var_result_def,
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.traits import IsTerminator, Pure
from xdsl.utils.exceptions import VerifyException

Expand Down Expand Up @@ -65,6 +67,44 @@ def verify_(self) -> None:
if len(self.map.data.results) != 1:
raise VerifyException("affine.apply expects a unidimensional map.")

@classmethod
def parse(cls, parser: Parser) -> ApplyOp:
pos = parser.pos
m = parser.parse_attribute()
if not isinstance(m, AffineMapAttr):
parser.raise_error("Expected affine map attr", at_position=pos)
dims = parser.parse_optional_comma_separated_list(
parser.Delimiter.PAREN, lambda: parser.parse_operand()
)
if dims is None:
dims = []
syms = parser.parse_optional_comma_separated_list(
parser.Delimiter.SQUARE, lambda: parser.parse_operand()
)
if syms is None:
syms = []
return ApplyOp(dims + syms, m)

def print(self, printer: Printer):
m = self.map.data
operands = tuple(self.mapOperands)
assert len(operands) == m.num_dims + m.num_symbols, f"{len(operands)} {m}"
printer.print_string_raw(" ")
printer.print_attribute(self.map)
printer.print_string_raw(" ")
if m.num_dims:
printer.print_string_raw("(")
printer.print_list(
operands[: m.num_dims], lambda el: printer.print_operand(el)
)
printer.print_string_raw(")")
if m.num_symbols:
printer.print_string_raw("[")
printer.print_list(
operands[m.num_dims :], lambda el: printer.print_operand(el)
)
printer.print_string_raw("]")


@irdl_op_definition
class For(IRDLOperation):
Expand Down

0 comments on commit 38d2ef8

Please sign in to comment.