Skip to content

Commit

Permalink
transformations: Only stream used inputs in memref-streamify (#2775)
Browse files Browse the repository at this point in the history
The pools have unused inputs in linalg.generic.

Note stacked PR.
  • Loading branch information
superlopuh authored Jun 26, 2024
1 parent a4fefaa commit a36af1f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/filecheck/transforms/memref_streamify.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,36 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) ->
// CHECK-NEXT: func.return
// CHECK-NEXT: }

func.func public @used_only(
%X : memref<2xf64>,
%Y : memref<2xf64>,
%Z : memref<2xf64>
) {
memref_stream.generic {
bounds = [#builtin.int<2>],
indexing_maps = [
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>,
affine_map<(d0) -> (d0)>
],
iterator_types = ["parallel"]
} ins(%X, %Y : memref<2xf64>, memref<2xf64>) outs(%Z : memref<2xf64>) {
^0(%x : f64, %y : f64, %z : f64):
memref_stream.yield %x : f64
}

func.return
}

// CHECK-NEXT: func.func public @used_only(%{{.*}} : memref<2xf64>, %{{.*}} : memref<2xf64>, %{{.*}} : memref<2xf64>) {
// CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern<ub = [2], index_map = (d0) -> (d0)>, #memref_stream.stride_pattern<ub = [2], index_map = (d0) -> (d0)>]} ins(%{{.*}} : memref<2xf64>) outs(%{{.*}} : memref<2xf64>) {
// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable<f64>, %{{.*}} : !stream.writable<f64>):
// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<2>], indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%{{.*}}, %{{.*}} : !stream.readable<f64>, memref<2xf64>) outs(%{{.*}} : !stream.writable<f64>) {
// CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64):
// CHECK-NEXT: memref_stream.yield %{{.*}} : f64
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: func.return
// CHECK-NEXT: }

// CHECK-NEXT: }
1 change: 1 addition & 0 deletions xdsl/transforms/memref_streamify.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def match_and_rewrite(
(index, cast(memref.MemRefType[Attribute], value_type).element_type)
for index, value in enumerate(op.inputs)
if isinstance(value_type := value.type, memref.MemRefType)
and op.body.block.args[index].uses
)
input_count = len(op.inputs)
streamable_output_indices = tuple(
Expand Down

0 comments on commit a36af1f

Please sign in to comment.