Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: Handle affine.apply in lower-affine #2849

Merged
merged 38 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8fd2ef7
api: add printing and parsing of dynamic lists as in MLIR
superlopuh Jul 4, 2024
dc18ee3
dialects: (memref) add custom syntax for memref_subview
superlopuh Jul 4, 2024
8c15a86
custom dynamic index
superlopuh Jul 5, 2024
cfb9479
Merge branch 'sasha/ir/dynamic-lists' into sasha/memref-subview
superlopuh Jul 5, 2024
887570a
use custom dynamic index
superlopuh Jul 5, 2024
5958582
add back deleted tests
superlopuh Jul 5, 2024
bf353d6
fix types and add roundrip tests with mlir-opt custom and generic
superlopuh Jul 5, 2024
1d9b2f0
Merge branch 'main' into sasha/memref-subview
superlopuh Jul 5, 2024
679afbc
Merge branch 'main' into sasha/memref-subview
superlopuh Jul 5, 2024
131152d
add docstring
superlopuh Jul 5, 2024
fc393ae
transformations: Handle more data types in memref to riscv lowering
superlopuh Jul 5, 2024
a622d08
Merge branch 'main' into sasha/riscv/memref-riscv-test-cleanup
superlopuh Jul 5, 2024
3d9ee64
transformations: Add memref.subview lowering to convert-memref-to-riscv
superlopuh Jul 5, 2024
1e7ddce
dialects: (affine) Add custom syntax to affine.apply
superlopuh Jul 5, 2024
58baae4
space
superlopuh Jul 5, 2024
bdc33d2
transformations: Handle affine.apply in lower-affine
superlopuh Jul 5, 2024
ef8b457
Merge remote-tracking branch 'origin/main' into sasha/riscv/lower-mem…
superlopuh Jul 5, 2024
2531bd6
Merge branch 'sasha/riscv/lower-memref-subview' into sasha/affine/app…
superlopuh Jul 5, 2024
4fc9a28
Merge branch 'sasha/affine/apply-custom-syntax' into sasha/riscv/lowe…
superlopuh Jul 5, 2024
5fb46e8
swap indexing in apply lowering
superlopuh Jul 6, 2024
d356931
Merge branch 'main' into sasha/riscv/lower-memref-subview
superlopuh Jul 6, 2024
5b7cc5c
Merge branch 'sasha/riscv/lower-memref-subview' into sasha/affine/app…
superlopuh Jul 6, 2024
671a444
Merge branch 'sasha/affine/apply-custom-syntax' into sasha/riscv/lowe…
superlopuh Jul 6, 2024
9c19016
reword doc string
superlopuh Jul 6, 2024
cc80989
add helpers for strides given a shape and whether memref is contiguous
superlopuh Jul 6, 2024
508eb4d
check for non-contiguity in load and store
superlopuh Jul 6, 2024
7d1b845
remove unnecessary res_ptr
superlopuh Jul 6, 2024
c38deb6
simplify offset pointer calculation
superlopuh Jul 6, 2024
198b0b2
Merge branch 'main' into sasha/riscv/lower-memref-subview
superlopuh Jul 7, 2024
d51e7fc
backend: (riscv) preserve name hints when casting operands
superlopuh Jul 7, 2024
1ec1cd6
Merge branch 'sasha/riscv/name-hints' into sasha/riscv/lower-memref-s…
superlopuh Jul 7, 2024
1926fbe
fix merge
superlopuh Jul 7, 2024
512b96e
Merge branch 'main' into sasha/riscv/lower-memref-subview
superlopuh Jul 7, 2024
f10a16f
dialects: (affine) Add custom syntax to affine.apply
superlopuh Jul 5, 2024
5a18a3a
space
superlopuh Jul 5, 2024
b86963b
Merge branch 'sasha/affine/apply-custom-syntax' into sasha/riscv/lowe…
superlopuh Jul 7, 2024
b540e1f
Merge branch 'main' into sasha/affine/apply-custom-syntax
superlopuh Jul 8, 2024
4bf4526
Merge branch 'sasha/affine/apply-custom-syntax' into sasha/riscv/lowe…
superlopuh Jul 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/affine/affine_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@
// CHECK-NEXT: "affine.store"(%value, %memref) <{"map" = affine_map<() -> (0, 0)>}> : (f64, memref<2x3xf64>) -> ()

%zero = "test.op"() : () -> index
%2 = "affine.apply"(%zero, %zero) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
%2 = affine.apply affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))> (%zero)[%zero]
%min = "affine.min"(%zero) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
%same_value = "affine.load"(%memref, %zero, %zero) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

// CHECK: %zero = "test.op"() : () -> index
// CHECK-NEXT: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
// CHECK-NEXT: %{{.*}} = affine.apply affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))> (%{{.*}})[%{{.*}}]
// CHECK-NEXT: %{{.*}} = "affine.min"(%{{.*}}) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
// CHECK-NEXT: %same_value = "affine.load"(%memref, %zero, %zero) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

Expand All @@ -78,7 +78,7 @@
}, {
"affine.yield"() : () -> ()
}) {"condition" = affine_set<() : (0 == 0)>} : () -> ()

func.return
}
// CHECK: func.func @empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@
// CHECK-NEXT: "affine.store"(%{{.*}}, %{{.*}}) <{"map" = affine_map<() -> (0, 0)>}> : (f64, memref<2x3xf64>) -> ()

%zero = "test.op"() : () -> index
%2 = "affine.apply"(%zero, %zero) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
%2 = affine.apply affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))> (%zero)[%zero]
%min = "affine.min"(%zero) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
%same_value = "affine.load"(%memref, %zero, %zero) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

// CHECK: %{{.*}} = "test.op"() : () -> index
// CHECK-NEXT: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}) <{"map" = affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))>}> : (index, index) -> index
// CHECK-NEXT: %{{.*}} = affine.apply affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))> (%{{.*}})[%{{.*}}]
// CHECK-NEXT: %{{.*}} = "affine.min"(%{{.*}}) <{"map" = affine_map<(d0) -> ((d0 + 41), d0)>}> : (index) -> index
// CHECK-NEXT: %{{.*}} = "affine.load"(%{{.*}}, %{{.*}}, %{{.*}}) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf64>, index, index) -> f64

Expand All @@ -78,7 +78,7 @@
}, {
"affine.yield"() : () -> ()
}) {"condition" = affine_set<() : (0 == 0)>} : () -> ()

func.return
}
// CHECK: func.func @empty() {
Expand Down
2 changes: 1 addition & 1 deletion tests/filecheck/transforms/convert_linalg_to_loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ linalg.generic {
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}) <{"map" = affine_map<(d0, d1) -> ((d0 + d1))>}> : (index, index) -> index
// CHECK-NEXT: %{{.*}} = affine.apply affine_map<(d0, d1) -> ((d0 + d1))>
// CHECK-NEXT: %{{.*}} = memref.load %G[%{{.*}}] : memref<4xf64>
// CHECK-NEXT: %{{.*}} = memref.load %H[%{{.*}}] : memref<2xf64>
// CHECK-NEXT: %{{.*}} = memref.load %I[%{{.*}}] : memref<3xf64>
Expand Down
6 changes: 3 additions & 3 deletions tests/filecheck/transforms/cse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ func.func @simple_constant() -> (i32, i32) {
func.func @basic() -> (index, index) {
%2 = arith.constant 0 : index
%3 = arith.constant 0 : index
%4 = "affine.apply"(%2) <{"map" = affine_map<(d0) -> ((d0 mod 2))>}> : (index) -> index
%5 = "affine.apply"(%3) <{"map" = affine_map<(d0) -> ((d0 mod 2))>}> : (index) -> index
%4 = affine.apply affine_map<(d0) -> ((d0 mod 2))>(%2)
%5 = affine.apply affine_map<(d0) -> ((d0 mod 2))>(%3)
func.return %4, %5 : index, index
}

// CHECK: func.func @basic() -> (index, index) {
// CHECK-NEXT: %0 = arith.constant 0 : index
// CHECK-NEXT: %1 = "affine.apply"(%0) <{"map" = affine_map<(d0) -> ((d0 mod 2))>}> : (index) -> index
// CHECK-NEXT: %1 = affine.apply affine_map<(d0) -> ((d0 mod 2))> (%0)
// CHECK-NEXT: func.return %1, %1 : index, index
// CHECK-NEXT: }

Expand Down
99 changes: 57 additions & 42 deletions tests/filecheck/transforms/lower_affine.mlir
Original file line number Diff line number Diff line change
@@ -1,46 +1,61 @@
// RUN: xdsl-opt %s -p lower-affine --allow-unregistered-dialect --print-op-generic | filecheck %s

"builtin.module"() ({
%v0, %m = "test.op"() : () -> (f32, memref<2x3xf32>)
"affine.store"(%v0, %m) {"map" = affine_map<() -> (1, 2)>} : (f32, memref<2x3xf32>) -> ()
%v1 = "affine.load"(%m) {"map" = affine_map<() -> (1, 2)>} : (memref<2x3xf32>) -> f32
%v2 = "affine.for"(%v1) <{"lowerBoundMap" = affine_map<() -> (0)>, "upperBoundMap" = affine_map<() -> (2)>, "step" = 1 : index, "operandSegmentSizes" = array<i32: 0, 0, 1>}> ({
^0(%r : index, %acc0 : f32):
%v3 = "affine.for"(%acc0) <{"lowerBoundMap" = affine_map<() -> (0)>, "upperBoundMap" = affine_map<() -> (3)>, "step" = 1 : index, "operandSegmentSizes" = array<i32: 0, 0, 1>}> ({
^2(%c : index, %acc1 : f32):
%v4 = "affine.load"(%m, %r, %c) {"map" = affine_map<(d0, d1) -> (d0, d1)>} : (memref<2x3xf32>, index, index) -> f32
%acc_new = "test.op"(%acc1, %v4) : (f32, f32) -> f32
"affine.yield"(%acc_new) : (f32) -> ()
}) : (f32) -> f32
"affine.yield"(%v3) : (f32) -> ()
// RUN: xdsl-opt %s -p lower-affine | filecheck %s

// CHECK: builtin.module {

// CHECK-NEXT: %v0, %m = "test.op"() : () -> (f32, memref<2x3xf32>)
%v0, %m = "test.op"() : () -> (f32, memref<2x3xf32>)

// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: %{{.*}} = arith.constant 2 : index
// CHECK-NEXT: memref.store %v0, %m[%{{.*}}, %{{.*}}] : memref<2x3xf32>
"affine.store"(%v0, %m) <{"map" = affine_map<() -> (1, 2)>}> : (f32, memref<2x3xf32>) -> ()

// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: %{{.*}} = arith.constant 2 : index
// CHECK-NEXT: %v1 = memref.load %m[%{{.*}}, %{{.*}}] : memref<2x3xf32>
%v1 = "affine.load"(%m) <{"map" = affine_map<() -> (1, 2)>}> : (memref<2x3xf32>) -> f32

// CHECK-NEXT: %{{.*}} = arith.constant 0 : index
// CHECK-NEXT: %{{.*}} = arith.constant 2 : index
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: %v2 = scf.for %r = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%acc0 = %v1) -> (f32) {
%v2 = "affine.for"(%v1) <{"lowerBoundMap" = affine_map<() -> (0)>, "upperBoundMap" = affine_map<() -> (2)>, "step" = 1 : index, "operandSegmentSizes" = array<i32: 0, 0, 1>}> ({
^0(%r : index, %acc0 : f32):

// CHECK-NEXT: %{{.*}} = arith.constant 0 : index
// CHECK-NEXT: %{{.*}} = arith.constant 3 : index
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: %v3 = scf.for %c = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%acc1 = %acc0) -> (f32) {
%v3 = "affine.for"(%acc0) <{"lowerBoundMap" = affine_map<() -> (0)>, "upperBoundMap" = affine_map<() -> (3)>, "step" = 1 : index, "operandSegmentSizes" = array<i32: 0, 0, 1>}> ({
^1(%c : index, %acc1 : f32):

// CHECK-NEXT: %v4 = memref.load %m[%r, %c] : memref<2x3xf32>
%v4 = "affine.load"(%m, %r, %c) <{"map" = affine_map<(d0, d1) -> (d0, d1)>}> : (memref<2x3xf32>, index, index) -> f32

// CHECK-NEXT: %acc_new = "test.op"(%acc1, %v4) : (f32, f32) -> f32
%acc_new = "test.op"(%acc1, %v4) : (f32, f32) -> f32

// CHECK-NEXT: scf.yield %acc_new : f32
"affine.yield"(%acc_new) : (f32) -> ()

// CHECK-NEXT: }
}) : (f32) -> f32
}) : () -> ()

// CHECK: "builtin.module"() ({
// CHECK-NEXT: %v0, %m = "test.op"() : () -> (f32, memref<2x3xf32>)
// CHECK-NEXT: %0 = "arith.constant"() <{"value" = 1 : index}> : () -> index
// CHECK-NEXT: %1 = "arith.constant"() <{"value" = 2 : index}> : () -> index
// CHECK-NEXT: "memref.store"(%v0, %m, %0, %1) : (f32, memref<2x3xf32>, index, index) -> ()
// CHECK-NEXT: %2 = "arith.constant"() <{"value" = 1 : index}> : () -> index
// CHECK-NEXT: %3 = "arith.constant"() <{"value" = 2 : index}> : () -> index
// CHECK-NEXT: %v1 = "memref.load"(%m, %2, %3) : (memref<2x3xf32>, index, index) -> f32
// CHECK-NEXT: %4 = "arith.constant"() <{"value" = 0 : index}> : () -> index
// CHECK-NEXT: %5 = "arith.constant"() <{"value" = 2 : index}> : () -> index
// CHECK-NEXT: %6 = "arith.constant"() <{"value" = 1 : index}> : () -> index
// CHECK-NEXT: %v2 = "scf.for"(%4, %5, %6, %v1) ({
// CHECK-NEXT: ^0(%r : index, %acc0 : f32):
// CHECK-NEXT: %7 = "arith.constant"() <{"value" = 0 : index}> : () -> index
// CHECK-NEXT: %8 = "arith.constant"() <{"value" = 3 : index}> : () -> index
// CHECK-NEXT: %9 = "arith.constant"() <{"value" = 1 : index}> : () -> index
// CHECK-NEXT: %v3 = "scf.for"(%7, %8, %9, %acc0) ({
// CHECK-NEXT: ^1(%c : index, %acc1 : f32):
// CHECK-NEXT: %v4 = "memref.load"(%m, %r, %c) : (memref<2x3xf32>, index, index) -> f32
// CHECK-NEXT: %acc_new = "test.op"(%acc1, %v4) : (f32, f32) -> f32
// CHECK-NEXT: "scf.yield"(%acc_new) : (f32) -> ()
// CHECK-NEXT: }) : (index, index, index, f32) -> f32
// CHECK-NEXT: "scf.yield"(%v3) : (f32) -> ()
// CHECK-NEXT: }) : (index, index, index, f32) -> f32
// CHECK-NEXT: }) : () -> ()

// CHECK-NEXT: scf.yield %v3 : f32
"affine.yield"(%v3) : (f32) -> ()

// CHECK-NEXT: }
}) : (f32) -> f32

// CHECK-NEXT: %apply_dim, %apply_sym = "test.op"() : () -> (index, index)
// CHECK-NEXT: %apply_res = arith.constant 42 : index
// CHECK-NEXT: %apply_res_1 = arith.muli %apply_sym, %apply_res : index
// CHECK-NEXT: %apply_res_2 = arith.addi %apply_dim, %apply_res_1 : index
// CHECK-NEXT: %apply_res_3 = arith.constant -1 : index
// CHECK-NEXT: %apply_res_4 = arith.addi %apply_res_2, %apply_res_3 : index
%apply_dim, %apply_sym = "test.op"() : () -> (index, index)
%apply_res = affine.apply affine_map<(d0)[s0] -> (((d0 + (s0 * 42)) + -1))> (%apply_dim)[%apply_sym]

// CHECK-NEXT: }

// CHECK-NOT: "affine.
40 changes: 40 additions & 0 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
var_operand_def,
var_result_def,
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.traits import IsTerminator, Pure
from xdsl.utils.exceptions import VerifyException

Expand Down Expand Up @@ -65,6 +67,44 @@ def verify_(self) -> None:
if len(self.map.data.results) != 1:
raise VerifyException("affine.apply expects a unidimensional map.")

@classmethod
def parse(cls, parser: Parser) -> ApplyOp:
pos = parser.pos
m = parser.parse_attribute()
if not isinstance(m, AffineMapAttr):
parser.raise_error("Expected affine map attr", at_position=pos)
dims = parser.parse_optional_comma_separated_list(
parser.Delimiter.PAREN, lambda: parser.parse_operand()
)
if dims is None:
dims = []
syms = parser.parse_optional_comma_separated_list(
parser.Delimiter.SQUARE, lambda: parser.parse_operand()
)
if syms is None:
syms = []
return ApplyOp(dims + syms, m)

def print(self, printer: Printer):
m = self.map.data
operands = tuple(self.mapOperands)
assert len(operands) == m.num_dims + m.num_symbols, f"{len(operands)} {m}"
printer.print_string_raw(" ")
printer.print_attribute(self.map)
printer.print_string_raw(" ")
if m.num_dims:
printer.print_string_raw("(")
printer.print_list(
operands[: m.num_dims], lambda el: printer.print_operand(el)
)
printer.print_string_raw(")")
if m.num_symbols:
printer.print_string_raw("[")
printer.print_list(
operands[m.num_dims :], lambda el: printer.print_operand(el)
)
printer.print_string_raw("]")


@irdl_op_definition
class For(IRDLOperation):
Expand Down
22 changes: 22 additions & 0 deletions xdsl/transforms/lower_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,27 @@ def match_and_rewrite(self, op: affine.Yield, rewriter: PatternRewriter, /):
rewriter.replace_matched_op(scf.Yield(*op.arguments))


class LowerAffineApply(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: affine.ApplyOp, rewriter: PatternRewriter, /):
affine_map = op.map.data
assert len(affine_map.results) == 1

operands = op.mapOperands
assert affine_map.num_dims + affine_map.num_symbols == len(operands)

dims = operands[: affine_map.num_dims]
symbols = operands[affine_map.num_dims :]

new_ops: list[Operation] = []
new_results: list[SSAValue] = []

ops, val = affine_expr_ops(affine_map.results[0], dims, symbols)
new_ops.extend(ops)
new_results.append(val)
rewriter.replace_matched_op(new_ops, new_results)


class LowerAffinePass(ModulePass):
name = "lower-affine"

Expand All @@ -144,6 +165,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
LowerAffineLoad(),
LowerAffineFor(),
LowerAffineYield(),
LowerAffineApply(),
]
)
).rewrite_module(op)
2 changes: 2 additions & 0 deletions xdsl/transforms/test_lower_memref_stream_to_snitch_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
from xdsl.transforms.convert_memref_stream_to_snitch_stream import (
ConvertMemrefStreamToSnitch,
)
from xdsl.transforms.lower_affine import LowerAffinePass
from xdsl.transforms.reconcile_unrealized_casts import ReconcileUnrealizedCastsPass

TEST_LOWER_MEMREF_STREAM_TO_SNITCH_STREAM: tuple[ModulePass, ...] = (
CanonicalizePass(),
ConvertMemrefToRiscvPass(),
LowerAffinePass(),
ConvertScfToRiscvPass(),
ConvertArithToRiscvPass(),
ConvertFuncToRiscvFuncPass(),
Expand Down
Loading