Skip to content

Commit

Permalink
transformations: Fix bug in memref-stream-unnest-out-parameters (#2771)
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh authored Jun 26, 2024
1 parent 18aa22f commit 14ec3ce
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,30 @@ memref_stream.generic {
// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64
// CHECK-NEXT: memref_stream.yield %{{.*}} : f64
// CHECK-NEXT: }

%X, %Y, %Z = "test.op"() : () -> (memref<1x1x8x8xf64>, memref<1x1x3x3xf64>, memref<1x1x6x6xf64>)

memref_stream.generic {
bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>],
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>,
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>,
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) {
^0(%x : f64, %y : f64, %acc : f64):
%prod = arith.mulf %x, %y fastmath<fast> : f64
%res = arith.addf %prod, %acc fastmath<fast> : f64
memref_stream.yield %res : f64
}

// CHECK-NEXT: %X, %Y, %Z = "test.op"() : () -> (memref<1x1x8x8xf64>, memref<1x1x3x3xf64>, memref<1x1x6x6xf64>)
// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, (d2 + d5), (d3 + d6))>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) {
// CHECK-NEXT: ^1(%x : f64, %y : f64, %acc : f64):
// CHECK-NEXT: %prod_1 = arith.mulf %x, %y fastmath<fast> : f64
// CHECK-NEXT: %res = arith.addf %prod_1, %acc fastmath<fast> : f64
// CHECK-NEXT: memref_stream.yield %res : f64
// CHECK-NEXT: }

// CHECK-NEXT: }
4 changes: 4 additions & 0 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,10 @@ def get_static_loop_ranges(
tuple(bound.data for bound in self.bounds.data[min_dims:]),
)

@property
def is_imperfectly_nested(self) -> bool:
return bool(self.get_static_loop_ranges()[1])

def _print_init(self, printer: Printer, init: SSAValue | None):
if init is None:
printer.print_string("None")
Expand Down
10 changes: 8 additions & 2 deletions xdsl/transforms/memref_stream_unnest_out_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ class UnnestOutParametersPattern(RewritePattern):
def match_and_rewrite(
self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
if op.is_imperfectly_nested:
# Already unnested
return

num_outputs = len(op.outputs)
if not num_outputs:
return

num_inputs = len(op.inputs)

num_parallel = sum(
i == memref_stream.IteratorTypeAttr.parallel() for i in op.iterator_types
)
Expand All @@ -37,10 +43,10 @@ def match_and_rewrite(

parallel_dims = (True,) * num_parallel + (False,) * num_reduction

maps = op.indexing_maps.data[num_parallel:]
maps = op.indexing_maps.data[num_inputs:]
new_maps = ArrayAttr(
(
*op.indexing_maps.data[:num_parallel],
*op.indexing_maps.data[:num_inputs],
*(AffineMapAttr(m.data.compress_dims(parallel_dims)) for m in maps),
)
)
Expand Down

0 comments on commit 14ec3ce

Please sign in to comment.