Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: Fix bug in memref-stream-unnest-out-parameters #2771

Merged
merged 32 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
8896dab
transformations: do not insert affine.apply ops when streaming
superlopuh Jun 20, 2024
f89d526
transformations: fix yielding of values in memref_stream.generic lowe…
superlopuh Jun 20, 2024
299de80
tests: move constant initialisation around in bottom-up tests
superlopuh Jun 21, 2024
2e5a024
dialects: (memref_stream) add an inits field to memref_stream.generic
superlopuh Jun 14, 2024
b2acbca
transformations: support constant inits in memref_stream.generic lowe…
superlopuh Jun 18, 2024
ca1cf95
transformations: memref_streamify handle constant inits
superlopuh Jun 21, 2024
2fb3274
Merge remote-tracking branch 'origin/main' into sasha/memref_stream/y…
superlopuh Jun 21, 2024
c2a03e6
Merge branch 'sasha/memref_stream/yields' into sasha/memref_stream/bo…
superlopuh Jun 21, 2024
7f12d54
Merge branch 'sasha/memref_stream/bottom-up-constants' into sasha/mem…
superlopuh Jun 21, 2024
7b7f027
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 21, 2024
06e4131
Merge branch 'sasha/memref_stream/const-init-lowering' into sasha/mem…
superlopuh Jun 21, 2024
7378f6d
Merge branch 'main' into sasha/memref_stream/const-init
superlopuh Jun 21, 2024
3f70240
Merge branch 'main' into sasha/memref_stream/const-init
superlopuh Jun 23, 2024
333ce98
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 23, 2024
f38b458
Merge branch 'sasha/memref_stream/const-init-lowering' into sasha/mem…
superlopuh Jun 23, 2024
cbadcd0
inits now values not attributes
superlopuh Jun 24, 2024
a207627
inits now values not attributes
superlopuh Jun 24, 2024
f9ae2aa
inits now values not attributes
superlopuh Jun 24, 2024
f178a2a
Merge branch 'sasha/memref_stream/const-init-lowering' into sasha/mem…
superlopuh Jun 24, 2024
aed0cd4
straggling inits
superlopuh Jun 24, 2024
182f2c0
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 24, 2024
a76c29a
Merge branch 'sasha/memref_stream/const-init-lowering' into sasha/mem…
superlopuh Jun 24, 2024
968c53d
fix printing
superlopuh Jun 24, 2024
81018b5
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 24, 2024
200acc2
Merge branch 'sasha/memref_stream/const-init-lowering' into sasha/mem…
superlopuh Jun 24, 2024
f114076
fix fix
superlopuh Jun 24, 2024
9ecbc32
fix fix fix
superlopuh Jun 24, 2024
73b0ec9
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 24, 2024
3bf1725
Merge branch 'sasha/memref_stream/const-init-lowering' into sasha/mem…
superlopuh Jun 24, 2024
9bd3f61
transformations: Fix bug in memref-stream-unnest-out-parameters
superlopuh Jun 25, 2024
4d15a54
add helper
superlopuh Jun 25, 2024
90bbbf1
Merge branch 'main' into sasha/memref_stream/fix-unnest
superlopuh Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,30 @@ memref_stream.generic {
// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64
// CHECK-NEXT: memref_stream.yield %{{.*}} : f64
// CHECK-NEXT: }

%X, %Y, %Z = "test.op"() : () -> (memref<1x1x8x8xf64>, memref<1x1x3x3xf64>, memref<1x1x6x6xf64>)

memref_stream.generic {
bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>],
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>,
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>,
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) {
^0(%x : f64, %y : f64, %acc : f64):
%prod = arith.mulf %x, %y fastmath<fast> : f64
%res = arith.addf %prod, %acc fastmath<fast> : f64
memref_stream.yield %res : f64
}

// CHECK-NEXT: %X, %Y, %Z = "test.op"() : () -> (memref<1x1x8x8xf64>, memref<1x1x3x3xf64>, memref<1x1x6x6xf64>)
// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, (d2 + d5), (d3 + d6))>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) {
// CHECK-NEXT: ^1(%x : f64, %y : f64, %acc : f64):
// CHECK-NEXT: %prod_1 = arith.mulf %x, %y fastmath<fast> : f64
// CHECK-NEXT: %res = arith.addf %prod_1, %acc fastmath<fast> : f64
// CHECK-NEXT: memref_stream.yield %res : f64
// CHECK-NEXT: }

// CHECK-NEXT: }
4 changes: 4 additions & 0 deletions xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,10 @@ def get_static_loop_ranges(
tuple(bound.data for bound in self.bounds.data[min_dims:]),
)

@property
def is_imperfectly_nested(self) -> bool:
return bool(self.get_static_loop_ranges()[1])

def _print_init(self, printer: Printer, init: SSAValue | None):
if init is None:
printer.print_string("None")
Expand Down
10 changes: 8 additions & 2 deletions xdsl/transforms/memref_stream_unnest_out_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,16 @@ class UnnestOutParametersPattern(RewritePattern):
def match_and_rewrite(
self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
if op.is_imperfectly_nested:
# Already unnested
return

num_outputs = len(op.outputs)
if not num_outputs:
return

num_inputs = len(op.inputs)

num_parallel = sum(
i == memref_stream.IteratorTypeAttr.parallel() for i in op.iterator_types
)
Expand All @@ -37,10 +43,10 @@ def match_and_rewrite(

parallel_dims = (True,) * num_parallel + (False,) * num_reduction

maps = op.indexing_maps.data[num_parallel:]
maps = op.indexing_maps.data[num_inputs:]
new_maps = ArrayAttr(
(
*op.indexing_maps.data[:num_parallel],
*op.indexing_maps.data[:num_inputs],
*(AffineMapAttr(m.data.compress_dims(parallel_dims)) for m in maps),
)
)
Expand Down
Loading