From 8896dabea12a4ea965a295f61f671a7d52fdc98f Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Thu, 20 Jun 2024 14:41:11 +0100 Subject: [PATCH 01/14] transformations: do not insert affine.apply ops when streaming --- .../convert_memref_stream_to_loops.mlir | 41 +++++++++++ xdsl/transforms/convert_linalg_to_loops.py | 35 +++++++-- .../convert_memref_stream_to_loops.py | 43 +++++++---- xdsl/transforms/loop_nest_lowering_utils.py | 73 +++++++++++++------ 4 files changed, 147 insertions(+), 45 deletions(-) diff --git a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir index c7e0ea6e0c..f10ea1c143 100644 --- a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir +++ b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir @@ -166,4 +166,45 @@ func.func @main(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64> // CHECK-NEXT: func.return %{{.*}} : memref<4x3xf64> // CHECK-NEXT: } +func.func @elide_affine(%A : memref<6xf64>, %B : memref) -> memref { + memref_stream.streaming_region { + patterns = [ + #memref_stream.stride_pattern (d0 * 3 + d1)> + ] + } ins(%A : memref<6xf64>) { + ^0(%0 : !stream.readable): + memref_stream.generic { + bounds = [#builtin.int<2>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1) -> (d0 * 3 + d1)>, + affine_map<(d0, d1) -> ()> + ], + iterator_types = ["parallel", "reduction"] + } ins(%0 : !stream.readable) outs(%B : memref) { + ^1(%a : f64, %acc_old : f64): + %acc_new = arith.addf %acc_old, %a : f64 + memref_stream.yield %acc_new : f64 + } + } + func.return %B : memref +} +// CHECK-NEXT: func.func @elide_affine(%{{.*}} : memref<6xf64>, %{{.*}} : memref) -> memref { +// CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (((d0 * 3) + d1))>]} ins(%{{.*}} : memref<6xf64>) { +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable): +// CHECK-NEXT: %{{.*}} = arith.constant 2 : index +// CHECK-NEXT: %{{.*}} = arith.constant 3 : index +// CHECK-NEXT: %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: %{{.*}} = arith.constant 1 : index +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: %{{.*}} = memref_stream.read from %{{.*}} : f64 +// CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[] : memref +// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 +// CHECK-NEXT: memref.store %{{.*}}, %{{.*}}[] : memref +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func.return %{{.*}} : memref +// CHECK-NEXT: } + // CHECK-NEXT: } diff --git a/xdsl/transforms/convert_linalg_to_loops.py b/xdsl/transforms/convert_linalg_to_loops.py index f3e928b69b..b2abf00325 100644 --- a/xdsl/transforms/convert_linalg_to_loops.py +++ b/xdsl/transforms/convert_linalg_to_loops.py @@ -2,7 +2,7 @@ from xdsl.context import MLContext from xdsl.dialects import linalg, memref -from xdsl.dialects.builtin import MemRefType, ModuleOp +from xdsl.dialects.builtin import AffineMapAttr, MemRefType, ModuleOp from xdsl.ir import SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -13,16 +13,23 @@ op_type_rewrite_pattern, ) from xdsl.rewriter import InsertPoint -from xdsl.transforms.loop_nest_lowering_utils import rewrite_generic_to_loops +from xdsl.transforms.loop_nest_lowering_utils import ( + indices_for_map, + rewrite_generic_to_loops, +) -def load( +def insert_load( value: SSAValue, - indices: Sequence[SSAValue], + affine_map_attr: AffineMapAttr, + ind_vars: Sequence[SSAValue], rewriter: PatternRewriter, insertion_target: InsertPoint, ) -> SSAValue: if isinstance(value.type, MemRefType): + indices = indices_for_map( + rewriter, insertion_target, affine_map_attr.data, ind_vars + ) op = memref.Load.get(value, indices) rewriter.insert_op(op, insertion_target) return op.res @@ -30,6 +37,22 @@ def load( return value +def insert_store( + value: SSAValue, + destination: SSAValue, + affine_map_attr: AffineMapAttr, + ind_vars: Sequence[SSAValue], + rewriter: PatternRewriter, + insertion_target: InsertPoint, +): + indices = indices_for_map( + rewriter, insertion_target, affine_map_attr.data, ind_vars + ) + op = memref.Store.get(value, destination, indices) + rewriter.insert_op(op, insertion_target) + return op + + class LowerGenericOpPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter) -> None: @@ -47,8 +70,8 @@ def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter) -> No op.operands, op.outputs, op.body.block, - load, - memref.Store.get, + insert_load, + insert_store, ) diff --git a/xdsl/transforms/convert_memref_stream_to_loops.py b/xdsl/transforms/convert_memref_stream_to_loops.py index 97ef15e12b..640cf64c16 100644 --- a/xdsl/transforms/convert_memref_stream_to_loops.py +++ b/xdsl/transforms/convert_memref_stream_to_loops.py @@ -2,9 +2,7 @@ from xdsl.context import MLContext from xdsl.dialects import memref, memref_stream, stream -from xdsl.dialects.builtin import ( - ModuleOp, -) +from xdsl.dialects.builtin import AffineMapAttr, ModuleOp from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -16,34 +14,49 @@ ) from xdsl.rewriter import InsertPoint from xdsl.transforms.loop_nest_lowering_utils import ( + indices_for_map, rewrite_generic_to_imperfect_loops, rewrite_generic_to_loops, ) -def load( +def insert_load( source: SSAValue, - indices: Sequence[SSAValue], + affine_map_attr: AffineMapAttr, + ind_vars: Sequence[SSAValue], rewriter: PatternRewriter, - insert_point: InsertPoint, + insertion_point: InsertPoint, ) -> SSAValue: if isinstance(source.type, memref.MemRefType): + indices = indices_for_map( + rewriter, insertion_point, affine_map_attr.data, ind_vars + ) op = memref.Load.get(source, indices) elif isinstance(source.type, stream.ReadableStreamType): op = memref_stream.ReadOp(source) else: return source - rewriter.insert_op(op, insert_point) + rewriter.insert_op(op, insertion_point) return op.res -def store( - value: SSAValue, destination: SSAValue, indices: Sequence[SSAValue] +def insert_store( + value: SSAValue, + destination: SSAValue, + affine_map_attr: AffineMapAttr, + ind_vars: Sequence[SSAValue], + rewriter: PatternRewriter, + insertion_point: InsertPoint, ) -> Operation: if isinstance(destination.type, memref.MemRefType): - return memref.Store.get(value, destination, indices) + indices = indices_for_map( + rewriter, insertion_point, affine_map_attr.data, ind_vars + ) + op = memref.Store.get(value, destination, indices) else: - return memref_stream.WriteOp(value, destination) + op = memref_stream.WriteOp(value, destination) + rewriter.insert_op(op, insertion_point) + return op class LowerGenericOpPattern(RewritePattern): @@ -69,8 +82,8 @@ def match_and_rewrite( op.body.block.args[ins_count:], op.body.block.args[:ins_count], op.body.block, - load, - store, + insert_load, + insert_store, ) else: rewrite_generic_to_loops( @@ -82,8 +95,8 @@ def match_and_rewrite( op.operands, op.outputs, op.body.block, - load, - store, + insert_load, + insert_store, ) diff --git a/xdsl/transforms/loop_nest_lowering_utils.py b/xdsl/transforms/loop_nest_lowering_utils.py index bda8b86677..2cbf40b407 100644 --- a/xdsl/transforms/loop_nest_lowering_utils.py +++ b/xdsl/transforms/loop_nest_lowering_utils.py @@ -1,5 +1,6 @@ from collections.abc import Callable, Sequence from itertools import compress +from typing import TypeAlias from xdsl.dialects import affine, arith, scf from xdsl.dialects.builtin import AffineMapAttr, IndexType, IntegerAttr @@ -62,6 +63,31 @@ def indices_for_map( return output_indices +INSERT_LOAD: TypeAlias = Callable[ + [ + SSAValue, + AffineMapAttr, + Sequence[SSAValue], + PatternRewriter, + InsertPoint, + ], + SSAValue, +] + + +INSERT_STORE: TypeAlias = Callable[ + [ + SSAValue, + SSAValue, + AffineMapAttr, + Sequence[SSAValue], + PatternRewriter, + InsertPoint, + ], + Operation, +] + + def _insert_loop_nest( rewriter: PatternRewriter, insertion_point: InsertPoint, @@ -131,9 +157,7 @@ def _insert_load_ops( affine_map_attrs: Sequence[AffineMapAttr], operands: Sequence[SSAValue], args: Sequence[BlockArgument], - load: Callable[ - [SSAValue, Sequence[SSAValue], PatternRewriter, InsertPoint], SSAValue - ], + insert_load: INSERT_LOAD, ) -> Sequence[tuple[int, SSAValue]]: """ Inserts the load operations at the specified insertion point. @@ -143,6 +167,8 @@ def _insert_load_ops( The `args` are the block arguments corresponding to the use of the load; if there are no uses, the loads are not inserted. The `affine_map_attrs`, `operands`, and `args` must have the same length. + Returns a tuple of integers indicating the locations of the returned values, and + the values themselves. """ res: list[tuple[int, SSAValue]] = [] for i, (affine_map_attr, operand, arg) in enumerate( @@ -150,9 +176,13 @@ def _insert_load_ops( ): if not arg.uses: continue - affine_map = affine_map_attr.data - indices = indices_for_map(rewriter, insertion_point, affine_map, ind_vars) - res_val = load(operand, indices, rewriter, insertion_point) + res_val = insert_load( + operand, + affine_map_attr, + ind_vars, + rewriter, + insertion_point, + ) res.append((i, res_val)) return res @@ -164,7 +194,7 @@ def _insert_store_ops( output_indexing_maps: Sequence[AffineMapAttr], yield_operands: Sequence[SSAValue], output_operands: Sequence[SSAValue], - store: Callable[[SSAValue, SSAValue, Sequence[SSAValue]], Operation], + insert_store: INSERT_STORE, ): """ Inserts the store operations at the specified insertion point. @@ -178,10 +208,9 @@ def _insert_store_ops( for affine_map_attr, yield_value, ref in zip( output_indexing_maps, yield_operands, output_operands, strict=True ): - affine_map = affine_map_attr.data - indices = indices_for_map(rewriter, insertion_point, affine_map, ind_vars) - store_op = store(yield_value, ref, indices) - rewriter.insert_op(store_op, insertion_point) + insert_store( + yield_value, ref, affine_map_attr, ind_vars, rewriter, insertion_point + ) def rewrite_generic_to_loops( @@ -193,10 +222,8 @@ def rewrite_generic_to_loops( load_operands: Sequence[SSAValue], store_operands: Sequence[SSAValue], block: Block, - load: Callable[ - [SSAValue, Sequence[SSAValue], PatternRewriter, InsertPoint], SSAValue - ], - store: Callable[[SSAValue, SSAValue, Sequence[SSAValue]], Operation], + insert_load: INSERT_LOAD, + insert_store: INSERT_STORE, ) -> None: # Create loop nest lb (0), step (1), and ubs # ubs are calculated from affine maps and memref dimensions @@ -227,7 +254,7 @@ def make_body( load_indexing_maps, load_operands, block.args, - load, + insert_load, ) for i, val in loaded_values: @@ -251,7 +278,7 @@ def make_body( store_indexing_maps, yield_op.operands, store_operands, - store, + insert_store, ) return () @@ -283,10 +310,8 @@ def rewrite_generic_to_imperfect_loops( outer_load_block_args: Sequence[BlockArgument], inner_load_block_args: Sequence[BlockArgument], block: Block, - load: Callable[ - [SSAValue, Sequence[SSAValue], PatternRewriter, InsertPoint], SSAValue - ], - store: Callable[[SSAValue, SSAValue, Sequence[SSAValue]], Operation], + insert_load: INSERT_LOAD, + insert_store: INSERT_STORE, ) -> None: # Create loop nest lb (0), step (1), and ubs # ubs are calculated from affine maps and memref dimensions @@ -323,7 +348,7 @@ def outer_make_body( outer_load_indexing_maps, outer_load_operands, outer_load_block_args, - load, + insert_load, ) def inner_make_body( @@ -340,7 +365,7 @@ def inner_make_body( inner_load_indexing_maps, inner_load_operands, inner_load_block_args, - load, + insert_load, ) # Replace block argument use with iter args @@ -390,7 +415,7 @@ def inner_make_body( store_indexing_maps, inner_loop_nest_results, store_operands, - store, + insert_store, ) return () From f89d526104e245118ef5c3e1521f858cbf17a962 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Thu, 20 Jun 2024 16:38:53 +0100 Subject: [PATCH 02/14] transformations: fix yielding of values in memref_stream.generic lowering --- .../convert_memref_stream_to_loops.mlir | 48 +++++++++++++++++++ xdsl/transforms/loop_nest_lowering_utils.py | 7 ++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir index f10ea1c143..9844dd49e5 100644 --- a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir +++ b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir @@ -207,4 +207,52 @@ func.func @elide_affine(%A : memref<6xf64>, %B : memref) -> memref { // CHECK-NEXT: func.return %{{.*}} : memref // CHECK-NEXT: } +func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref) -> memref { + memref_stream.streaming_region { + patterns = [ + #memref_stream.stride_pattern (d0, d1, d2)> + ] + } ins(%A : memref<2x3x4xf64>) { + ^0(%0 : !stream.readable): + memref_stream.generic { + bounds = [#builtin.int<2>, #builtin.int<3>, #builtin.int<4>], + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<() -> ()> + ], + iterator_types = ["reduction", "reduction", "reduction"] + } ins(%0 : !stream.readable) outs(%B : memref) { + ^1(%a : f64, %acc_old : f64): + %acc_new = arith.addf %acc_old, %a : f64 + memref_stream.yield %acc_new : f64 + } + } + func.return %B : memref +} + +// CHECK-NEXT: func.func @nested_imperfect(%{{.*}} : memref<2x3x4xf64>, %{{.*}} : memref) -> memref { +// CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1, d2)>]} ins(%{{.*}} : memref<2x3x4xf64>) { +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable): +// CHECK-NEXT: %{{.*}} = arith.constant 2 : index +// CHECK-NEXT: %{{.*}} = arith.constant 3 : index +// CHECK-NEXT: %{{.*}} = arith.constant 4 : index +// CHECK-NEXT: %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: %{{.*}} = arith.constant 1 : index +// CHECK-NEXT: %{{.*}} = memref.load %{{.*}}[] : memref +// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f64) { +// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f64) { +// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (f64) { +// CHECK-NEXT: %{{.*}} = memref_stream.read from %{{.*}} : f64 +// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 +// CHECK-NEXT: scf.yield %{{.*}} : f64 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %{{.*}} : f64 +// CHECK-NEXT: } +// CHECK-NEXT: scf.yield %{{.*}} : f64 +// CHECK-NEXT: } +// CHECK-NEXT: memref.store %{{.*}}, %{{.*}}[] : memref +// CHECK-NEXT: } +// CHECK-NEXT: func.return %{{.*}} : memref +// CHECK-NEXT: } + // CHECK-NEXT: } diff --git a/xdsl/transforms/loop_nest_lowering_utils.py b/xdsl/transforms/loop_nest_lowering_utils.py index 2cbf40b407..8286659ba4 100644 --- a/xdsl/transforms/loop_nest_lowering_utils.py +++ b/xdsl/transforms/loop_nest_lowering_utils.py @@ -129,7 +129,9 @@ def _insert_loop_nest( iter_args = loop.body.block.args[1:] loops.append(loop) rewriter.insert_op(loop, insertion_point) - results = loop.results + if i: + # Do not insert yield outside of outermost loop + rewriter.insert_op(scf.Yield(*loop.results), InsertPoint.after(loop)) if i + 1 == len(bounds): # Innermost loop iteration @@ -144,7 +146,8 @@ def _insert_loop_nest( "Unexpected number of results from `make_body` helper " f"({len(results)}), expected {len(iter_args)}" ) - rewriter.insert_op(scf.Yield(*results), InsertPoint.at_end(loop.body.block)) + rewriter.insert_op(scf.Yield(*results), InsertPoint.at_end(loop.body.block)) + insertion_point = InsertPoint.at_start(loop.body.block) return loops[0].results From 299de80693b3f2b4e88266b6998ebc4bc7d20809 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Fri, 21 Jun 2024 09:06:24 +0100 Subject: [PATCH 03/14] tests: move constant initialisation around in bottom-up tests --- .../riscv-backend-paper/bottom_up.mlir | 223 +++++++++--------- 1 file changed, 110 insertions(+), 113 deletions(-) diff --git a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir index 4b5c5f0360..5bbd23be07 100644 --- a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir @@ -5,13 +5,6 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( %Y: memref<1x1x3x3xf64>, %Z: memref<1x1x6x6xf64> ) -> () { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index - %c6 = arith.constant 6 : index - - %zero_float = arith.constant 0.0 : f64 - memref_stream.streaming_region { patterns = [ #memref_stream.stride_pattern (d0, d4, d2 + d5, d3 + d6)>, @@ -20,6 +13,12 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( ] } ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { ^0(%x_stream : !stream.readable, %y_stream : !stream.readable, %z_stream : !stream.writable): + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c6 = arith.constant 6 : index + + %zero_float = arith.constant 0.0 : f64 scf.for %i0 = %c0 to %c1 step %c1 { scf.for %i1 = %c0 to %c1 step %c1 { @@ -51,46 +50,46 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( // CHECK-NEXT: .globl conv_2d_nchw_fchw_d1_s1_3x3 // CHECK-NEXT: .p2align 2 // CHECK-NEXT: conv_2d_nchw_fchw_d1_s1_3x3: -// CHECK-NEXT: mv t3, a0 -// CHECK-NEXT: mv t0, a1 -// CHECK-NEXT: mv t1, a2 -// CHECK-NEXT: fcvt.d.w ft3, zero -// CHECK-NEXT: li t4, 2 -// CHECK-NEXT: scfgwi t4, 64 -// CHECK-NEXT: li t4, 2 -// CHECK-NEXT: scfgwi t4, 96 -// CHECK-NEXT: li t4, 5 -// CHECK-NEXT: scfgwi t4, 128 -// CHECK-NEXT: li t4, 5 -// CHECK-NEXT: scfgwi t4, 160 -// CHECK-NEXT: li t4, 8 -// CHECK-NEXT: scfgwi t4, 192 -// CHECK-NEXT: li t4, 48 -// CHECK-NEXT: scfgwi t4, 224 -// CHECK-NEXT: li t4, -136 -// CHECK-NEXT: scfgwi t4, 256 -// CHECK-NEXT: li t4, -120 -// CHECK-NEXT: scfgwi t4, 288 -// CHECK-NEXT: li t4, 2 -// CHECK-NEXT: scfgwi t4, 65 -// CHECK-NEXT: li t4, 2 -// CHECK-NEXT: scfgwi t4, 97 -// CHECK-NEXT: li t4, 35 -// CHECK-NEXT: scfgwi t4, 129 -// CHECK-NEXT: li t4, 8 -// CHECK-NEXT: scfgwi t4, 193 -// CHECK-NEXT: li t4, 8 -// CHECK-NEXT: scfgwi t4, 225 -// CHECK-NEXT: li t4, -64 -// CHECK-NEXT: scfgwi t4, 257 -// CHECK-NEXT: li t4, 35 -// CHECK-NEXT: scfgwi t4, 66 -// CHECK-NEXT: li t4, 8 -// CHECK-NEXT: scfgwi t4, 194 -// CHECK-NEXT: scfgwi t3, 864 -// CHECK-NEXT: scfgwi t0, 833 -// CHECK-NEXT: scfgwi t1, 898 +// CHECK-NEXT: mv t0, a0 +// CHECK-NEXT: mv t1, a1 +// CHECK-NEXT: mv t2, a2 +// CHECK-NEXT: li t3, 2 +// CHECK-NEXT: scfgwi t3, 64 +// CHECK-NEXT: li t3, 2 +// CHECK-NEXT: scfgwi t3, 96 +// CHECK-NEXT: li t3, 5 +// CHECK-NEXT: scfgwi t3, 128 +// CHECK-NEXT: li t3, 5 +// CHECK-NEXT: scfgwi t3, 160 +// CHECK-NEXT: li t3, 8 +// CHECK-NEXT: scfgwi t3, 192 +// CHECK-NEXT: li t3, 48 +// CHECK-NEXT: scfgwi t3, 224 +// CHECK-NEXT: li t3, -136 +// CHECK-NEXT: scfgwi t3, 256 +// CHECK-NEXT: li t3, -120 +// CHECK-NEXT: scfgwi t3, 288 +// CHECK-NEXT: li t3, 2 +// CHECK-NEXT: scfgwi t3, 65 +// CHECK-NEXT: li t3, 2 +// CHECK-NEXT: scfgwi t3, 97 +// CHECK-NEXT: li t3, 35 +// CHECK-NEXT: scfgwi t3, 129 +// CHECK-NEXT: li t3, 8 +// CHECK-NEXT: scfgwi t3, 193 +// CHECK-NEXT: li t3, 8 +// CHECK-NEXT: scfgwi t3, 225 +// CHECK-NEXT: li t3, -64 +// CHECK-NEXT: scfgwi t3, 257 +// CHECK-NEXT: li t3, 35 +// CHECK-NEXT: scfgwi t3, 66 +// CHECK-NEXT: li t3, 8 +// CHECK-NEXT: scfgwi t3, 194 +// CHECK-NEXT: scfgwi t0, 864 +// CHECK-NEXT: scfgwi t1, 833 +// CHECK-NEXT: scfgwi t2, 898 // CHECK-NEXT: csrrsi zero, 1984, 1 +// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t1, 36 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge @@ -407,14 +406,6 @@ func.func public @pooling_nchw_max_d1_s2_3x3( %X: memref<1x1x16x16xf64>, %Y: memref<1x1x7x7xf64> ) -> () { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index - %c7 = arith.constant 7 : index - %c512 = arith.constant 512 : index - - %min_val = arith.constant -10000.0 : f64 - memref_stream.streaming_region { patterns = [ #memref_stream.stride_pattern (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, @@ -422,6 +413,13 @@ func.func public @pooling_nchw_max_d1_s2_3x3( ] } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { ^0(%x_stream : !stream.readable, %y_stream : !stream.writable): + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c7 = arith.constant 7 : index + %c512 = arith.constant 512 : index + + %min_val = arith.constant -10000.0 : f64 scf.for %i0 = %c0 to %c1 step %c1 { scf.for %i1 = %c0 to %c1 step %c1 { scf.for %i2 = %c0 to %c7 step %c1 { @@ -450,33 +448,33 @@ func.func public @pooling_nchw_max_d1_s2_3x3( // CHECK-NEXT: .globl pooling_nchw_max_d1_s2_3x3 // CHECK-NEXT: .p2align 2 // CHECK-NEXT: pooling_nchw_max_d1_s2_3x3: -// CHECK-NEXT: mv t0, a0 -// CHECK-NEXT: mv t1, a1 -// CHECK-NEXT: li t3, -10000 -// CHECK-NEXT: fcvt.d.w ft3, t3 -// CHECK-NEXT: li t3, 2 -// CHECK-NEXT: scfgwi t3, 64 -// CHECK-NEXT: li t3, 2 -// CHECK-NEXT: scfgwi t3, 96 -// CHECK-NEXT: li t3, 6 -// CHECK-NEXT: scfgwi t3, 128 -// CHECK-NEXT: li t3, 6 -// CHECK-NEXT: scfgwi t3, 160 -// CHECK-NEXT: li t3, 8 -// CHECK-NEXT: scfgwi t3, 192 -// CHECK-NEXT: li t3, 112 -// CHECK-NEXT: scfgwi t3, 224 -// CHECK-NEXT: li t3, -256 -// CHECK-NEXT: scfgwi t3, 256 -// CHECK-NEXT: li t3, -112 -// CHECK-NEXT: scfgwi t3, 288 -// CHECK-NEXT: li t3, 48 -// CHECK-NEXT: scfgwi t3, 65 -// CHECK-NEXT: li t3, 8 -// CHECK-NEXT: scfgwi t3, 193 -// CHECK-NEXT: scfgwi t0, 864 -// CHECK-NEXT: scfgwi t1, 897 +// CHECK-NEXT: mv t1, a0 +// CHECK-NEXT: mv t2, a1 +// CHECK-NEXT: li t0, 2 +// CHECK-NEXT: scfgwi t0, 64 +// CHECK-NEXT: li t0, 2 +// CHECK-NEXT: scfgwi t0, 96 +// CHECK-NEXT: li t0, 6 +// CHECK-NEXT: scfgwi t0, 128 +// CHECK-NEXT: li t0, 6 +// CHECK-NEXT: scfgwi t0, 160 +// CHECK-NEXT: li t0, 8 +// CHECK-NEXT: scfgwi t0, 192 +// CHECK-NEXT: li t0, 112 +// CHECK-NEXT: scfgwi t0, 224 +// CHECK-NEXT: li t0, -256 +// CHECK-NEXT: scfgwi t0, 256 +// CHECK-NEXT: li t0, -112 +// CHECK-NEXT: scfgwi t0, 288 +// CHECK-NEXT: li t0, 48 +// CHECK-NEXT: scfgwi t0, 65 +// CHECK-NEXT: li t0, 8 +// CHECK-NEXT: scfgwi t0, 193 +// CHECK-NEXT: scfgwi t1, 864 +// CHECK-NEXT: scfgwi t2, 897 // CHECK-NEXT: csrrsi zero, 1984, 1 +// CHECK-NEXT: li t1, -10000 +// CHECK-NEXT: fcvt.d.w ft3, t1 // CHECK-NEXT: li t1, 49 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge @@ -540,14 +538,6 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( %X: memref<1x1x16x16xf64>, %Y: memref<1x1x7x7xf64> ) -> () { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index - %c7 = arith.constant 7 : index - %c512 = arith.constant 512 : index - - %zero_float = arith.constant 0.0 : f64 - memref_stream.streaming_region { patterns = [ #memref_stream.stride_pattern (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, @@ -555,6 +545,13 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( ] } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { ^0(%x_stream : !stream.readable, %y_stream : !stream.writable): + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c7 = arith.constant 7 : index + %c512 = arith.constant 512 : index + + %zero_float = arith.constant 0.0 : f64 scf.for %i0 = %c0 to %c1 step %c1 { scf.for %i1 = %c0 to %c1 step %c1 { scf.for %i2 = %c0 to %c7 step %c1 { @@ -583,32 +580,32 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( // CHECK-NEXT: .globl pooling_nchw_sum_d1_s2_3x3 // CHECK-NEXT: .p2align 2 // CHECK-NEXT: pooling_nchw_sum_d1_s2_3x3: -// CHECK-NEXT: mv t0, a0 -// CHECK-NEXT: mv t1, a1 -// CHECK-NEXT: fcvt.d.w ft3, zero -// CHECK-NEXT: li t3, 2 -// CHECK-NEXT: scfgwi t3, 64 -// CHECK-NEXT: li t3, 2 -// CHECK-NEXT: scfgwi t3, 96 -// CHECK-NEXT: li t3, 6 -// CHECK-NEXT: scfgwi t3, 128 -// CHECK-NEXT: li t3, 6 -// CHECK-NEXT: scfgwi t3, 160 -// CHECK-NEXT: li t3, 8 -// CHECK-NEXT: scfgwi t3, 192 -// CHECK-NEXT: li t3, 112 -// CHECK-NEXT: scfgwi t3, 224 -// CHECK-NEXT: li t3, -256 -// CHECK-NEXT: scfgwi t3, 256 -// CHECK-NEXT: li t3, -112 -// CHECK-NEXT: scfgwi t3, 288 -// CHECK-NEXT: li t3, 48 -// CHECK-NEXT: scfgwi t3, 65 -// CHECK-NEXT: li t3, 8 -// CHECK-NEXT: scfgwi t3, 193 -// CHECK-NEXT: scfgwi t0, 864 -// CHECK-NEXT: scfgwi t1, 897 +// CHECK-NEXT: mv t1, a0 +// CHECK-NEXT: mv t2, a1 +// CHECK-NEXT: li t0, 2 +// CHECK-NEXT: scfgwi t0, 64 +// CHECK-NEXT: li t0, 2 +// CHECK-NEXT: scfgwi t0, 96 +// CHECK-NEXT: li t0, 6 +// CHECK-NEXT: scfgwi t0, 128 +// CHECK-NEXT: li t0, 6 +// CHECK-NEXT: scfgwi t0, 160 +// CHECK-NEXT: li t0, 8 +// CHECK-NEXT: scfgwi t0, 192 +// CHECK-NEXT: li t0, 112 +// CHECK-NEXT: scfgwi t0, 224 +// CHECK-NEXT: li t0, -256 +// CHECK-NEXT: scfgwi t0, 256 +// CHECK-NEXT: li t0, -112 +// CHECK-NEXT: scfgwi t0, 288 +// CHECK-NEXT: li t0, 48 +// CHECK-NEXT: scfgwi t0, 65 +// CHECK-NEXT: li t0, 8 +// CHECK-NEXT: scfgwi t0, 193 +// CHECK-NEXT: scfgwi t1, 864 +// CHECK-NEXT: scfgwi t2, 897 // CHECK-NEXT: csrrsi zero, 1984, 1 +// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t1, 49 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge From 2e5a0241721424ee00ede6c289f510f3d7982b14 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Fri, 14 Jun 2024 10:05:07 +0100 Subject: [PATCH 04/14] dialects: (memref_stream) add an inits field to memref_stream.generic --- .../filecheck/dialects/memref_stream/ops.mlir | 48 ++++++++++++-- .../dialects/memref_stream/verify.mlir | 15 +++-- .../convert_linalg_to_memref_stream.mlir | 6 +- .../convert_memref_stream_to_loops.mlir | 16 +++-- .../memref_stream_unnest_out_parameters.mlir | 5 +- .../transforms/memref_streamify.mlir | 15 +++-- .../test_memref_stream_interpreter.py | 56 +++++++++++++++++ xdsl/dialects/memref_stream.py | 62 ++++++++++++++++++- xdsl/interpreters/memref_stream.py | 11 +++- .../convert_linalg_to_memref_stream.py | 7 +-- .../convert_memref_stream_to_loops.py | 5 +- xdsl/transforms/memref_streamify.py | 15 +++-- 12 files changed, 217 insertions(+), 44 deletions(-) diff --git a/tests/filecheck/dialects/memref_stream/ops.mlir b/tests/filecheck/dialects/memref_stream/ops.mlir index 9c84668e63..3185961daf 100644 --- a/tests/filecheck/dialects/memref_stream/ops.mlir +++ b/tests/filecheck/dialects/memref_stream/ops.mlir @@ -50,18 +50,19 @@ memref_stream.generic { affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"] + iterator_types = ["parallel", "parallel"], + inits = [unit] } ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {hello = "world"} { ^bb0(%arg3: f32, %arg4: f32): memref_stream.yield %arg3 : f32 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {"hello" = "world"} { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {"hello" = "world"} { // CHECK-NEXT: ^1(%arg3 : f32, %arg4 : f32): // CHECK-NEXT: memref_stream.yield %arg3 : f32 // CHECK-NEXT: } -// CHECK-GENERIC-NEXT: "memref_stream.generic"(%A, %B, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "indexing_maps" = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%A, %B, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "inits" = [unit], "indexing_maps" = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^1(%arg3 : f32, %arg4 : f32): // CHECK-GENERIC-NEXT: "memref_stream.yield"(%arg3) : (f32) -> () // CHECK-GENERIC-NEXT: }) {"hello" = "world"} : (memref<2xf32>, memref<3xf32>, memref<3x2xf64>) -> () @@ -72,21 +73,56 @@ memref_stream.generic { affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"] + iterator_types = ["parallel", "parallel"], + inits = [unit] } ins(%D : f64) outs(%C : memref<3x2xf64>) { ^bb0(%d : f64, %c : f64): memref_stream.yield %d : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%D : f64) outs(%C : memref<3x2xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%D : f64) outs(%C : memref<3x2xf64>) { // CHECK-NEXT: ^2(%d : f64, %c_1 : f64): // CHECK-NEXT: memref_stream.yield %d : f64 // CHECK-NEXT: } -// CHECK-GENERIC-NEXT: "memref_stream.generic"(%D, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%D, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "inits" = [unit], "indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^2(%d : f64, %c_1 : f64): // CHECK-GENERIC-NEXT: "memref_stream.yield"(%d) : (f64) -> () // CHECK-GENERIC-NEXT: }) : (f64, memref<3x2xf64>) -> () +%E, %F, %G = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) +// CHECK-NEXT: %E, %F, %G = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) +// CHECK-GENERIC-NEXT: %E, %F, %G = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) + +memref_stream.generic { + bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"], + inits = [0.000000e+00 : f64] +} ins(%E, %F : memref<4x2xf64>, memref<2x3xf64>) outs(%G : memref<4x3xf64>) { +^0(%e : f64, %f : f64, %acc_old : f64): + %prod = arith.mulf %e, %f : f64 + %acc_new = arith.addf %acc_old, %prod : f64 + linalg.yield %acc_new : f64 +} + +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], inits = [0.000000e+00 : f64]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) { +// CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): +// CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 +// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 +// CHECK-NEXT: linalg.yield %{{.*}} : f64 +// CHECK-NEXT: } + +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%{{.*}}, %{{.*}}, %{{.*}}) <{"bounds" = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], "inits" = [0.000000e+00 : f64], "indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f64) -> () +// CHECK-GENERIC-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -> () + // CHECK-NEXT: } // CHECK-GENERIC-NEXT: }) : () -> () diff --git a/tests/filecheck/dialects/memref_stream/verify.mlir b/tests/filecheck/dialects/memref_stream/verify.mlir index 83b8d12f29..20486c8550 100644 --- a/tests/filecheck/dialects/memref_stream/verify.mlir +++ b/tests/filecheck/dialects/memref_stream/verify.mlir @@ -9,7 +9,8 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)> ], - iterator_types = ["parallel", "reduction", "parallel"] + iterator_types = ["parallel", "reduction", "parallel"], + inits = [unit] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -29,7 +30,8 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d1, d2)> ], - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = ["parallel", "parallel", "reduction"], + inits = [unit] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -50,7 +52,8 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2)> ], - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = ["parallel", "parallel", "reduction"], + inits = [unit] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -71,7 +74,8 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2)> ], - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = ["parallel", "parallel", "reduction"], + inits = [unit] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -93,7 +97,8 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = ["parallel", "parallel", "reduction"], + inits = [unit, unit] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C, %D : memref<4x3xf64>, memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old0 : f64, %acc_old1 : f64): %prod = arith.mulf %a, %b : f64 diff --git a/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir b/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir index 707fd13e6e..86bd539b1b 100644 --- a/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir +++ b/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir @@ -22,7 +22,7 @@ linalg.generic { %acc_new = arith.addf %acc_old, %prod : f64 linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [], indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%A, %B : memref, memref) outs(%C : memref) { +// CHECK-NEXT: memref_stream.generic {bounds = [], indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = [], inits = [unit]} ins(%A, %B : memref, memref) outs(%C : memref) { // CHECK-NEXT: ^0(%a : f64, %b : f64, %acc_old : f64): // CHECK-NEXT: %prod = arith.mulf %a, %b : f64 // CHECK-NEXT: %acc_new = arith.addf %acc_old, %prod : f64 @@ -44,7 +44,7 @@ linalg.generic { linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<2>, #builtin.int<3>, #builtin.int<4>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%D, %E : memref<2x3xf64>, memref<3x4xf64>) outs(%F : memref<2x4xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<2>, #builtin.int<3>, #builtin.int<4>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>], iterator_types = ["parallel", "parallel", "reduction"], inits = [unit]} ins(%D, %E : memref<2x3xf64>, memref<3x4xf64>) outs(%F : memref<2x4xf64>) { // CHECK-NEXT: ^1(%d : f64, %e : f64, %acc_old_1 : f64): // CHECK-NEXT: %prod_1 = arith.mulf %d, %e : f64 // CHECK-NEXT: %acc_new_1 = arith.addf %acc_old_1, %prod_1 : f64 @@ -65,7 +65,7 @@ linalg.generic { linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ((d0 + d1))>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%G, %H : memref<4xf64>, memref<2xf64>) outs(%I : memref<3xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ((d0 + d1))>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], inits = [unit]} ins(%G, %H : memref<4xf64>, memref<2xf64>) outs(%I : memref<3xf64>) { // CHECK-NEXT: ^2(%g : f64, %h : f64, %acc_old_2 : f64): // CHECK-NEXT: %prod_2 = arith.mulf %g, %h : f64 // CHECK-NEXT: %acc_new_2 = arith.addf %acc_old_2, %prod_2 : f64 diff --git a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir index 9844dd49e5..84c993bc19 100644 --- a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir +++ b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir @@ -11,7 +11,7 @@ ] } ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { ^0(%0 : !stream.readable, %1 : !stream.readable, %2 : !stream.writable): - memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { + memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { ^1(%in : f64, %in_0 : f64, %out : f64): %3 = arith.addf %in, %in_0 : f64 memref_stream.yield %3 : f64 @@ -48,7 +48,7 @@ ] } ins(%arg0_1 : memref<16x16xf64>) outs(%arg1_1 : memref<16x16xf64>) { ^2(%4 : !stream.readable, %5 : !stream.writable): - memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%4 : !stream.readable) outs(%5 : !stream.writable) { + memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%4 : !stream.readable) outs(%5 : !stream.writable) { ^3(%in_1 : f64, %out_1 : f64): %6 = arith.maximumf %in_1, %cst : f64 memref_stream.yield %6 : f64 @@ -90,7 +90,8 @@ func.func public @fill(%arg0 : memref<16x16xf64>) -> memref<16x16xf64> { affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"] + iterator_types = ["parallel", "parallel"], + inits = [unit] } ins(%zero : f64) outs(%7 : !stream.writable) { ^4(%in: f64, %out: f64): memref_stream.yield %in : f64 @@ -131,7 +132,8 @@ func.func @main(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64> affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = ["parallel", "parallel", "reduction"], + inits = [unit] } ins(%0, %1 : !stream.readable, !stream.readable) outs(%C : memref<4x3xf64>) { ^1(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -179,7 +181,8 @@ func.func @elide_affine(%A : memref<6xf64>, %B : memref) -> memref { affine_map<(d0, d1) -> (d0 * 3 + d1)>, affine_map<(d0, d1) -> ()> ], - iterator_types = ["parallel", "reduction"] + iterator_types = ["parallel", "reduction"], + inits = [unit] } ins(%0 : !stream.readable) outs(%B : memref) { ^1(%a : f64, %acc_old : f64): %acc_new = arith.addf %acc_old, %a : f64 @@ -220,7 +223,8 @@ func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref) -> memref< affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<() -> ()> ], - iterator_types = ["reduction", "reduction", "reduction"] + iterator_types = ["reduction", "reduction", "reduction"], + inits = [unit] } ins(%0 : !stream.readable) outs(%B : memref) { ^1(%a : f64, %acc_old : f64): %acc_new = arith.addf %acc_old, %a : f64 diff --git a/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir b/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir index a585eeac2d..1878fd9f0e 100644 --- a/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir +++ b/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir @@ -10,7 +10,8 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"] + iterator_types = ["parallel", "parallel", "reduction"], + inits = [unit] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -20,7 +21,7 @@ memref_stream.generic { // CHECK: builtin.module { // CHECK-NEXT: %{{.*}}, %{{.*}}, %{{.*}} = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], inits = [unit]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 // CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 diff --git a/tests/filecheck/transforms/memref_streamify.mlir b/tests/filecheck/transforms/memref_streamify.mlir index 454d9f1da5..287a0fb193 100644 --- a/tests/filecheck/transforms/memref_streamify.mlir +++ b/tests/filecheck/transforms/memref_streamify.mlir @@ -10,7 +10,8 @@ func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 memref_stream.generic { bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] + iterator_types = ["parallel", "parallel"], + inits = [unit] } ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { ^0(%in : f64, %in_0 : f64, %out : f64): %0 = arith.addf %in, %in_0 : f64 @@ -22,7 +23,7 @@ func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 // CHECK-NEXT: func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 : memref<8x16xf64>) -> memref<8x16xf64> { // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1)>, #memref_stream.stride_pattern (d0, d1)>, #memref_stream.stride_pattern (d0, d1)>]} ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { // CHECK-NEXT: ^0(%0 : !stream.readable, %1 : !stream.readable, %2 : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { // CHECK-NEXT: ^1(%in : f64, %in_1 : f64, %out : f64): // CHECK-NEXT: %3 = arith.addf %in, %in_1 : f64 // CHECK-NEXT: memref_stream.yield %3 : f64 @@ -36,7 +37,8 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> memref_stream.generic { bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] + iterator_types = ["parallel", "parallel"], + inits = [unit] } ins(%arg0 : memref<16x16xf64>) outs(%arg1 : memref<16x16xf64>) { ^1(%in_1 : f64, %out_1 : f64): %1 = arith.maximumf %in_1, %cst : f64 @@ -49,7 +51,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1)>, #memref_stream.stride_pattern (d0, d1)>]} ins(%arg0 : memref<16x16xf64>) outs(%arg1 : memref<16x16xf64>) { // CHECK-NEXT: ^0(%0 : !stream.readable, %1 : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : !stream.readable) outs(%1 : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%0 : !stream.readable) outs(%1 : !stream.writable) { // CHECK-NEXT: ^1(%in : f64, %out : f64): // CHECK-NEXT: %2 = arith.maximumf %in, %cst : f64 // CHECK-NEXT: memref_stream.yield %2 : f64 @@ -68,7 +70,8 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"] + iterator_types = ["parallel", "parallel"], + inits = [unit] } ins(%X : f64) outs(%Y : memref<16x16xf64>) { ^bb0(%d : f64, %c : f64): memref_stream.yield %d : f64 @@ -80,7 +83,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: func.func @fill(%{{.*}} : f64, %{{.*}} : memref<16x16xf64>) { // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1)>]} outs(%{{.*}} : memref<16x16xf64>) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%{{.*}} : f64) outs(%{{.*}} : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%{{.*}} : f64) outs(%{{.*}} : !stream.writable) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: memref_stream.yield %{{.*}} : f64 // CHECK-NEXT: } diff --git a/tests/interpreters/test_memref_stream_interpreter.py b/tests/interpreters/test_memref_stream_interpreter.py index 4790598944..a74e96673c 100644 --- a/tests/interpreters/test_memref_stream_interpreter.py +++ b/tests/interpreters/test_memref_stream_interpreter.py @@ -4,9 +4,11 @@ AffineMapAttr, ArrayAttr, Float32Type, + FloatAttr, IntAttr, MemRefType, ModuleOp, + UnitAttr, i32, ) from xdsl.interpreter import Interpreter @@ -31,6 +33,7 @@ def test_memref_stream_generic(): ), (TestSSAValue(MemRefType(i32, [1, 6])),), Region(Block(arg_types=(i32, i32))), + ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.identity(2)), @@ -83,6 +86,7 @@ def test_memref_stream_generic_scalar(): ), (TestSSAValue(MemRefType(i32, [1, 6])),), Region(Block(arg_types=(i32, i32))), + ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.identity(2)), @@ -135,6 +139,7 @@ def test_memref_stream_generic_reduction(): ), (TestSSAValue(MemRefType(i32, [])),), Region(Block(arg_types=(i32, i32, i32))), + ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.identity(1)), @@ -176,6 +181,7 @@ def test_memref_stream_generic_imperfect_nesting(): ), (TestSSAValue(MemRefType(f32, [3, 3])),), Region(Block(arg_types=(f32, f32, f32))), + ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.from_callable(lambda n, m, k: (n, k))), @@ -209,3 +215,53 @@ def test_memref_stream_generic_imperfect_nesting(): TypedPtr.new_float32([6.0, 7.0, 21.0, 16.0, 17.0, 47.0, 26.0, 27.0, 73.0]), [3, 3], ) + + +def test_memref_stream_generic_reduction_with_initial_value(): + interpreter = Interpreter(ModuleOp([])) + interpreter.register_implementations(MemrefStreamFunctions()) + interpreter.register_implementations(ArithFunctions()) + + f32 = Float32Type() + + op = memref_stream.GenericOp( + ( + TestSSAValue(MemRefType(f32, [3, 2])), + TestSSAValue(MemRefType(f32, [2, 3])), + ), + (TestSSAValue(MemRefType(f32, [3, 3])),), + Region(Block(arg_types=(f32, f32, f32))), + ArrayAttr((FloatAttr(0.5, f32),)), + ArrayAttr( + ( + AffineMapAttr(AffineMap.from_callable(lambda n, m, k: (n, k))), + AffineMapAttr(AffineMap.from_callable(lambda n, m, k: (k, m))), + AffineMapAttr(AffineMap.from_callable(lambda n, m: (n, m))), + ) + ), + ArrayAttr( + ( + memref_stream.IteratorTypeAttr.parallel(), + memref_stream.IteratorTypeAttr.parallel(), + memref_stream.IteratorTypeAttr.reduction(), + ) + ), + ArrayAttr((IntAttr(3), IntAttr(3), IntAttr(2))), + ) + + with ImplicitBuilder(op.body) as (lhs, rhs, acc): + sum = arith.Mulf(lhs, rhs).result + new_acc = arith.Addf(sum, acc).result + memref_stream.YieldOp(new_acc) + + op.verify() + + a = ShapedArray(TypedPtr.new_float32([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), [3, 2]) + b = ShapedArray(TypedPtr.new_float32([4.0, 3.0, 5.0, 1.0, 2.0, 8.0]), [2, 3]) + c = ShapedArray(TypedPtr.new_float32([0.0] * 9), [3, 3]) + + interpreter.run_op(op, (a, b, c)) + assert c == ShapedArray( + TypedPtr.new_float32([6.5, 7.5, 21.5, 16.5, 17.5, 47.5, 26.5, 27.5, 73.5]), + [3, 3], + ) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index bf280bf8ee..ba7a02eb1d 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -15,7 +15,15 @@ from typing_extensions import Self from xdsl.dialects import memref, stream -from xdsl.dialects.builtin import AffineMapAttr, ArrayAttr, IntAttr, StringAttr +from xdsl.dialects.builtin import ( + AffineMapAttr, + AnyFloatAttr, + AnyIntegerAttr, + ArrayAttr, + IntAttr, + StringAttr, + UnitAttr, +) from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( Attribute, @@ -39,6 +47,7 @@ from xdsl.printer import Printer from xdsl.traits import IsTerminator, NoTerminator from xdsl.utils.exceptions import VerifyException +from xdsl.utils.hints import isa from xdsl.utils.str_enum import StrEnum @@ -310,6 +319,12 @@ class GenericOp(IRDLOperation): pattern defines the order in which the elements of the input buffers will be written to. """ + inits = prop_def(ArrayAttr[AnyFloatAttr | AnyIntegerAttr | UnitAttr]) + """ + Initial values for outputs. If `NoneAttr`, then the value is read from the output + buffer. Otherwise, the value is created at runtime with an `arith.constant` operation + during lowering. The inits may be set only for the imperfectly nested form. + """ indexing_maps = prop_def(ArrayAttr[AffineMapAttr]) """ Stride patterns that define the order of the input and output streams. @@ -332,6 +347,7 @@ def __init__( inputs: Sequence[SSAValue], outputs: Sequence[SSAValue], body: Region, + inits: ArrayAttr[AnyFloatAttr | AnyIntegerAttr | UnitAttr], indexing_maps: ArrayAttr[AffineMapAttr], iterator_types: ArrayAttr[Attribute], bounds: ArrayAttr[IntAttr], @@ -345,6 +361,7 @@ def __init__( operands=[inputs, outputs], properties={ "bounds": bounds, + "inits": inits, "indexing_maps": ArrayAttr(indexing_maps), "iterator_types": ArrayAttr(iterator_types), }, @@ -380,6 +397,12 @@ def print(self, printer: Printer): lambda iterator_type: printer.print_string_literal(iterator_type.data), ) printer.print_string("]") + printer.print_string(", inits = [") + printer.print_list( + self.inits, + lambda val: printer.print_attribute(val), + ) + printer.print_string("]") printer.print_string("}") if self.inputs: @@ -405,6 +428,8 @@ def print(self, printer: Printer): del extra_attrs["doc"] if "library_call" in extra_attrs: del extra_attrs["library_call"] + if "inits" in extra_attrs: + del extra_attrs["inits"] if extra_attrs: printer.print(" attrs = ") @@ -438,7 +463,7 @@ def parse(cls, parser: Parser) -> Self: del attrs["indexing_maps"] else: parser.raise_error( - "Expected indexing_maps for linalg.generic", + "Expected indexing_maps for memref_stream.generic", attrs_start_pos, attrs_end_pos, ) @@ -469,7 +494,19 @@ def parse(cls, parser: Parser) -> Self: ) else: parser.raise_error( - "Expected iterator_types for linalg.generic", + "Expected iterator_types for memref_stream.generic", + attrs_start_pos, + attrs_end_pos, + ) + + if "inits" in attrs: + inits = attrs["inits"] + if not isa(inits, ArrayAttr[AnyFloatAttr | AnyIntegerAttr | UnitAttr]): + parser.raise_error("Expected inits for memref_stream.generic") + del attrs["inits"] + else: + parser.raise_error( + "Expected inits for memref_stream.generic", attrs_start_pos, attrs_end_pos, ) @@ -532,6 +569,7 @@ def parse(cls, parser: Parser) -> Self: ins, outs, body, + inits, indexing_maps, ArrayAttr(iterator_types), bounds, @@ -542,6 +580,13 @@ def parse(cls, parser: Parser) -> Self: return generic def verify_(self) -> None: + # Verify that the number of initial values for outputs is the same as the number + # of outputs + if len(self.inits) != len(self.outputs): + raise VerifyException( + f"Mismatching number of outputs and initial values: {len(self.outputs)} != {self.inits}" + ) + # Parallel iterator types must preceed reduction iterators iterator_types = self.iterator_types.data num_parallel = iterator_types.count(IteratorTypeAttr.parallel()) @@ -586,6 +631,17 @@ def verify_(self) -> None: f"{len(iterator_types)} or {num_parallel}" ) + # The non-None values of the inits must correspond to inputs where the domain + # of the affine map has the same number of dimensions as the number of parallel + # iterators + for i, (m, init) in enumerate(zip(output_maps, self.inits, strict=True)): + if init != UnitAttr(): + if m.data.num_dims != num_parallel: + raise VerifyException( + "Incompatible affine map and initial value for output at index " + f"{i}" + ) + @irdl_op_definition class YieldOp(AbstractYieldOperation[Attribute]): diff --git a/xdsl/interpreters/memref_stream.py b/xdsl/interpreters/memref_stream.py index 559248f364..7ba522856d 100644 --- a/xdsl/interpreters/memref_stream.py +++ b/xdsl/interpreters/memref_stream.py @@ -2,6 +2,7 @@ from typing import Any, cast from xdsl.dialects import memref_stream +from xdsl.dialects.builtin import UnitAttr from xdsl.interpreter import ( Interpreter, InterpreterFunctions, @@ -33,18 +34,22 @@ def run_generic( outer_ubs, inner_ubs = op.get_static_loop_ranges() + inits = op.inits.data + if inner_ubs: inputs: tuple[ShapedArray[float] | float, ...] = args[:inputs_count] input_indexing_maps = indexing_maps[:inputs_count] for outer_indices in product(*(range(outer_ub) for outer_ub in outer_ubs)): output_loop_args = tuple( ( - (cast(ShapedArray[Any], o)).load( + (cast(ShapedArray[int | float], o)).load( indexing_map.eval(outer_indices, ()) ) + if isinstance(init, UnitAttr) + else init.value.data ) - for o, indexing_map in zip( - outputs, output_indexing_maps, strict=True + for o, indexing_map, init in zip( + outputs, output_indexing_maps, inits, strict=True ) ) for inner_indices in product( diff --git a/xdsl/transforms/convert_linalg_to_memref_stream.py b/xdsl/transforms/convert_linalg_to_memref_stream.py index c6fef810a7..8a59ac5c37 100644 --- a/xdsl/transforms/convert_linalg_to_memref_stream.py +++ b/xdsl/transforms/convert_linalg_to_memref_stream.py @@ -1,10 +1,6 @@ from xdsl.context import MLContext from xdsl.dialects import linalg, memref_stream -from xdsl.dialects.builtin import ( - ArrayAttr, - IntAttr, - ModuleOp, -) +from xdsl.dialects.builtin import ArrayAttr, IntAttr, ModuleOp, UnitAttr from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -47,6 +43,7 @@ def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter) -> No op.inputs, op.outputs, rewriter.move_region_contents_to_new_regions(op.body), + ArrayAttr(tuple(UnitAttr() for _ in range(len(op.outputs)))), op.indexing_maps, iterator_types, bounds, diff --git a/xdsl/transforms/convert_memref_stream_to_loops.py b/xdsl/transforms/convert_memref_stream_to_loops.py index 640cf64c16..91cb5423cf 100644 --- a/xdsl/transforms/convert_memref_stream_to_loops.py +++ b/xdsl/transforms/convert_memref_stream_to_loops.py @@ -2,7 +2,7 @@ from xdsl.context import MLContext from xdsl.dialects import memref, memref_stream, stream -from xdsl.dialects.builtin import AffineMapAttr, ModuleOp +from xdsl.dialects.builtin import AffineMapAttr, ModuleOp, UnitAttr from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -64,6 +64,9 @@ class LowerGenericOpPattern(RewritePattern): def match_and_rewrite( self, op: memref_stream.GenericOp, rewriter: PatternRewriter ) -> None: + if any(not isinstance(init, UnitAttr) for init in op.inits): + raise NotImplementedError("Operation has inits that are not UnitAttr") + outer_ubs, inner_ubs = op.get_static_loop_ranges() if inner_ubs: # Imperfectly nested diff --git a/xdsl/transforms/memref_streamify.py b/xdsl/transforms/memref_streamify.py index 77632c876b..0997e5f6fa 100644 --- a/xdsl/transforms/memref_streamify.py +++ b/xdsl/transforms/memref_streamify.py @@ -3,10 +3,7 @@ from xdsl.context import MLContext from xdsl.dialects import memref, memref_stream, stream -from xdsl.dialects.builtin import ( - ArrayAttr, - ModuleOp, -) +from xdsl.dialects.builtin import ArrayAttr, ModuleOp, UnitAttr from xdsl.ir import Attribute, Block, Region from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -26,6 +23,15 @@ class StreamifyGenericOpPattern(RewritePattern): def match_and_rewrite( self, op: memref_stream.GenericOp, rewriter: PatternRewriter ) -> None: + if any(isinstance(operand.type, stream.StreamType) for operand in op.operands): + # Already streamified + return + + if any(not isinstance(init, UnitAttr) for init in op.inits): + raise NotImplementedError( + "Cannot streamify operation that has inits that are not UnitAttr" + ) + # Currently can only stream memrefs that are not inout streamable_input_indices = tuple( (index, cast(memref.MemRefType[Attribute], value_type).element_type) @@ -87,6 +93,7 @@ def match_and_rewrite( new_operands[:input_count], new_operands[input_count:], rewriter.move_region_contents_to_new_regions(op.body), + op.inits, op.indexing_maps, op.iterator_types, op.bounds, From b2acbca8103b1dfa556989887613c95e82521c9f Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 18 Jun 2024 11:40:19 +0100 Subject: [PATCH 05/14] transformations: support constant inits in memref_stream.generic lowering --- .../riscv-backend-paper/bottom_up.mlir | 116 ++++++------------ .../convert_memref_stream_to_loops.mlir | 51 ++++++++ xdsl/transforms/convert_linalg_to_loops.py | 1 + .../convert_memref_stream_to_loops.py | 45 ++++++- xdsl/transforms/loop_nest_lowering_utils.py | 9 +- 5 files changed, 140 insertions(+), 82 deletions(-) diff --git a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir index 5bbd23be07..84f5075cfe 100644 --- a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir @@ -13,32 +13,20 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( ] } ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { ^0(%x_stream : !stream.readable, %y_stream : !stream.readable, %z_stream : !stream.writable): - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index - %c6 = arith.constant 6 : index - - %zero_float = arith.constant 0.0 : f64 - - scf.for %i0 = %c0 to %c1 step %c1 { - scf.for %i1 = %c0 to %c1 step %c1 { - scf.for %i2 = %c0 to %c6 step %c1 { - scf.for %i3 = %c0 to %c6 step %c1 { - %z = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %zero_float) -> (f64) { - %z3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) { - %x = memref_stream.read from %x_stream : f64 - %y = memref_stream.read from %y_stream : f64 - %prod = arith.mulf %x, %y fastmath : f64 - %res = arith.addf %prod, %acc1 fastmath : f64 - scf.yield %res : f64 - } - scf.yield %z3 : f64 - } - - memref_stream.write %z to %z_stream : f64 - } - } - } + memref_stream.generic { + bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"], + inits = [0.0 : f64] + } ins(%x_stream, %y_stream : !stream.readable, !stream.readable) outs(%z_stream : !stream.writable) { + ^0(%x : f64, %y : f64, %acc : f64): + %prod = arith.mulf %x, %y fastmath : f64 + %res = arith.addf %prod, %acc fastmath : f64 + memref_stream.yield %res : f64 } } @@ -413,30 +401,18 @@ func.func public @pooling_nchw_max_d1_s2_3x3( ] } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { ^0(%x_stream : !stream.readable, %y_stream : !stream.writable): - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index - %c7 = arith.constant 7 : index - %c512 = arith.constant 512 : index - - %min_val = arith.constant -10000.0 : f64 - scf.for %i0 = %c0 to %c1 step %c1 { - scf.for %i1 = %c0 to %c1 step %c1 { - scf.for %i2 = %c0 to %c7 step %c1 { - scf.for %i3 = %c0 to %c7 step %c1 { - %y = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %min_val) -> (f64) { - %y3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) { - %x = memref_stream.read from %x_stream : f64 - %res = arith.maximumf %x, %acc1 : f64 - scf.yield %res : f64 - } - scf.yield %y3 : f64 - } - - memref_stream.write %y to %y_stream : f64 - } - } - } + memref_stream.generic { + bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], + inits = [-10000.0 : f64] + } ins(%x_stream : !stream.readable) outs(%y_stream : !stream.writable) { + ^0(%x : f64, %acc : f64): + %res = arith.maximumf %x, %acc : f64 + memref_stream.yield %res : f64 } } @@ -473,8 +449,8 @@ func.func public @pooling_nchw_max_d1_s2_3x3( // CHECK-NEXT: scfgwi t1, 864 // CHECK-NEXT: scfgwi t2, 897 // CHECK-NEXT: csrrsi zero, 1984, 1 -// CHECK-NEXT: li t1, -10000 -// CHECK-NEXT: fcvt.d.w ft3, t1 +// CHECK-NEXT: li t2, -10000 +// CHECK-NEXT: fcvt.d.w ft3, t2 // CHECK-NEXT: li t1, 49 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge @@ -545,30 +521,18 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( ] } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { ^0(%x_stream : !stream.readable, %y_stream : !stream.writable): - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c3 = arith.constant 3 : index - %c7 = arith.constant 7 : index - %c512 = arith.constant 512 : index - - %zero_float = arith.constant 0.0 : f64 - scf.for %i0 = %c0 to %c1 step %c1 { - scf.for %i1 = %c0 to %c1 step %c1 { - scf.for %i2 = %c0 to %c7 step %c1 { - scf.for %i3 = %c0 to %c7 step %c1 { - %y = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %zero_float) -> (f64) { - %y3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) { - %x = memref_stream.read from %x_stream : f64 - %res = arith.addf %x, %acc1 : f64 - scf.yield %res : f64 - } - scf.yield %y3 : f64 - } - - memref_stream.write %y to %y_stream : f64 - } - } - } + memref_stream.generic { + bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], + inits = [0.0 : f64] + } ins(%x_stream : !stream.readable) outs(%y_stream : !stream.writable) { + ^0(%x : f64, %acc : f64): + %res = arith.addf %x, %acc : f64 + memref_stream.yield %res : f64 } } diff --git a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir index 84c993bc19..1a5e9c7fe6 100644 --- a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir +++ b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir @@ -259,4 +259,55 @@ func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref) -> memref< // CHECK-NEXT: func.return %{{.*}} : memref // CHECK-NEXT: } +func.func @main_inits(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64>) -> memref<4x3xf64> { + memref_stream.streaming_region { + patterns = [ + #memref_stream.stride_pattern (d0, d2)>, + #memref_stream.stride_pattern (d2, d1)> + ] + } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) { + ^0(%0 : !stream.readable, %1 : !stream.readable): + memref_stream.generic { + bounds = [#builtin.int<4>, #builtin.int<3>, #builtin.int<2>], + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"], + inits = [0.0 : f64] + } ins(%0, %1 : !stream.readable, !stream.readable) outs(%C : memref<4x3xf64>) { + ^1(%a : f64, %b : f64, %acc_old : f64): + %prod = arith.mulf %a, %b : f64 + %acc_new = arith.addf %acc_old, %prod : f64 + memref_stream.yield %acc_new : f64 + } + } + func.return %C : memref<4x3xf64> +} +// CHECK-NEXT: func.func @main_inits(%{{.*}} : memref<4x2xf64>, %{{.*}} : memref<2x3xf64>, %{{.*}} : memref<4x3xf64>) -> memref<4x3xf64> { +// CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d2)>, #memref_stream.stride_pattern (d2, d1)>]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) { +// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable): +// CHECK-NEXT: %2 = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %{{.*}} = arith.constant 4 : index +// CHECK-NEXT: %{{.*}} = arith.constant 3 : index +// CHECK-NEXT: %{{.*}} = arith.constant 2 : index +// CHECK-NEXT: %{{.*}} = arith.constant 0 : index +// CHECK-NEXT: %{{.*}} = arith.constant 1 : index +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { +// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %2) -> (f64) { +// CHECK-NEXT: %{{.*}} = memref_stream.read from %{{.*}} : f64 +// CHECK-NEXT: %{{.*}} = memref_stream.read from %{{.*}} : f64 +// CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 +// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 +// CHECK-NEXT: scf.yield %{{.*}} : f64 +// CHECK-NEXT: } +// CHECK-NEXT: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<4x3xf64> +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func.return %{{.*}} : memref<4x3xf64> +// CHECK-NEXT: } + // CHECK-NEXT: } diff --git a/xdsl/transforms/convert_linalg_to_loops.py b/xdsl/transforms/convert_linalg_to_loops.py index b2abf00325..8d7ecf9292 100644 --- a/xdsl/transforms/convert_linalg_to_loops.py +++ b/xdsl/transforms/convert_linalg_to_loops.py @@ -20,6 +20,7 @@ def insert_load( + value_index: int, value: SSAValue, affine_map_attr: AffineMapAttr, ind_vars: Sequence[SSAValue], diff --git a/xdsl/transforms/convert_memref_stream_to_loops.py b/xdsl/transforms/convert_memref_stream_to_loops.py index 91cb5423cf..6ea1cbc79f 100644 --- a/xdsl/transforms/convert_memref_stream_to_loops.py +++ b/xdsl/transforms/convert_memref_stream_to_loops.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from xdsl.context import MLContext -from xdsl.dialects import memref, memref_stream, stream +from xdsl.dialects import arith, memref, memref_stream, stream from xdsl.dialects.builtin import AffineMapAttr, ModuleOp, UnitAttr from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass @@ -20,7 +20,8 @@ ) -def insert_load( +def _insert_load( + source_index: int, source: SSAValue, affine_map_attr: AffineMapAttr, ind_vars: Sequence[SSAValue], @@ -64,13 +65,49 @@ class LowerGenericOpPattern(RewritePattern): def match_and_rewrite( self, op: memref_stream.GenericOp, rewriter: PatternRewriter ) -> None: + ins_count = len(op.inputs) if any(not isinstance(init, UnitAttr) for init in op.inits): - raise NotImplementedError("Operation has inits that are not UnitAttr") + insertion_point = InsertPoint.before(op) + constant_ops = tuple( + None if isinstance(attr, UnitAttr) else arith.Constant(attr) + for attr in op.inits + ) + for constant_op in constant_ops: + if constant_op is not None: + rewriter.insert_op(constant_op, insertion_point) + constant_vals = tuple( + None if constant_op is None else constant_op.result + for constant_op in constant_ops + ) + + def insert_load( + source_index: int, + source: SSAValue, + affine_map_attr: AffineMapAttr, + ind_vars: Sequence[SSAValue], + rewriter: PatternRewriter, + insertion_point: InsertPoint, + ) -> SSAValue: + if source_index >= ins_count: + constant_val = constant_vals[source_index - ins_count] + if constant_val is not None: + return constant_val + + return _insert_load( + source_index, + source, + affine_map_attr, + ind_vars, + rewriter, + insertion_point, + ) + + else: + insert_load = _insert_load outer_ubs, inner_ubs = op.get_static_loop_ranges() if inner_ubs: # Imperfectly nested - ins_count = len(op.inputs) rewrite_generic_to_imperfect_loops( rewriter, InsertPoint.before(op), diff --git a/xdsl/transforms/loop_nest_lowering_utils.py b/xdsl/transforms/loop_nest_lowering_utils.py index 8286659ba4..e5d3a95679 100644 --- a/xdsl/transforms/loop_nest_lowering_utils.py +++ b/xdsl/transforms/loop_nest_lowering_utils.py @@ -65,6 +65,7 @@ def indices_for_map( INSERT_LOAD: TypeAlias = Callable[ [ + int, SSAValue, AffineMapAttr, Sequence[SSAValue], @@ -161,6 +162,7 @@ def _insert_load_ops( operands: Sequence[SSAValue], args: Sequence[BlockArgument], insert_load: INSERT_LOAD, + index_increment: int = 0, ) -> Sequence[tuple[int, SSAValue]]: """ Inserts the load operations at the specified insertion point. @@ -172,6 +174,7 @@ def _insert_load_ops( The `affine_map_attrs`, `operands`, and `args` must have the same length. Returns a tuple of integers indicating the locations of the returned values, and the values themselves. + The integer values are incremented by `index_increment`. """ res: list[tuple[int, SSAValue]] = [] for i, (affine_map_attr, operand, arg) in enumerate( @@ -180,13 +183,14 @@ def _insert_load_ops( if not arg.uses: continue res_val = insert_load( + i + index_increment, operand, affine_map_attr, ind_vars, rewriter, insertion_point, ) - res.append((i, res_val)) + res.append((i + index_increment, res_val)) return res @@ -352,6 +356,7 @@ def outer_make_body( outer_load_operands, outer_load_block_args, insert_load, + index_increment=len(inner_load_block_args), ) def inner_make_body( @@ -377,7 +382,7 @@ def inner_make_body( inner_iter_args, strict=True, ): - block.args[i + len(inner_loaded_values)].replace_by(arg) + block.args[i].replace_by(arg) # Replace block argument use with load op results for i, val in inner_loaded_values: From ca1cf95fe4113fcfef0cd14ebb419a09a8b1b11f Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Fri, 21 Jun 2024 09:27:12 +0100 Subject: [PATCH 06/14] transformations: memref_streamify handle constant inits --- .../riscv-backend-paper/bottom_up.mlir | 95 +++++++------------ .../transforms/memref_streamify.mlir | 38 ++++++++ xdsl/transforms/memref_streamify.py | 7 +- 3 files changed, 75 insertions(+), 65 deletions(-) diff --git a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir index 84f5075cfe..acb9946bd8 100644 --- a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir @@ -5,29 +5,20 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( %Y: memref<1x1x3x3xf64>, %Z: memref<1x1x6x6xf64> ) -> () { - memref_stream.streaming_region { - patterns = [ - #memref_stream.stride_pattern (d0, d4, d2 + d5, d3 + d6)>, - #memref_stream.stride_pattern (d1, d4, d5, d6)>, - #memref_stream.stride_pattern (d0, d1, d2, d3)> - ] + memref_stream.generic { + bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"], + inits = [0.0 : f64] } ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { - ^0(%x_stream : !stream.readable, %y_stream : !stream.readable, %z_stream : !stream.writable): - memref_stream.generic { - bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], - indexing_maps = [ - affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, - affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, - affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> - ], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"], - inits = [0.0 : f64] - } ins(%x_stream, %y_stream : !stream.readable, !stream.readable) outs(%z_stream : !stream.writable) { - ^0(%x : f64, %y : f64, %acc : f64): - %prod = arith.mulf %x, %y fastmath : f64 - %res = arith.addf %prod, %acc fastmath : f64 - memref_stream.yield %res : f64 - } + ^0(%x : f64, %y : f64, %acc : f64): + %prod = arith.mulf %x, %y fastmath : f64 + %res = arith.addf %prod, %acc fastmath : f64 + memref_stream.yield %res : f64 } func.return @@ -394,26 +385,18 @@ func.func public @pooling_nchw_max_d1_s2_3x3( %X: memref<1x1x16x16xf64>, %Y: memref<1x1x7x7xf64> ) -> () { - memref_stream.streaming_region { - patterns = [ - #memref_stream.stride_pattern (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, - #memref_stream.stride_pattern (d0, d1, d2, d3)> - ] + memref_stream.generic { + bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], + inits = [-10000.0 : f64] } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { - ^0(%x_stream : !stream.readable, %y_stream : !stream.writable): - memref_stream.generic { - bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], - indexing_maps = [ - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, - affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> - ], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], - inits = [-10000.0 : f64] - } ins(%x_stream : !stream.readable) outs(%y_stream : !stream.writable) { - ^0(%x : f64, %acc : f64): - %res = arith.maximumf %x, %acc : f64 - memref_stream.yield %res : f64 - } + ^0(%x : f64, %acc : f64): + %res = arith.maximumf %x, %acc : f64 + memref_stream.yield %res : f64 } func.return @@ -514,26 +497,18 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( %X: memref<1x1x16x16xf64>, %Y: memref<1x1x7x7xf64> ) -> () { - memref_stream.streaming_region { - patterns = [ - #memref_stream.stride_pattern (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, - #memref_stream.stride_pattern (d0, d1, d2, d3)> - ] + memref_stream.generic { + bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], + inits = [0.0 : f64] } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { - ^0(%x_stream : !stream.readable, %y_stream : !stream.writable): - memref_stream.generic { - bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], - indexing_maps = [ - affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, - affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> - ], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], - inits = [0.0 : f64] - } ins(%x_stream : !stream.readable) outs(%y_stream : !stream.writable) { - ^0(%x : f64, %acc : f64): - %res = arith.addf %x, %acc : f64 - memref_stream.yield %res : f64 - } + ^0(%x : f64, %acc : f64): + %res = arith.addf %x, %acc : f64 + memref_stream.yield %res : f64 } func.return diff --git a/tests/filecheck/transforms/memref_streamify.mlir b/tests/filecheck/transforms/memref_streamify.mlir index 287a0fb193..153b60fbc5 100644 --- a/tests/filecheck/transforms/memref_streamify.mlir +++ b/tests/filecheck/transforms/memref_streamify.mlir @@ -91,4 +91,42 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: func.return // CHECK-NEXT: } + func.func public @conv_2d_nchw_fchw_d1_s1_3x3( + %X : memref<1x1x8x8xf64>, + %Y : memref<1x1x3x3xf64>, + %Z : memref<1x1x6x6xf64> + ) { + memref_stream.generic { + bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"], + inits = [0.000000e+00 : f64] + } ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { + ^0(%x : f64, %y : f64, %acc : f64): + %prod = arith.mulf %x, %y fastmath : f64 + %res = arith.addf %prod, %acc fastmath : f64 + memref_stream.yield %res : f64 + } + + func.return + } + +// CHECK-NEXT: func.func public @conv_2d_nchw_fchw_d1_s1_3x3(%X : memref<1x1x8x8xf64>, %Y : memref<1x1x3x3xf64>, %Z : memref<1x1x6x6xf64>) { +// CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d4, (d2 + d5), (d3 + d6))>, #memref_stream.stride_pattern (d1, d4, d5, d6)>, #memref_stream.stride_pattern (d0, d1, d2, d3)>]} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { +// CHECK-NEXT: ^0(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable, %{{.*}} : !stream.writable): +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, (d2 + d5), (d3 + d6))>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"], inits = [0.000000e+00 : f64]} ins(%{{.*}}, %{{.*}} : !stream.readable, !stream.readable) outs(%{{.*}} : !stream.writable) { +// CHECK-NEXT: ^{{\d+}}(%x : f64, %y : f64, %acc : f64): +// CHECK-NEXT: %prod = arith.mulf %x, %y fastmath : f64 +// CHECK-NEXT: %res = arith.addf %prod, %acc fastmath : f64 +// CHECK-NEXT: memref_stream.yield %res : f64 +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK-NEXT: func.return +// CHECK-NEXT: } + + // CHECK-NEXT: } diff --git a/xdsl/transforms/memref_streamify.py b/xdsl/transforms/memref_streamify.py index 0997e5f6fa..1a657c1709 100644 --- a/xdsl/transforms/memref_streamify.py +++ b/xdsl/transforms/memref_streamify.py @@ -27,10 +27,7 @@ def match_and_rewrite( # Already streamified return - if any(not isinstance(init, UnitAttr) for init in op.inits): - raise NotImplementedError( - "Cannot streamify operation that has inits that are not UnitAttr" - ) + init_values = tuple(not isinstance(init, UnitAttr) for init in op.inits) # Currently can only stream memrefs that are not inout streamable_input_indices = tuple( @@ -43,7 +40,7 @@ def match_and_rewrite( (index, cast(memref.MemRefType[Attribute], value_type).element_type) for index, value in enumerate(op.outputs) if isinstance(value_type := value.type, memref.MemRefType) - if not op.body.block.args[index + input_count].uses + if init_values[index] or not op.body.block.args[index + input_count].uses ) if not streamable_input_indices and not streamable_output_indices: # No memrefs to convert to streams From cbadcd01d90a13cf6aca03d81749d862dedfbb47 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 24 Jun 2024 10:14:01 +0100 Subject: [PATCH 07/14] inits now values not attributes --- .../filecheck/dialects/memref_stream/ops.mlir | 31 ++-- .../dialects/memref_stream/verify.mlir | 17 +-- .../riscv-backend-paper/bottom_up.mlir | 26 ++-- .../convert_linalg_to_memref_stream.mlir | 6 +- .../convert_memref_stream_to_loops.mlir | 26 ++-- .../memref_stream_unnest_out_parameters.mlir | 5 +- .../transforms/memref_streamify.mlir | 24 ++- .../test_memref_stream_interpreter.py | 19 ++- xdsl/dialects/memref_stream.py | 138 ++++++++++++------ xdsl/interpreters/memref_stream.py | 25 ++-- .../convert_linalg_to_memref_stream.py | 5 +- .../convert_memref_stream_to_loops.py | 17 +-- xdsl/transforms/memref_streamify.py | 11 +- 13 files changed, 189 insertions(+), 161 deletions(-) diff --git a/tests/filecheck/dialects/memref_stream/ops.mlir b/tests/filecheck/dialects/memref_stream/ops.mlir index 3185961daf..b1fef4c430 100644 --- a/tests/filecheck/dialects/memref_stream/ops.mlir +++ b/tests/filecheck/dialects/memref_stream/ops.mlir @@ -50,19 +50,18 @@ memref_stream.generic { affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {hello = "world"} { ^bb0(%arg3: f32, %arg4: f32): memref_stream.yield %arg3 : f32 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {"hello" = "world"} { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {"hello" = "world"} { // CHECK-NEXT: ^1(%arg3 : f32, %arg4 : f32): // CHECK-NEXT: memref_stream.yield %arg3 : f32 // CHECK-NEXT: } -// CHECK-GENERIC-NEXT: "memref_stream.generic"(%A, %B, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "inits" = [unit], "indexing_maps" = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%A, %B, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "init_indices" = [], "indexing_maps" = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^1(%arg3 : f32, %arg4 : f32): // CHECK-GENERIC-NEXT: "memref_stream.yield"(%arg3) : (f32) -> () // CHECK-GENERIC-NEXT: }) {"hello" = "world"} : (memref<2xf32>, memref<3xf32>, memref<3x2xf64>) -> () @@ -73,26 +72,25 @@ memref_stream.generic { affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%D : f64) outs(%C : memref<3x2xf64>) { ^bb0(%d : f64, %c : f64): memref_stream.yield %d : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%D : f64) outs(%C : memref<3x2xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%D : f64) outs(%C : memref<3x2xf64>) { // CHECK-NEXT: ^2(%d : f64, %c_1 : f64): // CHECK-NEXT: memref_stream.yield %d : f64 // CHECK-NEXT: } -// CHECK-GENERIC-NEXT: "memref_stream.generic"(%D, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "inits" = [unit], "indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%D, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "init_indices" = [], "indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^2(%d : f64, %c_1 : f64): // CHECK-GENERIC-NEXT: "memref_stream.yield"(%d) : (f64) -> () // CHECK-GENERIC-NEXT: }) : (f64, memref<3x2xf64>) -> () -%E, %F, %G = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -// CHECK-NEXT: %E, %F, %G = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -// CHECK-GENERIC-NEXT: %E, %F, %G = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) +%E, %F, %G, %H = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) +// CHECK-NEXT: %E, %F, %G, %H = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) +// CHECK-GENERIC-NEXT: %E, %F, %G, %H = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) memref_stream.generic { bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], @@ -101,28 +99,27 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [0.000000e+00 : f64] -} ins(%E, %F : memref<4x2xf64>, memref<2x3xf64>) outs(%G : memref<4x3xf64>) { + iterator_types = ["parallel", "parallel", "reduction"] +} ins(%E, %F : memref<4x2xf64>, memref<2x3xf64>) outs(%G : memref<4x3xf64>) inits(%H : f64) { ^0(%e : f64, %f : f64, %acc_old : f64): %prod = arith.mulf %e, %f : f64 %acc_new = arith.addf %acc_old, %prod : f64 linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], inits = [0.000000e+00 : f64]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) inits(%H : f64) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 // CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 // CHECK-NEXT: linalg.yield %{{.*}} : f64 // CHECK-NEXT: } -// CHECK-GENERIC-NEXT: "memref_stream.generic"(%{{.*}}, %{{.*}}, %{{.*}}) <{"bounds" = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], "inits" = [0.000000e+00 : f64], "indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%{{.*}}, %{{.*}}, %{{.*}}, %H) <{"bounds" = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], "init_indices" = [#builtin.int<0>], "indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 // CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 // CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f64) -> () -// CHECK-GENERIC-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -> () +// CHECK-GENERIC-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) -> () // CHECK-NEXT: } // CHECK-GENERIC-NEXT: }) : () -> () diff --git a/tests/filecheck/dialects/memref_stream/verify.mlir b/tests/filecheck/dialects/memref_stream/verify.mlir index 20486c8550..ec5399136f 100644 --- a/tests/filecheck/dialects/memref_stream/verify.mlir +++ b/tests/filecheck/dialects/memref_stream/verify.mlir @@ -9,8 +9,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)> ], - iterator_types = ["parallel", "reduction", "parallel"], - inits = [unit] + iterator_types = ["parallel", "reduction", "parallel"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -30,8 +29,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d1, d2)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -39,7 +37,7 @@ memref_stream.generic { memref_stream.yield %acc_new : f64 } -// CHECK: Operation does not verify: The number of affine maps must match the number of operands +// CHECK: Operation does not verify: The number of affine maps must match the number of inputs and outputs // ----- @@ -52,8 +50,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -74,8 +71,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -97,8 +93,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit, unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C, %D : memref<4x3xf64>, memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old0 : f64, %acc_old1 : f64): %prod = arith.mulf %a, %b : f64 diff --git a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir index acb9946bd8..30e8b6cc0f 100644 --- a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir @@ -5,6 +5,7 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( %Y: memref<1x1x3x3xf64>, %Z: memref<1x1x6x6xf64> ) -> () { + %zero_float = arith.constant 0.0 : f64 memref_stream.generic { bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [ @@ -12,9 +13,8 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> ], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"], - inits = [0.0 : f64] - } ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] + } ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) inits(%zero_float : f64) { ^0(%x : f64, %y : f64, %acc : f64): %prod = arith.mulf %x, %y fastmath : f64 %res = arith.addf %prod, %acc fastmath : f64 @@ -32,6 +32,7 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( // CHECK-NEXT: mv t0, a0 // CHECK-NEXT: mv t1, a1 // CHECK-NEXT: mv t2, a2 +// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t3, 2 // CHECK-NEXT: scfgwi t3, 64 // CHECK-NEXT: li t3, 2 @@ -68,7 +69,6 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( // CHECK-NEXT: scfgwi t1, 833 // CHECK-NEXT: scfgwi t2, 898 // CHECK-NEXT: csrrsi zero, 1984, 1 -// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t1, 36 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge @@ -385,15 +385,15 @@ func.func public @pooling_nchw_max_d1_s2_3x3( %X: memref<1x1x16x16xf64>, %Y: memref<1x1x7x7xf64> ) -> () { + %min_val = arith.constant -10000 : f64 memref_stream.generic { bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> ], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], - inits = [-10000.0 : f64] - } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] + } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) inits(%min_val : f64) { ^0(%x : f64, %acc : f64): %res = arith.maximumf %x, %acc : f64 memref_stream.yield %res : f64 @@ -409,6 +409,8 @@ func.func public @pooling_nchw_max_d1_s2_3x3( // CHECK-NEXT: pooling_nchw_max_d1_s2_3x3: // CHECK-NEXT: mv t1, a0 // CHECK-NEXT: mv t2, a1 +// CHECK-NEXT: li t0, -10000 +// CHECK-NEXT: fcvt.d.w ft3, t0 // CHECK-NEXT: li t0, 2 // CHECK-NEXT: scfgwi t0, 64 // CHECK-NEXT: li t0, 2 @@ -432,8 +434,6 @@ func.func public @pooling_nchw_max_d1_s2_3x3( // CHECK-NEXT: scfgwi t1, 864 // CHECK-NEXT: scfgwi t2, 897 // CHECK-NEXT: csrrsi zero, 1984, 1 -// CHECK-NEXT: li t2, -10000 -// CHECK-NEXT: fcvt.d.w ft3, t2 // CHECK-NEXT: li t1, 49 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge @@ -497,15 +497,15 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( %X: memref<1x1x16x16xf64>, %Y: memref<1x1x7x7xf64> ) -> () { + %zero_float = arith.constant 0.0 : f64 memref_stream.generic { bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [ affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> ], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], - inits = [0.0 : f64] - } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] + } ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) inits(%zero_float : f64) { ^0(%x : f64, %acc : f64): %res = arith.addf %x, %acc : f64 memref_stream.yield %res : f64 @@ -521,6 +521,7 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( // CHECK-NEXT: pooling_nchw_sum_d1_s2_3x3: // CHECK-NEXT: mv t1, a0 // CHECK-NEXT: mv t2, a1 +// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t0, 2 // CHECK-NEXT: scfgwi t0, 64 // CHECK-NEXT: li t0, 2 @@ -544,7 +545,6 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( // CHECK-NEXT: scfgwi t1, 864 // CHECK-NEXT: scfgwi t2, 897 // CHECK-NEXT: csrrsi zero, 1984, 1 -// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t1, 49 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge diff --git a/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir b/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir index 86bd539b1b..707fd13e6e 100644 --- a/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir +++ b/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir @@ -22,7 +22,7 @@ linalg.generic { %acc_new = arith.addf %acc_old, %prod : f64 linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [], indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = [], inits = [unit]} ins(%A, %B : memref, memref) outs(%C : memref) { +// CHECK-NEXT: memref_stream.generic {bounds = [], indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%A, %B : memref, memref) outs(%C : memref) { // CHECK-NEXT: ^0(%a : f64, %b : f64, %acc_old : f64): // CHECK-NEXT: %prod = arith.mulf %a, %b : f64 // CHECK-NEXT: %acc_new = arith.addf %acc_old, %prod : f64 @@ -44,7 +44,7 @@ linalg.generic { linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<2>, #builtin.int<3>, #builtin.int<4>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>], iterator_types = ["parallel", "parallel", "reduction"], inits = [unit]} ins(%D, %E : memref<2x3xf64>, memref<3x4xf64>) outs(%F : memref<2x4xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<2>, #builtin.int<3>, #builtin.int<4>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%D, %E : memref<2x3xf64>, memref<3x4xf64>) outs(%F : memref<2x4xf64>) { // CHECK-NEXT: ^1(%d : f64, %e : f64, %acc_old_1 : f64): // CHECK-NEXT: %prod_1 = arith.mulf %d, %e : f64 // CHECK-NEXT: %acc_new_1 = arith.addf %acc_old_1, %prod_1 : f64 @@ -65,7 +65,7 @@ linalg.generic { linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ((d0 + d1))>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], inits = [unit]} ins(%G, %H : memref<4xf64>, memref<2xf64>) outs(%I : memref<3xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ((d0 + d1))>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%G, %H : memref<4xf64>, memref<2xf64>) outs(%I : memref<3xf64>) { // CHECK-NEXT: ^2(%g : f64, %h : f64, %acc_old_2 : f64): // CHECK-NEXT: %prod_2 = arith.mulf %g, %h : f64 // CHECK-NEXT: %acc_new_2 = arith.addf %acc_old_2, %prod_2 : f64 diff --git a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir index 1a5e9c7fe6..2f2d7bdfba 100644 --- a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir +++ b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir @@ -11,7 +11,7 @@ ] } ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { ^0(%0 : !stream.readable, %1 : !stream.readable, %2 : !stream.writable): - memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { + memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { ^1(%in : f64, %in_0 : f64, %out : f64): %3 = arith.addf %in, %in_0 : f64 memref_stream.yield %3 : f64 @@ -48,7 +48,7 @@ ] } ins(%arg0_1 : memref<16x16xf64>) outs(%arg1_1 : memref<16x16xf64>) { ^2(%4 : !stream.readable, %5 : !stream.writable): - memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%4 : !stream.readable) outs(%5 : !stream.writable) { + memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%4 : !stream.readable) outs(%5 : !stream.writable) { ^3(%in_1 : f64, %out_1 : f64): %6 = arith.maximumf %in_1, %cst : f64 memref_stream.yield %6 : f64 @@ -90,8 +90,7 @@ func.func public @fill(%arg0 : memref<16x16xf64>) -> memref<16x16xf64> { affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%zero : f64) outs(%7 : !stream.writable) { ^4(%in: f64, %out: f64): memref_stream.yield %in : f64 @@ -132,8 +131,7 @@ func.func @main(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64> affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%0, %1 : !stream.readable, !stream.readable) outs(%C : memref<4x3xf64>) { ^1(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -181,8 +179,7 @@ func.func @elide_affine(%A : memref<6xf64>, %B : memref) -> memref { affine_map<(d0, d1) -> (d0 * 3 + d1)>, affine_map<(d0, d1) -> ()> ], - iterator_types = ["parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "reduction"] } ins(%0 : !stream.readable) outs(%B : memref) { ^1(%a : f64, %acc_old : f64): %acc_new = arith.addf %acc_old, %a : f64 @@ -223,8 +220,7 @@ func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref) -> memref< affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<() -> ()> ], - iterator_types = ["reduction", "reduction", "reduction"], - inits = [unit] + iterator_types = ["reduction", "reduction", "reduction"] } ins(%0 : !stream.readable) outs(%B : memref) { ^1(%a : f64, %acc_old : f64): %acc_new = arith.addf %acc_old, %a : f64 @@ -260,6 +256,7 @@ func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref) -> memref< // CHECK-NEXT: } func.func @main_inits(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64>) -> memref<4x3xf64> { + %zero_float = arith.constant 0.000000e+00 : f64 memref_stream.streaming_region { patterns = [ #memref_stream.stride_pattern (d0, d2)>, @@ -274,9 +271,8 @@ func.func @main_inits(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [0.0 : f64] - } ins(%0, %1 : !stream.readable, !stream.readable) outs(%C : memref<4x3xf64>) { + iterator_types = ["parallel", "parallel", "reduction"] + } ins(%0, %1 : !stream.readable, !stream.readable) outs(%C : memref<4x3xf64>) inits(%zero_float : f64) { ^1(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 %acc_new = arith.addf %acc_old, %prod : f64 @@ -286,9 +282,9 @@ func.func @main_inits(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x func.return %C : memref<4x3xf64> } // CHECK-NEXT: func.func @main_inits(%{{.*}} : memref<4x2xf64>, %{{.*}} : memref<2x3xf64>, %{{.*}} : memref<4x3xf64>) -> memref<4x3xf64> { +// CHECK-NEXT: %zero_float = arith.constant 0.000000e+00 : f64 // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d2)>, #memref_stream.stride_pattern (d2, d1)>]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable): -// CHECK-NEXT: %2 = arith.constant 0.000000e+00 : f64 // CHECK-NEXT: %{{.*}} = arith.constant 4 : index // CHECK-NEXT: %{{.*}} = arith.constant 3 : index // CHECK-NEXT: %{{.*}} = arith.constant 2 : index @@ -296,7 +292,7 @@ func.func @main_inits(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x // CHECK-NEXT: %{{.*}} = arith.constant 1 : index // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %2) -> (f64) { +// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %zero_float) -> (f64) { // CHECK-NEXT: %{{.*}} = memref_stream.read from %{{.*}} : f64 // CHECK-NEXT: %{{.*}} = memref_stream.read from %{{.*}} : f64 // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 diff --git a/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir b/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir index 1878fd9f0e..a585eeac2d 100644 --- a/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir +++ b/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir @@ -10,8 +10,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -21,7 +20,7 @@ memref_stream.generic { // CHECK: builtin.module { // CHECK-NEXT: %{{.*}}, %{{.*}}, %{{.*}} = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], inits = [unit]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 // CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 diff --git a/tests/filecheck/transforms/memref_streamify.mlir b/tests/filecheck/transforms/memref_streamify.mlir index 153b60fbc5..1e60375b8a 100644 --- a/tests/filecheck/transforms/memref_streamify.mlir +++ b/tests/filecheck/transforms/memref_streamify.mlir @@ -10,8 +10,7 @@ func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 memref_stream.generic { bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { ^0(%in : f64, %in_0 : f64, %out : f64): %0 = arith.addf %in, %in_0 : f64 @@ -23,7 +22,7 @@ func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 // CHECK-NEXT: func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 : memref<8x16xf64>) -> memref<8x16xf64> { // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1)>, #memref_stream.stride_pattern (d0, d1)>, #memref_stream.stride_pattern (d0, d1)>]} ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { // CHECK-NEXT: ^0(%0 : !stream.readable, %1 : !stream.readable, %2 : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { // CHECK-NEXT: ^1(%in : f64, %in_1 : f64, %out : f64): // CHECK-NEXT: %3 = arith.addf %in, %in_1 : f64 // CHECK-NEXT: memref_stream.yield %3 : f64 @@ -37,8 +36,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> memref_stream.generic { bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%arg0 : memref<16x16xf64>) outs(%arg1 : memref<16x16xf64>) { ^1(%in_1 : f64, %out_1 : f64): %1 = arith.maximumf %in_1, %cst : f64 @@ -51,7 +49,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1)>, #memref_stream.stride_pattern (d0, d1)>]} ins(%arg0 : memref<16x16xf64>) outs(%arg1 : memref<16x16xf64>) { // CHECK-NEXT: ^0(%0 : !stream.readable, %1 : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%0 : !stream.readable) outs(%1 : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : !stream.readable) outs(%1 : !stream.writable) { // CHECK-NEXT: ^1(%in : f64, %out : f64): // CHECK-NEXT: %2 = arith.maximumf %in, %cst : f64 // CHECK-NEXT: memref_stream.yield %2 : f64 @@ -70,8 +68,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%X : f64) outs(%Y : memref<16x16xf64>) { ^bb0(%d : f64, %c : f64): memref_stream.yield %d : f64 @@ -83,7 +80,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: func.func @fill(%{{.*}} : f64, %{{.*}} : memref<16x16xf64>) { // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1)>]} outs(%{{.*}} : memref<16x16xf64>) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%{{.*}} : f64) outs(%{{.*}} : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%{{.*}} : f64) outs(%{{.*}} : !stream.writable) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: memref_stream.yield %{{.*}} : f64 // CHECK-NEXT: } @@ -96,6 +93,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> %Y : memref<1x1x3x3xf64>, %Z : memref<1x1x6x6xf64> ) { + %zero_float = arith.constant 0.000000e+00 : f64 memref_stream.generic { bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [ @@ -103,9 +101,8 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> ], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"], - inits = [0.000000e+00 : f64] - } ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] + } ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) inits(%zero_float : f64) { ^0(%x : f64, %y : f64, %acc : f64): %prod = arith.mulf %x, %y fastmath : f64 %res = arith.addf %prod, %acc fastmath : f64 @@ -116,9 +113,10 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> } // CHECK-NEXT: func.func public @conv_2d_nchw_fchw_d1_s1_3x3(%X : memref<1x1x8x8xf64>, %Y : memref<1x1x3x3xf64>, %Z : memref<1x1x6x6xf64>) { +// CHECK-NEXT: %zero_float = arith.constant 0.000000e+00 : f64 // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d4, (d2 + d5), (d3 + d6))>, #memref_stream.stride_pattern (d1, d4, d5, d6)>, #memref_stream.stride_pattern (d0, d1, d2, d3)>]} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { // CHECK-NEXT: ^0(%{{.*}} : !stream.readable, %{{.*}} : !stream.readable, %{{.*}} : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, (d2 + d5), (d3 + d6))>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"], inits = [0.000000e+00 : f64]} ins(%{{.*}}, %{{.*}} : !stream.readable, !stream.readable) outs(%{{.*}} : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, (d2 + d5), (d3 + d6))>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%{{.*}}, %{{.*}} : !stream.readable, !stream.readable) outs(%{{.*}} : !stream.writable) inits(%zero_float : f64) { // CHECK-NEXT: ^{{\d+}}(%x : f64, %y : f64, %acc : f64): // CHECK-NEXT: %prod = arith.mulf %x, %y fastmath : f64 // CHECK-NEXT: %res = arith.addf %prod, %acc fastmath : f64 diff --git a/tests/interpreters/test_memref_stream_interpreter.py b/tests/interpreters/test_memref_stream_interpreter.py index a74e96673c..3448e7b78f 100644 --- a/tests/interpreters/test_memref_stream_interpreter.py +++ b/tests/interpreters/test_memref_stream_interpreter.py @@ -4,11 +4,9 @@ AffineMapAttr, ArrayAttr, Float32Type, - FloatAttr, IntAttr, MemRefType, ModuleOp, - UnitAttr, i32, ) from xdsl.interpreter import Interpreter @@ -32,8 +30,8 @@ def test_memref_stream_generic(): TestSSAValue(MemRefType(i32, [3, 2])), ), (TestSSAValue(MemRefType(i32, [1, 6])),), + (), Region(Block(arg_types=(i32, i32))), - ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.identity(2)), @@ -57,6 +55,7 @@ def test_memref_stream_generic(): ) ), ArrayAttr((IntAttr(2), IntAttr(3))), + ArrayAttr(()), ) with ImplicitBuilder(op.body) as (a, b): @@ -85,8 +84,8 @@ def test_memref_stream_generic_scalar(): TestSSAValue(i32), ), (TestSSAValue(MemRefType(i32, [1, 6])),), + (), Region(Block(arg_types=(i32, i32))), - ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.identity(2)), @@ -110,6 +109,7 @@ def test_memref_stream_generic_scalar(): ) ), ArrayAttr((IntAttr(2), IntAttr(3))), + ArrayAttr(()), ) with ImplicitBuilder(op.body) as (a, b): @@ -138,8 +138,8 @@ def test_memref_stream_generic_reduction(): TestSSAValue(MemRefType(i32, [3])), ), (TestSSAValue(MemRefType(i32, [])),), + (), Region(Block(arg_types=(i32, i32, i32))), - ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.identity(1)), @@ -149,6 +149,7 @@ def test_memref_stream_generic_reduction(): ), ArrayAttr((memref_stream.IteratorTypeAttr.reduction(),)), ArrayAttr((IntAttr(3),)), + ArrayAttr(()), ) with ImplicitBuilder(op.body) as (lhs, rhs, acc): @@ -180,8 +181,8 @@ def test_memref_stream_generic_imperfect_nesting(): TestSSAValue(MemRefType(f32, [2, 3])), ), (TestSSAValue(MemRefType(f32, [3, 3])),), + (), Region(Block(arg_types=(f32, f32, f32))), - ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.from_callable(lambda n, m, k: (n, k))), @@ -197,6 +198,7 @@ def test_memref_stream_generic_imperfect_nesting(): ) ), ArrayAttr((IntAttr(3), IntAttr(3), IntAttr(2))), + ArrayAttr(()), ) with ImplicitBuilder(op.body) as (lhs, rhs, acc): @@ -230,8 +232,8 @@ def test_memref_stream_generic_reduction_with_initial_value(): TestSSAValue(MemRefType(f32, [2, 3])), ), (TestSSAValue(MemRefType(f32, [3, 3])),), + (TestSSAValue(f32),), Region(Block(arg_types=(f32, f32, f32))), - ArrayAttr((FloatAttr(0.5, f32),)), ArrayAttr( ( AffineMapAttr(AffineMap.from_callable(lambda n, m, k: (n, k))), @@ -247,6 +249,7 @@ def test_memref_stream_generic_reduction_with_initial_value(): ) ), ArrayAttr((IntAttr(3), IntAttr(3), IntAttr(2))), + ArrayAttr((IntAttr(0),)), ) with ImplicitBuilder(op.body) as (lhs, rhs, acc): @@ -260,7 +263,7 @@ def test_memref_stream_generic_reduction_with_initial_value(): b = ShapedArray(TypedPtr.new_float32([4.0, 3.0, 5.0, 1.0, 2.0, 8.0]), [2, 3]) c = ShapedArray(TypedPtr.new_float32([0.0] * 9), [3, 3]) - interpreter.run_op(op, (a, b, c)) + interpreter.run_op(op, (a, b, c, 0.5)) assert c == ShapedArray( TypedPtr.new_float32([6.5, 7.5, 21.5, 16.5, 17.5, 47.5, 26.5, 27.5, 73.5]), [3, 3], diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index ba7a02eb1d..c81891045f 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -17,12 +17,9 @@ from xdsl.dialects import memref, stream from xdsl.dialects.builtin import ( AffineMapAttr, - AnyFloatAttr, - AnyIntegerAttr, ArrayAttr, IntAttr, StringAttr, - UnitAttr, ) from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( @@ -47,7 +44,6 @@ from xdsl.printer import Printer from xdsl.traits import IsTerminator, NoTerminator from xdsl.utils.exceptions import VerifyException -from xdsl.utils.hints import isa from xdsl.utils.str_enum import StrEnum @@ -319,11 +315,10 @@ class GenericOp(IRDLOperation): pattern defines the order in which the elements of the input buffers will be written to. """ - inits = prop_def(ArrayAttr[AnyFloatAttr | AnyIntegerAttr | UnitAttr]) + inits = var_operand_def() """ - Initial values for outputs. If `NoneAttr`, then the value is read from the output - buffer. Otherwise, the value is created at runtime with an `arith.constant` operation - during lowering. The inits may be set only for the imperfectly nested form. + Initial values for outputs. The outputs are at corresponding `init_indices`. The inits + may be set only for the imperfectly nested form. """ indexing_maps = prop_def(ArrayAttr[AffineMapAttr]) """ @@ -337,6 +332,10 @@ class GenericOp(IRDLOperation): """ iterator_types = prop_def(ArrayAttr[IteratorTypeAttr]) + init_indices = prop_def(ArrayAttr[IntAttr]) + """ + Indices into the `outputs` that correspond to the initial values in `inits`. + """ body: Region = region_def("single_block") @@ -346,11 +345,12 @@ def __init__( self, inputs: Sequence[SSAValue], outputs: Sequence[SSAValue], + inits: Sequence[SSAValue], body: Region, - inits: ArrayAttr[AnyFloatAttr | AnyIntegerAttr | UnitAttr], indexing_maps: ArrayAttr[AffineMapAttr], iterator_types: ArrayAttr[Attribute], bounds: ArrayAttr[IntAttr], + init_indices: ArrayAttr[IntAttr], ) -> None: for m in indexing_maps: if m.data.num_symbols: @@ -358,12 +358,12 @@ def __init__( f"Symbols currently not implemented in {self.name} indexing maps" ) super().__init__( - operands=[inputs, outputs], + operands=[inputs, outputs, inits], properties={ "bounds": bounds, - "inits": inits, - "indexing_maps": ArrayAttr(indexing_maps), - "iterator_types": ArrayAttr(iterator_types), + "init_indices": init_indices, + "indexing_maps": indexing_maps, + "iterator_types": iterator_types, }, regions=[body], ) @@ -386,6 +386,14 @@ def get_static_loop_ranges( tuple(bound.data for bound in self.bounds.data[min_dims:]), ) + def _print_init(self, printer: Printer, init: SSAValue | None): + if init is None: + printer.print_string("None") + else: + printer.print_ssa_value(init) + printer.print_string(" : ") + printer.print_attribute(init.type) + def print(self, printer: Printer): printer.print_string(" {bounds = ") printer.print_attribute(self.bounds) @@ -397,12 +405,6 @@ def print(self, printer: Printer): lambda iterator_type: printer.print_string_literal(iterator_type.data), ) printer.print_string("]") - printer.print_string(", inits = [") - printer.print_list( - self.inits, - lambda val: printer.print_attribute(val), - ) - printer.print_string("]") printer.print_string("}") if self.inputs: @@ -419,6 +421,18 @@ def print(self, printer: Printer): printer.print_list((o.type for o in self.outputs), printer.print_attribute) printer.print_string(")") + if self.inits: + printer.print_string(" inits(") + init_indices = set(attr.data for attr in self.init_indices) + inits = [ + val if i in init_indices else None for i, val in enumerate(self.inits) + ] + printer.print_list( + inits, + lambda val: self._print_init(printer, val), + ) + printer.print_string(")") + extra_attrs = self.attributes.copy() if "indexing_maps" in extra_attrs: del extra_attrs["indexing_maps"] @@ -438,6 +452,35 @@ def print(self, printer: Printer): printer.print_string(" ") printer.print_region(self.body) + @classmethod + def _parse_init(cls, parser: Parser) -> SSAValue | None: + if parser.parse_optional_characters("None"): + return None + unresolved = parser.parse_unresolved_operand() + parser.parse_punctuation(":") + type = parser.parse_type() + return parser.resolve_operand(unresolved, type) + + @classmethod + def _parse_inits( + cls, parser: Parser + ) -> tuple[tuple[SSAValue, ...], tuple[int, ...]]: + if not parser.parse_optional_characters("inits"): + return ((), ()) + + parser.parse_punctuation("(") + optional_inits = parser.parse_comma_separated_list( + Parser.Delimiter.NONE, lambda: cls._parse_init(parser) + ) + parser.parse_punctuation(")") + enumerated_inits = tuple( + (i, val) for i, val in enumerate(optional_inits) if val is not None + ) + inits = tuple(init for _, init in enumerated_inits) + init_indices = tuple(i for i, _ in enumerated_inits) + + return (tuple(inits), init_indices) + @classmethod def parse(cls, parser: Parser) -> Self: attrs_start_pos = parser.pos @@ -499,18 +542,6 @@ def parse(cls, parser: Parser) -> Self: attrs_end_pos, ) - if "inits" in attrs: - inits = attrs["inits"] - if not isa(inits, ArrayAttr[AnyFloatAttr | AnyIntegerAttr | UnitAttr]): - parser.raise_error("Expected inits for memref_stream.generic") - del attrs["inits"] - else: - parser.raise_error( - "Expected inits for memref_stream.generic", - attrs_start_pos, - attrs_end_pos, - ) - if "doc" in attrs: doc = attrs["doc"] assert isinstance(doc, StringAttr) @@ -553,8 +584,11 @@ def parse(cls, parser: Parser) -> Self: parser.parse_punctuation(")") outs = parser.resolve_operands(unresolved_outs, outs_types, pos) else: + outs_types = () outs = () + inits, init_indices = cls._parse_inits(parser) + if parser.parse_optional_keyword("attrs"): parser.parse_punctuation("=") extra_attrs = parser.expect( @@ -568,11 +602,12 @@ def parse(cls, parser: Parser) -> Self: generic = cls( ins, outs, - body, inits, + body, indexing_maps, ArrayAttr(iterator_types), bounds, + ArrayAttr(IntAttr(index) for index in init_indices), ) generic.attributes |= attrs generic.attributes |= extra_attrs @@ -580,11 +615,9 @@ def parse(cls, parser: Parser) -> Self: return generic def verify_(self) -> None: - # Verify that the number of initial values for outputs is the same as the number - # of outputs - if len(self.inits) != len(self.outputs): + if len(self.inits) != len(self.init_indices): raise VerifyException( - f"Mismatching number of outputs and initial values: {len(self.outputs)} != {self.inits}" + f"Mismatching number of inits and init indices: {len(self.inits)} != {self.init_indices}" ) # Parallel iterator types must preceed reduction iterators @@ -595,9 +628,9 @@ def verify_(self) -> None: f"Unexpected order of iterator types: {[it.data.value for it in iterator_types]}" ) - if len(self.operands) != len(self.indexing_maps): + if len(self.inputs) + len(self.outputs) != len(self.indexing_maps): raise VerifyException( - "The number of affine maps must match the number of operands" + "The number of affine maps must match the number of inputs and outputs" ) # Whether or not the operation represents an imperfect loop nest, verify that the @@ -631,16 +664,25 @@ def verify_(self) -> None: f"{len(iterator_types)} or {num_parallel}" ) - # The non-None values of the inits must correspond to inputs where the domain - # of the affine map has the same number of dimensions as the number of parallel - # iterators - for i, (m, init) in enumerate(zip(output_maps, self.inits, strict=True)): - if init != UnitAttr(): - if m.data.num_dims != num_parallel: - raise VerifyException( - "Incompatible affine map and initial value for output at index " - f"{i}" - ) + if len(self.init_indices) != len(self.inits): + raise VerifyException( + "The number of inits and init_indices must be the same" + ) + + # The values of the inits must correspond to outputs where the domain of the + # affine map has the same number of dimensions as the number of parallel + # iterators. + num_outputs = len(self.outputs) + output_maps = self.indexing_maps.data[-num_outputs:] + for index in self.init_indices: + if not (0 <= index.data <= num_outputs): + raise VerifyException(f"Init index out of bounds: {index.data}") + m = output_maps[index.data] + if m.data.num_dims != num_parallel: + raise VerifyException( + "Incompatible affine map and initial value for output at index " + f"{index}" + ) @irdl_op_definition diff --git a/xdsl/interpreters/memref_stream.py b/xdsl/interpreters/memref_stream.py index 7ba522856d..ce1174cedd 100644 --- a/xdsl/interpreters/memref_stream.py +++ b/xdsl/interpreters/memref_stream.py @@ -2,7 +2,6 @@ from typing import Any, cast from xdsl.dialects import memref_stream -from xdsl.dialects.builtin import UnitAttr from xdsl.interpreter import ( Interpreter, InterpreterFunctions, @@ -26,15 +25,21 @@ def run_generic( ) -> PythonValues: inputs_count = len(op.inputs) + outputs_count = len(op.outputs) - outputs: tuple[ShapedArray[float], ...] = args[inputs_count:] + outputs: tuple[ShapedArray[int | float], ...] = args[ + inputs_count : inputs_count + outputs_count + ] + init_values: tuple[int | float, ...] = args[inputs_count + outputs_count :] indexing_maps = tuple(attr.data for attr in op.indexing_maps) output_indexing_maps = indexing_maps[inputs_count:] outer_ubs, inner_ubs = op.get_static_loop_ranges() - inits = op.inits.data + inits: list[None | int | float] = [None] * len(op.outputs) + for index, init in zip(op.init_indices, init_values, strict=True): + inits[index.data] = init if inner_ubs: inputs: tuple[ShapedArray[float] | float, ...] = args[:inputs_count] @@ -42,14 +47,15 @@ def run_generic( for outer_indices in product(*(range(outer_ub) for outer_ub in outer_ubs)): output_loop_args = tuple( ( - (cast(ShapedArray[int | float], o)).load( - indexing_map.eval(outer_indices, ()) - ) - if isinstance(init, UnitAttr) - else init.value.data + o.load(indexing_map.eval(outer_indices, ())) + if init is None + else init ) for o, indexing_map, init in zip( - outputs, output_indexing_maps, inits, strict=True + outputs, + output_indexing_maps, + inits, + strict=True, ) ) for inner_indices in product( @@ -72,7 +78,6 @@ def run_generic( op.body, input_loop_args + output_loop_args, "for_loop" ) output_loop_args = loop_results - print(output_loop_args, output_indexing_maps, outputs) for res, indexing_map, output in zip( output_loop_args, output_indexing_maps, outputs, strict=True ): diff --git a/xdsl/transforms/convert_linalg_to_memref_stream.py b/xdsl/transforms/convert_linalg_to_memref_stream.py index 8a59ac5c37..e1e9ade527 100644 --- a/xdsl/transforms/convert_linalg_to_memref_stream.py +++ b/xdsl/transforms/convert_linalg_to_memref_stream.py @@ -1,6 +1,6 @@ from xdsl.context import MLContext from xdsl.dialects import linalg, memref_stream -from xdsl.dialects.builtin import ArrayAttr, IntAttr, ModuleOp, UnitAttr +from xdsl.dialects.builtin import ArrayAttr, IntAttr, ModuleOp from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -42,11 +42,12 @@ def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter) -> No memref_stream.GenericOp( op.inputs, op.outputs, + (), rewriter.move_region_contents_to_new_regions(op.body), - ArrayAttr(tuple(UnitAttr() for _ in range(len(op.outputs)))), op.indexing_maps, iterator_types, bounds, + ArrayAttr(()), ) ) diff --git a/xdsl/transforms/convert_memref_stream_to_loops.py b/xdsl/transforms/convert_memref_stream_to_loops.py index 6ea1cbc79f..d5f44449ba 100644 --- a/xdsl/transforms/convert_memref_stream_to_loops.py +++ b/xdsl/transforms/convert_memref_stream_to_loops.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from xdsl.context import MLContext -from xdsl.dialects import arith, memref, memref_stream, stream +from xdsl.dialects import memref, memref_stream, stream from xdsl.dialects.builtin import AffineMapAttr, ModuleOp, UnitAttr from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass @@ -67,18 +67,9 @@ def match_and_rewrite( ) -> None: ins_count = len(op.inputs) if any(not isinstance(init, UnitAttr) for init in op.inits): - insertion_point = InsertPoint.before(op) - constant_ops = tuple( - None if isinstance(attr, UnitAttr) else arith.Constant(attr) - for attr in op.inits - ) - for constant_op in constant_ops: - if constant_op is not None: - rewriter.insert_op(constant_op, insertion_point) - constant_vals = tuple( - None if constant_op is None else constant_op.result - for constant_op in constant_ops - ) + constant_vals: list[SSAValue | None] = [None] * len(op.outputs) + for index, val in zip(op.init_indices, op.inits, strict=True): + constant_vals[index.data] = val def insert_load( source_index: int, diff --git a/xdsl/transforms/memref_streamify.py b/xdsl/transforms/memref_streamify.py index 1a657c1709..769537a91e 100644 --- a/xdsl/transforms/memref_streamify.py +++ b/xdsl/transforms/memref_streamify.py @@ -3,7 +3,7 @@ from xdsl.context import MLContext from xdsl.dialects import memref, memref_stream, stream -from xdsl.dialects.builtin import ArrayAttr, ModuleOp, UnitAttr +from xdsl.dialects.builtin import ArrayAttr, ModuleOp from xdsl.ir import Attribute, Block, Region from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -27,7 +27,7 @@ def match_and_rewrite( # Already streamified return - init_values = tuple(not isinstance(init, UnitAttr) for init in op.inits) + init_indices = set(index.data for index in op.init_indices) # Currently can only stream memrefs that are not inout streamable_input_indices = tuple( @@ -40,7 +40,7 @@ def match_and_rewrite( (index, cast(memref.MemRefType[Attribute], value_type).element_type) for index, value in enumerate(op.outputs) if isinstance(value_type := value.type, memref.MemRefType) - if init_values[index] or not op.body.block.args[index + input_count].uses + if index in init_indices or not op.body.block.args[index + input_count].uses ) if not streamable_input_indices and not streamable_output_indices: # No memrefs to convert to streams @@ -81,7 +81,7 @@ def match_and_rewrite( ) ) new_body = streaming_region_op.body.block - new_operands = list(op.operands) + new_operands = list(op.operands[: len(op.inputs) + len(op.outputs)]) for stream_index, (index, _) in enumerate(streamed_operand_indices): new_operands[index] = new_body.args[stream_index] @@ -89,11 +89,12 @@ def match_and_rewrite( memref_stream.GenericOp( new_operands[:input_count], new_operands[input_count:], - rewriter.move_region_contents_to_new_regions(op.body), op.inits, + rewriter.move_region_contents_to_new_regions(op.body), op.indexing_maps, op.iterator_types, op.bounds, + op.init_indices, ), InsertPoint.at_end(new_body), ) From a2076276f4bf88534e3d42ba17fb8ea7ac218cc7 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 24 Jun 2024 10:14:01 +0100 Subject: [PATCH 08/14] inits now values not attributes --- .../filecheck/dialects/memref_stream/ops.mlir | 31 ++-- .../dialects/memref_stream/verify.mlir | 17 +-- .../riscv-backend-paper/bottom_up.mlir | 15 +- .../convert_linalg_to_memref_stream.mlir | 6 +- .../memref_stream_unnest_out_parameters.mlir | 5 +- .../transforms/memref_streamify.mlir | 15 +- .../test_memref_stream_interpreter.py | 19 ++- xdsl/dialects/memref_stream.py | 138 ++++++++++++------ xdsl/interpreters/memref_stream.py | 25 ++-- .../convert_linalg_to_memref_stream.py | 5 +- xdsl/transforms/memref_streamify.py | 9 +- 11 files changed, 162 insertions(+), 123 deletions(-) diff --git a/tests/filecheck/dialects/memref_stream/ops.mlir b/tests/filecheck/dialects/memref_stream/ops.mlir index 3185961daf..b1fef4c430 100644 --- a/tests/filecheck/dialects/memref_stream/ops.mlir +++ b/tests/filecheck/dialects/memref_stream/ops.mlir @@ -50,19 +50,18 @@ memref_stream.generic { affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {hello = "world"} { ^bb0(%arg3: f32, %arg4: f32): memref_stream.yield %arg3 : f32 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {"hello" = "world"} { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%A, %B : memref<2xf32>, memref<3xf32>) outs(%C : memref<3x2xf64>) attrs = {"hello" = "world"} { // CHECK-NEXT: ^1(%arg3 : f32, %arg4 : f32): // CHECK-NEXT: memref_stream.yield %arg3 : f32 // CHECK-NEXT: } -// CHECK-GENERIC-NEXT: "memref_stream.generic"(%A, %B, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "inits" = [unit], "indexing_maps" = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%A, %B, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "init_indices" = [], "indexing_maps" = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^1(%arg3 : f32, %arg4 : f32): // CHECK-GENERIC-NEXT: "memref_stream.yield"(%arg3) : (f32) -> () // CHECK-GENERIC-NEXT: }) {"hello" = "world"} : (memref<2xf32>, memref<3xf32>, memref<3x2xf64>) -> () @@ -73,26 +72,25 @@ memref_stream.generic { affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%D : f64) outs(%C : memref<3x2xf64>) { ^bb0(%d : f64, %c : f64): memref_stream.yield %d : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%D : f64) outs(%C : memref<3x2xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%D : f64) outs(%C : memref<3x2xf64>) { // CHECK-NEXT: ^2(%d : f64, %c_1 : f64): // CHECK-NEXT: memref_stream.yield %d : f64 // CHECK-NEXT: } -// CHECK-GENERIC-NEXT: "memref_stream.generic"(%D, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "inits" = [unit], "indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%D, %C) <{"bounds" = [#builtin.int<3>, #builtin.int<2>], "init_indices" = [], "indexing_maps" = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^2(%d : f64, %c_1 : f64): // CHECK-GENERIC-NEXT: "memref_stream.yield"(%d) : (f64) -> () // CHECK-GENERIC-NEXT: }) : (f64, memref<3x2xf64>) -> () -%E, %F, %G = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -// CHECK-NEXT: %E, %F, %G = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -// CHECK-GENERIC-NEXT: %E, %F, %G = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) +%E, %F, %G, %H = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) +// CHECK-NEXT: %E, %F, %G, %H = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) +// CHECK-GENERIC-NEXT: %E, %F, %G, %H = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) memref_stream.generic { bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], @@ -101,28 +99,27 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [0.000000e+00 : f64] -} ins(%E, %F : memref<4x2xf64>, memref<2x3xf64>) outs(%G : memref<4x3xf64>) { + iterator_types = ["parallel", "parallel", "reduction"] +} ins(%E, %F : memref<4x2xf64>, memref<2x3xf64>) outs(%G : memref<4x3xf64>) inits(%H : f64) { ^0(%e : f64, %f : f64, %acc_old : f64): %prod = arith.mulf %e, %f : f64 %acc_new = arith.addf %acc_old, %prod : f64 linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], inits = [0.000000e+00 : f64]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) inits(%H : f64) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 // CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 // CHECK-NEXT: linalg.yield %{{.*}} : f64 // CHECK-NEXT: } -// CHECK-GENERIC-NEXT: "memref_stream.generic"(%{{.*}}, %{{.*}}, %{{.*}}) <{"bounds" = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], "inits" = [0.000000e+00 : f64], "indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%{{.*}}, %{{.*}}, %{{.*}}, %H) <{"bounds" = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], "init_indices" = [#builtin.int<0>], "indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ // CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 // CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 // CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f64) -> () -// CHECK-GENERIC-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -> () +// CHECK-GENERIC-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) -> () // CHECK-NEXT: } // CHECK-GENERIC-NEXT: }) : () -> () diff --git a/tests/filecheck/dialects/memref_stream/verify.mlir b/tests/filecheck/dialects/memref_stream/verify.mlir index 20486c8550..ec5399136f 100644 --- a/tests/filecheck/dialects/memref_stream/verify.mlir +++ b/tests/filecheck/dialects/memref_stream/verify.mlir @@ -9,8 +9,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)> ], - iterator_types = ["parallel", "reduction", "parallel"], - inits = [unit] + iterator_types = ["parallel", "reduction", "parallel"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -30,8 +29,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d1, d2)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -39,7 +37,7 @@ memref_stream.generic { memref_stream.yield %acc_new : f64 } -// CHECK: Operation does not verify: The number of affine maps must match the number of operands +// CHECK: Operation does not verify: The number of affine maps must match the number of inputs and outputs // ----- @@ -52,8 +50,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -74,8 +71,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d2)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -97,8 +93,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit, unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C, %D : memref<4x3xf64>, memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old0 : f64, %acc_old1 : f64): %prod = arith.mulf %a, %b : f64 diff --git a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir index 5bbd23be07..6062a79958 100644 --- a/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir +++ b/tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir @@ -5,6 +5,7 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( %Y: memref<1x1x3x3xf64>, %Z: memref<1x1x6x6xf64> ) -> () { + %zero_float = arith.constant 0.0 : f64 memref_stream.streaming_region { patterns = [ #memref_stream.stride_pattern (d0, d4, d2 + d5, d3 + d6)>, @@ -18,8 +19,6 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( %c3 = arith.constant 3 : index %c6 = arith.constant 6 : index - %zero_float = arith.constant 0.0 : f64 - scf.for %i0 = %c0 to %c1 step %c1 { scf.for %i1 = %c0 to %c1 step %c1 { scf.for %i2 = %c0 to %c6 step %c1 { @@ -53,6 +52,7 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( // CHECK-NEXT: mv t0, a0 // CHECK-NEXT: mv t1, a1 // CHECK-NEXT: mv t2, a2 +// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t3, 2 // CHECK-NEXT: scfgwi t3, 64 // CHECK-NEXT: li t3, 2 @@ -89,7 +89,6 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( // CHECK-NEXT: scfgwi t1, 833 // CHECK-NEXT: scfgwi t2, 898 // CHECK-NEXT: csrrsi zero, 1984, 1 -// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t1, 36 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge @@ -406,6 +405,7 @@ func.func public @pooling_nchw_max_d1_s2_3x3( %X: memref<1x1x16x16xf64>, %Y: memref<1x1x7x7xf64> ) -> () { + %min_val = arith.constant -10000.0 : f64 memref_stream.streaming_region { patterns = [ #memref_stream.stride_pattern (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, @@ -419,7 +419,6 @@ func.func public @pooling_nchw_max_d1_s2_3x3( %c7 = arith.constant 7 : index %c512 = arith.constant 512 : index - %min_val = arith.constant -10000.0 : f64 scf.for %i0 = %c0 to %c1 step %c1 { scf.for %i1 = %c0 to %c1 step %c1 { scf.for %i2 = %c0 to %c7 step %c1 { @@ -450,6 +449,8 @@ func.func public @pooling_nchw_max_d1_s2_3x3( // CHECK-NEXT: pooling_nchw_max_d1_s2_3x3: // CHECK-NEXT: mv t1, a0 // CHECK-NEXT: mv t2, a1 +// CHECK-NEXT: li t0, -10000 +// CHECK-NEXT: fcvt.d.w ft3, t0 // CHECK-NEXT: li t0, 2 // CHECK-NEXT: scfgwi t0, 64 // CHECK-NEXT: li t0, 2 @@ -473,8 +474,6 @@ func.func public @pooling_nchw_max_d1_s2_3x3( // CHECK-NEXT: scfgwi t1, 864 // CHECK-NEXT: scfgwi t2, 897 // CHECK-NEXT: csrrsi zero, 1984, 1 -// CHECK-NEXT: li t1, -10000 -// CHECK-NEXT: fcvt.d.w ft3, t1 // CHECK-NEXT: li t1, 49 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge @@ -538,6 +537,7 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( %X: memref<1x1x16x16xf64>, %Y: memref<1x1x7x7xf64> ) -> () { + %zero_float = arith.constant 0.0 : f64 memref_stream.streaming_region { patterns = [ #memref_stream.stride_pattern (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, @@ -551,7 +551,6 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( %c7 = arith.constant 7 : index %c512 = arith.constant 512 : index - %zero_float = arith.constant 0.0 : f64 scf.for %i0 = %c0 to %c1 step %c1 { scf.for %i1 = %c0 to %c1 step %c1 { scf.for %i2 = %c0 to %c7 step %c1 { @@ -582,6 +581,7 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( // CHECK-NEXT: pooling_nchw_sum_d1_s2_3x3: // CHECK-NEXT: mv t1, a0 // CHECK-NEXT: mv t2, a1 +// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t0, 2 // CHECK-NEXT: scfgwi t0, 64 // CHECK-NEXT: li t0, 2 @@ -605,7 +605,6 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( // CHECK-NEXT: scfgwi t1, 864 // CHECK-NEXT: scfgwi t2, 897 // CHECK-NEXT: csrrsi zero, 1984, 1 -// CHECK-NEXT: fcvt.d.w ft3, zero // CHECK-NEXT: li t1, 49 // CHECK-NEXT: mv t0, zero // CHECK-NEXT: # Constant folded riscv_cf.bge diff --git a/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir b/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir index 86bd539b1b..707fd13e6e 100644 --- a/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir +++ b/tests/filecheck/transforms/convert_linalg_to_memref_stream.mlir @@ -22,7 +22,7 @@ linalg.generic { %acc_new = arith.addf %acc_old, %prod : f64 linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [], indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = [], inits = [unit]} ins(%A, %B : memref, memref) outs(%C : memref) { +// CHECK-NEXT: memref_stream.generic {bounds = [], indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%A, %B : memref, memref) outs(%C : memref) { // CHECK-NEXT: ^0(%a : f64, %b : f64, %acc_old : f64): // CHECK-NEXT: %prod = arith.mulf %a, %b : f64 // CHECK-NEXT: %acc_new = arith.addf %acc_old, %prod : f64 @@ -44,7 +44,7 @@ linalg.generic { linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<2>, #builtin.int<3>, #builtin.int<4>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>], iterator_types = ["parallel", "parallel", "reduction"], inits = [unit]} ins(%D, %E : memref<2x3xf64>, memref<3x4xf64>) outs(%F : memref<2x4xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<2>, #builtin.int<3>, #builtin.int<4>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d2)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%D, %E : memref<2x3xf64>, memref<3x4xf64>) outs(%F : memref<2x4xf64>) { // CHECK-NEXT: ^1(%d : f64, %e : f64, %acc_old_1 : f64): // CHECK-NEXT: %prod_1 = arith.mulf %d, %e : f64 // CHECK-NEXT: %acc_new_1 = arith.addf %acc_old_1, %prod_1 : f64 @@ -65,7 +65,7 @@ linalg.generic { linalg.yield %acc_new : f64 } -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ((d0 + d1))>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"], inits = [unit]} ins(%G, %H : memref<4xf64>, memref<2xf64>) outs(%I : memref<3xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<3>, #builtin.int<2>], indexing_maps = [affine_map<(d0, d1) -> ((d0 + d1))>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%G, %H : memref<4xf64>, memref<2xf64>) outs(%I : memref<3xf64>) { // CHECK-NEXT: ^2(%g : f64, %h : f64, %acc_old_2 : f64): // CHECK-NEXT: %prod_2 = arith.mulf %g, %h : f64 // CHECK-NEXT: %acc_new_2 = arith.addf %acc_old_2, %prod_2 : f64 diff --git a/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir b/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir index 1878fd9f0e..a585eeac2d 100644 --- a/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir +++ b/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir @@ -10,8 +10,7 @@ memref_stream.generic { affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) outs(%C : memref<4x3xf64>) { ^0(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -21,7 +20,7 @@ memref_stream.generic { // CHECK: builtin.module { // CHECK-NEXT: %{{.*}}, %{{.*}}, %{{.*}} = "test.op"() : () -> (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>) -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], inits = [unit]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}} : memref<4x3xf64>) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 // CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 diff --git a/tests/filecheck/transforms/memref_streamify.mlir b/tests/filecheck/transforms/memref_streamify.mlir index 287a0fb193..454d9f1da5 100644 --- a/tests/filecheck/transforms/memref_streamify.mlir +++ b/tests/filecheck/transforms/memref_streamify.mlir @@ -10,8 +10,7 @@ func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 memref_stream.generic { bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { ^0(%in : f64, %in_0 : f64, %out : f64): %0 = arith.addf %in, %in_0 : f64 @@ -23,7 +22,7 @@ func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 // CHECK-NEXT: func.func public @dsum(%arg0 : memref<8x16xf64>, %arg1 : memref<8x16xf64>, %arg2 : memref<8x16xf64>) -> memref<8x16xf64> { // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1)>, #memref_stream.stride_pattern (d0, d1)>, #memref_stream.stride_pattern (d0, d1)>]} ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { // CHECK-NEXT: ^0(%0 : !stream.readable, %1 : !stream.readable, %2 : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { // CHECK-NEXT: ^1(%in : f64, %in_1 : f64, %out : f64): // CHECK-NEXT: %3 = arith.addf %in, %in_1 : f64 // CHECK-NEXT: memref_stream.yield %3 : f64 @@ -37,8 +36,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> memref_stream.generic { bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%arg0 : memref<16x16xf64>) outs(%arg1 : memref<16x16xf64>) { ^1(%in_1 : f64, %out_1 : f64): %1 = arith.maximumf %in_1, %cst : f64 @@ -51,7 +49,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1)>, #memref_stream.stride_pattern (d0, d1)>]} ins(%arg0 : memref<16x16xf64>) outs(%arg1 : memref<16x16xf64>) { // CHECK-NEXT: ^0(%0 : !stream.readable, %1 : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%0 : !stream.readable) outs(%1 : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0 : !stream.readable) outs(%1 : !stream.writable) { // CHECK-NEXT: ^1(%in : f64, %out : f64): // CHECK-NEXT: %2 = arith.maximumf %in, %cst : f64 // CHECK-NEXT: memref_stream.yield %2 : f64 @@ -70,8 +68,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%X : f64) outs(%Y : memref<16x16xf64>) { ^bb0(%d : f64, %c : f64): memref_stream.yield %d : f64 @@ -83,7 +80,7 @@ func.func public @relu(%arg0 : memref<16x16xf64>, %arg1 : memref<16x16xf64>) -> // CHECK-NEXT: func.func @fill(%{{.*}} : f64, %{{.*}} : memref<16x16xf64>) { // CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern (d0, d1)>]} outs(%{{.*}} : memref<16x16xf64>) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.writable): -// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%{{.*}} : f64) outs(%{{.*}} : !stream.writable) { +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%{{.*}} : f64) outs(%{{.*}} : !stream.writable) { // CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64): // CHECK-NEXT: memref_stream.yield %{{.*}} : f64 // CHECK-NEXT: } diff --git a/tests/interpreters/test_memref_stream_interpreter.py b/tests/interpreters/test_memref_stream_interpreter.py index a74e96673c..3448e7b78f 100644 --- a/tests/interpreters/test_memref_stream_interpreter.py +++ b/tests/interpreters/test_memref_stream_interpreter.py @@ -4,11 +4,9 @@ AffineMapAttr, ArrayAttr, Float32Type, - FloatAttr, IntAttr, MemRefType, ModuleOp, - UnitAttr, i32, ) from xdsl.interpreter import Interpreter @@ -32,8 +30,8 @@ def test_memref_stream_generic(): TestSSAValue(MemRefType(i32, [3, 2])), ), (TestSSAValue(MemRefType(i32, [1, 6])),), + (), Region(Block(arg_types=(i32, i32))), - ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.identity(2)), @@ -57,6 +55,7 @@ def test_memref_stream_generic(): ) ), ArrayAttr((IntAttr(2), IntAttr(3))), + ArrayAttr(()), ) with ImplicitBuilder(op.body) as (a, b): @@ -85,8 +84,8 @@ def test_memref_stream_generic_scalar(): TestSSAValue(i32), ), (TestSSAValue(MemRefType(i32, [1, 6])),), + (), Region(Block(arg_types=(i32, i32))), - ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.identity(2)), @@ -110,6 +109,7 @@ def test_memref_stream_generic_scalar(): ) ), ArrayAttr((IntAttr(2), IntAttr(3))), + ArrayAttr(()), ) with ImplicitBuilder(op.body) as (a, b): @@ -138,8 +138,8 @@ def test_memref_stream_generic_reduction(): TestSSAValue(MemRefType(i32, [3])), ), (TestSSAValue(MemRefType(i32, [])),), + (), Region(Block(arg_types=(i32, i32, i32))), - ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.identity(1)), @@ -149,6 +149,7 @@ def test_memref_stream_generic_reduction(): ), ArrayAttr((memref_stream.IteratorTypeAttr.reduction(),)), ArrayAttr((IntAttr(3),)), + ArrayAttr(()), ) with ImplicitBuilder(op.body) as (lhs, rhs, acc): @@ -180,8 +181,8 @@ def test_memref_stream_generic_imperfect_nesting(): TestSSAValue(MemRefType(f32, [2, 3])), ), (TestSSAValue(MemRefType(f32, [3, 3])),), + (), Region(Block(arg_types=(f32, f32, f32))), - ArrayAttr((UnitAttr(),)), ArrayAttr( ( AffineMapAttr(AffineMap.from_callable(lambda n, m, k: (n, k))), @@ -197,6 +198,7 @@ def test_memref_stream_generic_imperfect_nesting(): ) ), ArrayAttr((IntAttr(3), IntAttr(3), IntAttr(2))), + ArrayAttr(()), ) with ImplicitBuilder(op.body) as (lhs, rhs, acc): @@ -230,8 +232,8 @@ def test_memref_stream_generic_reduction_with_initial_value(): TestSSAValue(MemRefType(f32, [2, 3])), ), (TestSSAValue(MemRefType(f32, [3, 3])),), + (TestSSAValue(f32),), Region(Block(arg_types=(f32, f32, f32))), - ArrayAttr((FloatAttr(0.5, f32),)), ArrayAttr( ( AffineMapAttr(AffineMap.from_callable(lambda n, m, k: (n, k))), @@ -247,6 +249,7 @@ def test_memref_stream_generic_reduction_with_initial_value(): ) ), ArrayAttr((IntAttr(3), IntAttr(3), IntAttr(2))), + ArrayAttr((IntAttr(0),)), ) with ImplicitBuilder(op.body) as (lhs, rhs, acc): @@ -260,7 +263,7 @@ def test_memref_stream_generic_reduction_with_initial_value(): b = ShapedArray(TypedPtr.new_float32([4.0, 3.0, 5.0, 1.0, 2.0, 8.0]), [2, 3]) c = ShapedArray(TypedPtr.new_float32([0.0] * 9), [3, 3]) - interpreter.run_op(op, (a, b, c)) + interpreter.run_op(op, (a, b, c, 0.5)) assert c == ShapedArray( TypedPtr.new_float32([6.5, 7.5, 21.5, 16.5, 17.5, 47.5, 26.5, 27.5, 73.5]), [3, 3], diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index ba7a02eb1d..c81891045f 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -17,12 +17,9 @@ from xdsl.dialects import memref, stream from xdsl.dialects.builtin import ( AffineMapAttr, - AnyFloatAttr, - AnyIntegerAttr, ArrayAttr, IntAttr, StringAttr, - UnitAttr, ) from xdsl.dialects.utils import AbstractYieldOperation from xdsl.ir import ( @@ -47,7 +44,6 @@ from xdsl.printer import Printer from xdsl.traits import IsTerminator, NoTerminator from xdsl.utils.exceptions import VerifyException -from xdsl.utils.hints import isa from xdsl.utils.str_enum import StrEnum @@ -319,11 +315,10 @@ class GenericOp(IRDLOperation): pattern defines the order in which the elements of the input buffers will be written to. """ - inits = prop_def(ArrayAttr[AnyFloatAttr | AnyIntegerAttr | UnitAttr]) + inits = var_operand_def() """ - Initial values for outputs. If `NoneAttr`, then the value is read from the output - buffer. Otherwise, the value is created at runtime with an `arith.constant` operation - during lowering. The inits may be set only for the imperfectly nested form. + Initial values for outputs. The outputs are at corresponding `init_indices`. The inits + may be set only for the imperfectly nested form. """ indexing_maps = prop_def(ArrayAttr[AffineMapAttr]) """ @@ -337,6 +332,10 @@ class GenericOp(IRDLOperation): """ iterator_types = prop_def(ArrayAttr[IteratorTypeAttr]) + init_indices = prop_def(ArrayAttr[IntAttr]) + """ + Indices into the `outputs` that correspond to the initial values in `inits`. + """ body: Region = region_def("single_block") @@ -346,11 +345,12 @@ def __init__( self, inputs: Sequence[SSAValue], outputs: Sequence[SSAValue], + inits: Sequence[SSAValue], body: Region, - inits: ArrayAttr[AnyFloatAttr | AnyIntegerAttr | UnitAttr], indexing_maps: ArrayAttr[AffineMapAttr], iterator_types: ArrayAttr[Attribute], bounds: ArrayAttr[IntAttr], + init_indices: ArrayAttr[IntAttr], ) -> None: for m in indexing_maps: if m.data.num_symbols: @@ -358,12 +358,12 @@ def __init__( f"Symbols currently not implemented in {self.name} indexing maps" ) super().__init__( - operands=[inputs, outputs], + operands=[inputs, outputs, inits], properties={ "bounds": bounds, - "inits": inits, - "indexing_maps": ArrayAttr(indexing_maps), - "iterator_types": ArrayAttr(iterator_types), + "init_indices": init_indices, + "indexing_maps": indexing_maps, + "iterator_types": iterator_types, }, regions=[body], ) @@ -386,6 +386,14 @@ def get_static_loop_ranges( tuple(bound.data for bound in self.bounds.data[min_dims:]), ) + def _print_init(self, printer: Printer, init: SSAValue | None): + if init is None: + printer.print_string("None") + else: + printer.print_ssa_value(init) + printer.print_string(" : ") + printer.print_attribute(init.type) + def print(self, printer: Printer): printer.print_string(" {bounds = ") printer.print_attribute(self.bounds) @@ -397,12 +405,6 @@ def print(self, printer: Printer): lambda iterator_type: printer.print_string_literal(iterator_type.data), ) printer.print_string("]") - printer.print_string(", inits = [") - printer.print_list( - self.inits, - lambda val: printer.print_attribute(val), - ) - printer.print_string("]") printer.print_string("}") if self.inputs: @@ -419,6 +421,18 @@ def print(self, printer: Printer): printer.print_list((o.type for o in self.outputs), printer.print_attribute) printer.print_string(")") + if self.inits: + printer.print_string(" inits(") + init_indices = set(attr.data for attr in self.init_indices) + inits = [ + val if i in init_indices else None for i, val in enumerate(self.inits) + ] + printer.print_list( + inits, + lambda val: self._print_init(printer, val), + ) + printer.print_string(")") + extra_attrs = self.attributes.copy() if "indexing_maps" in extra_attrs: del extra_attrs["indexing_maps"] @@ -438,6 +452,35 @@ def print(self, printer: Printer): printer.print_string(" ") printer.print_region(self.body) + @classmethod + def _parse_init(cls, parser: Parser) -> SSAValue | None: + if parser.parse_optional_characters("None"): + return None + unresolved = parser.parse_unresolved_operand() + parser.parse_punctuation(":") + type = parser.parse_type() + return parser.resolve_operand(unresolved, type) + + @classmethod + def _parse_inits( + cls, parser: Parser + ) -> tuple[tuple[SSAValue, ...], tuple[int, ...]]: + if not parser.parse_optional_characters("inits"): + return ((), ()) + + parser.parse_punctuation("(") + optional_inits = parser.parse_comma_separated_list( + Parser.Delimiter.NONE, lambda: cls._parse_init(parser) + ) + parser.parse_punctuation(")") + enumerated_inits = tuple( + (i, val) for i, val in enumerate(optional_inits) if val is not None + ) + inits = tuple(init for _, init in enumerated_inits) + init_indices = tuple(i for i, _ in enumerated_inits) + + return (tuple(inits), init_indices) + @classmethod def parse(cls, parser: Parser) -> Self: attrs_start_pos = parser.pos @@ -499,18 +542,6 @@ def parse(cls, parser: Parser) -> Self: attrs_end_pos, ) - if "inits" in attrs: - inits = attrs["inits"] - if not isa(inits, ArrayAttr[AnyFloatAttr | AnyIntegerAttr | UnitAttr]): - parser.raise_error("Expected inits for memref_stream.generic") - del attrs["inits"] - else: - parser.raise_error( - "Expected inits for memref_stream.generic", - attrs_start_pos, - attrs_end_pos, - ) - if "doc" in attrs: doc = attrs["doc"] assert isinstance(doc, StringAttr) @@ -553,8 +584,11 @@ def parse(cls, parser: Parser) -> Self: parser.parse_punctuation(")") outs = parser.resolve_operands(unresolved_outs, outs_types, pos) else: + outs_types = () outs = () + inits, init_indices = cls._parse_inits(parser) + if parser.parse_optional_keyword("attrs"): parser.parse_punctuation("=") extra_attrs = parser.expect( @@ -568,11 +602,12 @@ def parse(cls, parser: Parser) -> Self: generic = cls( ins, outs, - body, inits, + body, indexing_maps, ArrayAttr(iterator_types), bounds, + ArrayAttr(IntAttr(index) for index in init_indices), ) generic.attributes |= attrs generic.attributes |= extra_attrs @@ -580,11 +615,9 @@ def parse(cls, parser: Parser) -> Self: return generic def verify_(self) -> None: - # Verify that the number of initial values for outputs is the same as the number - # of outputs - if len(self.inits) != len(self.outputs): + if len(self.inits) != len(self.init_indices): raise VerifyException( - f"Mismatching number of outputs and initial values: {len(self.outputs)} != {self.inits}" + f"Mismatching number of inits and init indices: {len(self.inits)} != {self.init_indices}" ) # Parallel iterator types must preceed reduction iterators @@ -595,9 +628,9 @@ def verify_(self) -> None: f"Unexpected order of iterator types: {[it.data.value for it in iterator_types]}" ) - if len(self.operands) != len(self.indexing_maps): + if len(self.inputs) + len(self.outputs) != len(self.indexing_maps): raise VerifyException( - "The number of affine maps must match the number of operands" + "The number of affine maps must match the number of inputs and outputs" ) # Whether or not the operation represents an imperfect loop nest, verify that the @@ -631,16 +664,25 @@ def verify_(self) -> None: f"{len(iterator_types)} or {num_parallel}" ) - # The non-None values of the inits must correspond to inputs where the domain - # of the affine map has the same number of dimensions as the number of parallel - # iterators - for i, (m, init) in enumerate(zip(output_maps, self.inits, strict=True)): - if init != UnitAttr(): - if m.data.num_dims != num_parallel: - raise VerifyException( - "Incompatible affine map and initial value for output at index " - f"{i}" - ) + if len(self.init_indices) != len(self.inits): + raise VerifyException( + "The number of inits and init_indices must be the same" + ) + + # The values of the inits must correspond to outputs where the domain of the + # affine map has the same number of dimensions as the number of parallel + # iterators. + num_outputs = len(self.outputs) + output_maps = self.indexing_maps.data[-num_outputs:] + for index in self.init_indices: + if not (0 <= index.data <= num_outputs): + raise VerifyException(f"Init index out of bounds: {index.data}") + m = output_maps[index.data] + if m.data.num_dims != num_parallel: + raise VerifyException( + "Incompatible affine map and initial value for output at index " + f"{index}" + ) @irdl_op_definition diff --git a/xdsl/interpreters/memref_stream.py b/xdsl/interpreters/memref_stream.py index 7ba522856d..ce1174cedd 100644 --- a/xdsl/interpreters/memref_stream.py +++ b/xdsl/interpreters/memref_stream.py @@ -2,7 +2,6 @@ from typing import Any, cast from xdsl.dialects import memref_stream -from xdsl.dialects.builtin import UnitAttr from xdsl.interpreter import ( Interpreter, InterpreterFunctions, @@ -26,15 +25,21 @@ def run_generic( ) -> PythonValues: inputs_count = len(op.inputs) + outputs_count = len(op.outputs) - outputs: tuple[ShapedArray[float], ...] = args[inputs_count:] + outputs: tuple[ShapedArray[int | float], ...] = args[ + inputs_count : inputs_count + outputs_count + ] + init_values: tuple[int | float, ...] = args[inputs_count + outputs_count :] indexing_maps = tuple(attr.data for attr in op.indexing_maps) output_indexing_maps = indexing_maps[inputs_count:] outer_ubs, inner_ubs = op.get_static_loop_ranges() - inits = op.inits.data + inits: list[None | int | float] = [None] * len(op.outputs) + for index, init in zip(op.init_indices, init_values, strict=True): + inits[index.data] = init if inner_ubs: inputs: tuple[ShapedArray[float] | float, ...] = args[:inputs_count] @@ -42,14 +47,15 @@ def run_generic( for outer_indices in product(*(range(outer_ub) for outer_ub in outer_ubs)): output_loop_args = tuple( ( - (cast(ShapedArray[int | float], o)).load( - indexing_map.eval(outer_indices, ()) - ) - if isinstance(init, UnitAttr) - else init.value.data + o.load(indexing_map.eval(outer_indices, ())) + if init is None + else init ) for o, indexing_map, init in zip( - outputs, output_indexing_maps, inits, strict=True + outputs, + output_indexing_maps, + inits, + strict=True, ) ) for inner_indices in product( @@ -72,7 +78,6 @@ def run_generic( op.body, input_loop_args + output_loop_args, "for_loop" ) output_loop_args = loop_results - print(output_loop_args, output_indexing_maps, outputs) for res, indexing_map, output in zip( output_loop_args, output_indexing_maps, outputs, strict=True ): diff --git a/xdsl/transforms/convert_linalg_to_memref_stream.py b/xdsl/transforms/convert_linalg_to_memref_stream.py index 8a59ac5c37..e1e9ade527 100644 --- a/xdsl/transforms/convert_linalg_to_memref_stream.py +++ b/xdsl/transforms/convert_linalg_to_memref_stream.py @@ -1,6 +1,6 @@ from xdsl.context import MLContext from xdsl.dialects import linalg, memref_stream -from xdsl.dialects.builtin import ArrayAttr, IntAttr, ModuleOp, UnitAttr +from xdsl.dialects.builtin import ArrayAttr, IntAttr, ModuleOp from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -42,11 +42,12 @@ def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter) -> No memref_stream.GenericOp( op.inputs, op.outputs, + (), rewriter.move_region_contents_to_new_regions(op.body), - ArrayAttr(tuple(UnitAttr() for _ in range(len(op.outputs)))), op.indexing_maps, iterator_types, bounds, + ArrayAttr(()), ) ) diff --git a/xdsl/transforms/memref_streamify.py b/xdsl/transforms/memref_streamify.py index 0997e5f6fa..de9601c1fa 100644 --- a/xdsl/transforms/memref_streamify.py +++ b/xdsl/transforms/memref_streamify.py @@ -3,7 +3,7 @@ from xdsl.context import MLContext from xdsl.dialects import memref, memref_stream, stream -from xdsl.dialects.builtin import ArrayAttr, ModuleOp, UnitAttr +from xdsl.dialects.builtin import ArrayAttr, ModuleOp from xdsl.ir import Attribute, Block, Region from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -27,7 +27,7 @@ def match_and_rewrite( # Already streamified return - if any(not isinstance(init, UnitAttr) for init in op.inits): + if op.inits: raise NotImplementedError( "Cannot streamify operation that has inits that are not UnitAttr" ) @@ -84,7 +84,7 @@ def match_and_rewrite( ) ) new_body = streaming_region_op.body.block - new_operands = list(op.operands) + new_operands = list(op.operands[: len(op.inputs) + len(op.outputs)]) for stream_index, (index, _) in enumerate(streamed_operand_indices): new_operands[index] = new_body.args[stream_index] @@ -92,11 +92,12 @@ def match_and_rewrite( memref_stream.GenericOp( new_operands[:input_count], new_operands[input_count:], - rewriter.move_region_contents_to_new_regions(op.body), op.inits, + rewriter.move_region_contents_to_new_regions(op.body), op.indexing_maps, op.iterator_types, op.bounds, + op.init_indices, ), InsertPoint.at_end(new_body), ) From aed0cd46d80cd7a583b4a83027f04c14177dd0ee Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 24 Jun 2024 10:25:44 +0100 Subject: [PATCH 09/14] straggling inits --- .../convert_memref_stream_to_loops.mlir | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir index 84c993bc19..9844dd49e5 100644 --- a/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir +++ b/tests/filecheck/transforms/convert_memref_stream_to_loops.mlir @@ -11,7 +11,7 @@ ] } ins(%arg0, %arg1 : memref<8x16xf64>, memref<8x16xf64>) outs(%arg2 : memref<8x16xf64>) { ^0(%0 : !stream.readable, %1 : !stream.readable, %2 : !stream.writable): - memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { + memref_stream.generic {bounds = [#builtin.int<8>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%0, %1 : !stream.readable, !stream.readable) outs(%2 : !stream.writable) { ^1(%in : f64, %in_0 : f64, %out : f64): %3 = arith.addf %in, %in_0 : f64 memref_stream.yield %3 : f64 @@ -48,7 +48,7 @@ ] } ins(%arg0_1 : memref<16x16xf64>) outs(%arg1_1 : memref<16x16xf64>) { ^2(%4 : !stream.readable, %5 : !stream.writable): - memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"], inits = [unit]} ins(%4 : !stream.readable) outs(%5 : !stream.writable) { + memref_stream.generic {bounds = [#builtin.int<16>, #builtin.int<16>], indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%4 : !stream.readable) outs(%5 : !stream.writable) { ^3(%in_1 : f64, %out_1 : f64): %6 = arith.maximumf %in_1, %cst : f64 memref_stream.yield %6 : f64 @@ -90,8 +90,7 @@ func.func public @fill(%arg0 : memref<16x16xf64>) -> memref<16x16xf64> { affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel"], - inits = [unit] + iterator_types = ["parallel", "parallel"] } ins(%zero : f64) outs(%7 : !stream.writable) { ^4(%in: f64, %out: f64): memref_stream.yield %in : f64 @@ -132,8 +131,7 @@ func.func @main(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64> affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)> ], - iterator_types = ["parallel", "parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "parallel", "reduction"] } ins(%0, %1 : !stream.readable, !stream.readable) outs(%C : memref<4x3xf64>) { ^1(%a : f64, %b : f64, %acc_old : f64): %prod = arith.mulf %a, %b : f64 @@ -181,8 +179,7 @@ func.func @elide_affine(%A : memref<6xf64>, %B : memref) -> memref { affine_map<(d0, d1) -> (d0 * 3 + d1)>, affine_map<(d0, d1) -> ()> ], - iterator_types = ["parallel", "reduction"], - inits = [unit] + iterator_types = ["parallel", "reduction"] } ins(%0 : !stream.readable) outs(%B : memref) { ^1(%a : f64, %acc_old : f64): %acc_new = arith.addf %acc_old, %a : f64 @@ -223,8 +220,7 @@ func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref) -> memref< affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<() -> ()> ], - iterator_types = ["reduction", "reduction", "reduction"], - inits = [unit] + iterator_types = ["reduction", "reduction", "reduction"] } ins(%0 : !stream.readable) outs(%B : memref) { ^1(%a : f64, %acc_old : f64): %acc_new = arith.addf %acc_old, %a : f64 From 968c53d26d5fcf4ff6f1495fe494385e799307ea Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 24 Jun 2024 17:06:28 +0100 Subject: [PATCH 10/14] fix printing --- .../filecheck/dialects/memref_stream/ops.mlir | 38 +++++++++++++++++++ xdsl/dialects/memref_stream.py | 9 ++--- 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/tests/filecheck/dialects/memref_stream/ops.mlir b/tests/filecheck/dialects/memref_stream/ops.mlir index b1fef4c430..dc3736508a 100644 --- a/tests/filecheck/dialects/memref_stream/ops.mlir +++ b/tests/filecheck/dialects/memref_stream/ops.mlir @@ -121,5 +121,43 @@ memref_stream.generic { // CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f64) -> () // CHECK-GENERIC-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) -> () + +%I = "test.op"() : () -> (memref<4x3xf64>) +// CHECK-NEXT: %I = "test.op"() : () -> (memref<4x3xf64>) +// CHECK-GENERIC-NEXT: %I = "test.op"() : () -> (memref<4x3xf64>) + +memref_stream.generic { + bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2) -> (d0, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)> + ], + iterator_types = ["parallel", "parallel", "reduction"] +} ins(%E, %F : memref<4x2xf64>, memref<2x3xf64>) outs(%G, %I : memref<4x3xf64>, memref<4x3xf64>) inits(%H : f64, None) { +^0(%e : f64, %f : f64, %acc_old_0 : f64, %acc_old_1 : f64): + %prod = arith.mulf %e, %f : f64 + %acc_new_0 = arith.addf %acc_old_0, %prod : f64 + %acc_new_1 = arith.addf %acc_old_1, %prod : f64 + linalg.yield %acc_new_0, %acc_new_1 : f64, f64 +} + +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) outs(%{{.*}}, %{{.*}} : memref<4x3xf64>, memref<4x3xf64>) inits(%{{.*}} : f64, None) { +// CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): +// CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64 +// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 +// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 +// CHECK-NEXT: linalg.yield %{{.*}}, %{{.*}} : f64, f64 +// CHECK-NEXT: } + +// CHECK-NEXT: "memref_stream.generic"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{"bounds" = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], "init_indices" = [#builtin.int<0>], "indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): +// CHECK-NEXT: %{{.*}} = "arith.mulf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 +// CHECK-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 +// CHECK-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 +// CHECK-NEXT: "linalg.yield"(%{{.*}}, %{{.*}}) : (f64, f64) -> () +// CHECK-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, memref<4x3xf64>, f64) -> () + // CHECK-NEXT: } // CHECK-GENERIC-NEXT: }) : () -> () diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index c81891045f..0614d1ef1f 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -423,10 +423,9 @@ def print(self, printer: Printer): if self.inits: printer.print_string(" inits(") - init_indices = set(attr.data for attr in self.init_indices) - inits = [ - val if i in init_indices else None for i, val in enumerate(self.inits) - ] + inits: list[SSAValue | None] = [None] * len(self.outputs) + for i, val in zip(self.init_indices, self.inits): + inits[i.data] = val printer.print_list( inits, lambda val: self._print_init(printer, val), @@ -442,8 +441,6 @@ def print(self, printer: Printer): del extra_attrs["doc"] if "library_call" in extra_attrs: del extra_attrs["library_call"] - if "inits" in extra_attrs: - del extra_attrs["inits"] if extra_attrs: printer.print(" attrs = ") From f114076eaf7d850a1bf10d36d6631d47b4093a5e Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 24 Jun 2024 17:09:56 +0100 Subject: [PATCH 11/14] fix fix --- .../filecheck/dialects/memref_stream/ops.mlir | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/tests/filecheck/dialects/memref_stream/ops.mlir b/tests/filecheck/dialects/memref_stream/ops.mlir index dc3736508a..10e5db4c58 100644 --- a/tests/filecheck/dialects/memref_stream/ops.mlir +++ b/tests/filecheck/dialects/memref_stream/ops.mlir @@ -121,10 +121,9 @@ memref_stream.generic { // CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}) : (f64) -> () // CHECK-GENERIC-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, f64) -> () - -%I = "test.op"() : () -> (memref<4x3xf64>) -// CHECK-NEXT: %I = "test.op"() : () -> (memref<4x3xf64>) -// CHECK-GENERIC-NEXT: %I = "test.op"() : () -> (memref<4x3xf64>) +%I = "test.op"() : () -> memref<4x3xf64> +// CHECK-NEXT: %I = "test.op"() : () -> memref<4x3xf64> +// CHECK-GENERIC-NEXT: %I = "test.op"() : () -> memref<4x3xf64> memref_stream.generic { bounds = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], @@ -151,13 +150,20 @@ memref_stream.generic { // CHECK-NEXT: linalg.yield %{{.*}}, %{{.*}} : f64, f64 // CHECK-NEXT: } -// CHECK-NEXT: "memref_stream.generic"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{"bounds" = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], "init_indices" = [#builtin.int<0>], "indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ -// CHECK-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): -// CHECK-NEXT: %{{.*}} = "arith.mulf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 -// CHECK-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 -// CHECK-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 -// CHECK-NEXT: "linalg.yield"(%{{.*}}, %{{.*}}) : (f64, f64) -> () -// CHECK-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, memref<4x3xf64>, f64) -> () +// CHECK-GENERIC-NEXT: "memref_stream.generic"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) <{"bounds" = [#builtin.int<4>, #builtin.int<2>, #builtin.int<3>], "init_indices" = [#builtin.int<0>], "indexing_maps" = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], "iterator_types" = [#memref_stream.iterator_type, #memref_stream.iterator_type, #memref_stream.iterator_type], "operandSegmentSizes" = array}> ({ +// CHECK-GENERIC-NEXT: ^{{.*}}(%{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64, %{{.*}} : f64): +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.mulf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 +// CHECK-GENERIC-NEXT: %{{.*}} = "arith.addf"(%{{.*}}, %{{.*}}) <{"fastmath" = #arith.fastmath}> : (f64, f64) -> f64 +// CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}, %{{.*}}) : (f64, f64) -> () +// CHECK-GENERIC-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, memref<4x3xf64>, f64) -> () + + +memref_stream.fill %C with %D : memref<3x2xf64> + +// CHECK-NEXT: memref_stream.fill %C with %D : memref<3x2xf64> +// CHECK-GENERIC-NEXT: "memref_stream.fill"(%C, %D) : (memref<3x2xf64>, f64) -> () + // CHECK-NEXT: } // CHECK-GENERIC-NEXT: }) : () -> () From 9ecbc328b4b46f56e7f348effe59b6661cb135f2 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 24 Jun 2024 17:10:24 +0100 Subject: [PATCH 12/14] fix fix fix --- tests/filecheck/dialects/memref_stream/ops.mlir | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/filecheck/dialects/memref_stream/ops.mlir b/tests/filecheck/dialects/memref_stream/ops.mlir index 10e5db4c58..b3bb6cfa3b 100644 --- a/tests/filecheck/dialects/memref_stream/ops.mlir +++ b/tests/filecheck/dialects/memref_stream/ops.mlir @@ -158,12 +158,5 @@ memref_stream.generic { // CHECK-GENERIC-NEXT: "linalg.yield"(%{{.*}}, %{{.*}}) : (f64, f64) -> () // CHECK-GENERIC-NEXT: }) : (memref<4x2xf64>, memref<2x3xf64>, memref<4x3xf64>, memref<4x3xf64>, f64) -> () - -memref_stream.fill %C with %D : memref<3x2xf64> - -// CHECK-NEXT: memref_stream.fill %C with %D : memref<3x2xf64> -// CHECK-GENERIC-NEXT: "memref_stream.fill"(%C, %D) : (memref<3x2xf64>, f64) -> () - - // CHECK-NEXT: } // CHECK-GENERIC-NEXT: }) : () -> () From 9bd3f61cb0f9159ca9b8909651a8c099156a7355 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 25 Jun 2024 10:35:06 +0100 Subject: [PATCH 13/14] transformations: Fix bug in memref-stream-unnest-out-parameters --- .../memref_stream_unnest_out_parameters.mlir | 26 +++++++++++++++++++ .../memref_stream_unnest_out_parameters.py | 10 +++++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir b/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir index a585eeac2d..66afadb6ad 100644 --- a/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir +++ b/tests/filecheck/transforms/memref_stream_unnest_out_parameters.mlir @@ -26,4 +26,30 @@ memref_stream.generic { // CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64 // CHECK-NEXT: memref_stream.yield %{{.*}} : f64 // CHECK-NEXT: } + +%X, %Y, %Z = "test.op"() : () -> (memref<1x1x8x8xf64>, memref<1x1x3x3xf64>, memref<1x1x6x6xf64>) + +memref_stream.generic { + bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], + indexing_maps = [ + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)> + ], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] +} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { +^0(%x : f64, %y : f64, %acc : f64): + %prod = arith.mulf %x, %y fastmath : f64 + %res = arith.addf %prod, %acc fastmath : f64 + memref_stream.yield %res : f64 +} + +// CHECK-NEXT: %X, %Y, %Z = "test.op"() : () -> (memref<1x1x8x8xf64>, memref<1x1x3x3xf64>, memref<1x1x6x6xf64>) +// CHECK-NEXT: memref_stream.generic {bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, (d2 + d5), (d3 + d6))>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { +// CHECK-NEXT: ^1(%x : f64, %y : f64, %acc : f64): +// CHECK-NEXT: %prod_1 = arith.mulf %x, %y fastmath : f64 +// CHECK-NEXT: %res = arith.addf %prod_1, %acc fastmath : f64 +// CHECK-NEXT: memref_stream.yield %res : f64 +// CHECK-NEXT: } + // CHECK-NEXT: } diff --git a/xdsl/transforms/memref_stream_unnest_out_parameters.py b/xdsl/transforms/memref_stream_unnest_out_parameters.py index 6acf5901c5..d33f700847 100644 --- a/xdsl/transforms/memref_stream_unnest_out_parameters.py +++ b/xdsl/transforms/memref_stream_unnest_out_parameters.py @@ -22,10 +22,16 @@ class UnnestOutParametersPattern(RewritePattern): def match_and_rewrite( self, op: memref_stream.GenericOp, rewriter: PatternRewriter ) -> None: + if op.is_imperfectly_nested: + # Already unnested + return + num_outputs = len(op.outputs) if not num_outputs: return + num_inputs = len(op.inputs) + num_parallel = sum( i == memref_stream.IteratorTypeAttr.parallel() for i in op.iterator_types ) @@ -37,10 +43,10 @@ def match_and_rewrite( parallel_dims = (True,) * num_parallel + (False,) * num_reduction - maps = op.indexing_maps.data[num_parallel:] + maps = op.indexing_maps.data[num_inputs:] new_maps = ArrayAttr( ( - *op.indexing_maps.data[:num_parallel], + *op.indexing_maps.data[:num_inputs], *(AffineMapAttr(m.data.compress_dims(parallel_dims)) for m in maps), ) ) From 4d15a54aa46ff629c30394e2cbdd3a611d8b05d1 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Tue, 25 Jun 2024 14:15:09 +0100 Subject: [PATCH 14/14] add helper --- xdsl/dialects/memref_stream.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/xdsl/dialects/memref_stream.py b/xdsl/dialects/memref_stream.py index 0614d1ef1f..88b56dc261 100644 --- a/xdsl/dialects/memref_stream.py +++ b/xdsl/dialects/memref_stream.py @@ -386,6 +386,10 @@ def get_static_loop_ranges( tuple(bound.data for bound in self.bounds.data[min_dims:]), ) + @property + def is_imperfectly_nested(self) -> bool: + return bool(self.get_static_loop_ranges()[1]) + def _print_init(self, printer: Printer, init: SSAValue | None): if init is None: printer.print_string("None")