From c330d8604d68c3315b3ed1b550e408420c8c4d05 Mon Sep 17 00:00:00 2001 From: n-io Date: Fri, 27 Sep 2024 18:11:23 +0200 Subject: [PATCH 1/4] transformations: New csl-stencil-materialize-stores pass --- .../csl-stencil-materialize-stores.mlir | 134 +++++++++++++++ xdsl/tools/command_line_tool.py | 6 + .../csl_stencil_materialize_stores.py | 152 ++++++++++++++++++ 3 files changed, 292 insertions(+) create mode 100644 tests/filecheck/transforms/csl-stencil-materialize-stores.mlir create mode 100644 xdsl/transforms/csl_stencil_materialize_stores.py diff --git a/tests/filecheck/transforms/csl-stencil-materialize-stores.mlir b/tests/filecheck/transforms/csl-stencil-materialize-stores.mlir new file mode 100644 index 0000000000..2a7a7ae88e --- /dev/null +++ b/tests/filecheck/transforms/csl-stencil-materialize-stores.mlir @@ -0,0 +1,134 @@ +// RUN: xdsl-opt -p csl-stencil-materialize-stores %s | filecheck %s + +builtin.module { + "csl_wrapper.module"() <{"height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=1 : i16>, #csl_wrapper.param<"chunk_size" default=510 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "gauss_seidel", "width" = 1024 : i16}> ({ + ^0(%arg0 : i16, %arg1 : i16, %arg2 : i16, %arg3 : i16, %arg4 : i16, %arg5 : i16, %arg6 : i16, %arg7 : i16, %arg8 : i16): + %0 = arith.constant 0 : i16 + %1 = "csl.get_color"(%0) : (i16) -> !csl.color + %2 = "csl_wrapper.import"(%arg2, %arg3, %1) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module + %3 = "csl_wrapper.import"(%arg5, %arg2, %arg3) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module + %4 = "csl.member_call"(%3, %arg0, %arg1, %arg2, %arg3, %arg5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct + %5 = "csl.member_call"(%2, %arg0) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct + %6 = arith.constant 1 : i16 + %7 = arith.subi %arg5, %6 : i16 + %8 = arith.subi %arg2, %arg0 : i16 + %9 = arith.subi %arg3, %arg1 : i16 + %10 = arith.cmpi slt, %arg0, %7 : i16 + %11 = arith.cmpi slt, %arg1, %7 : i16 + %12 = arith.cmpi slt, %8, %arg5 : i16 + %13 = arith.cmpi slt, %9, %arg5 : i16 + %14 = arith.ori %10, %11 : i1 + %15 = arith.ori %14, %12 : i1 + %16 = arith.ori %15, %13 : i1 + "csl_wrapper.yield"(%5, %4, %16) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () + }, { + ^1(%arg0_1 : i16, %arg1_1 : i16, %arg2_1 : i16, %arg3_1 : i16, %arg4_1 : i16, %arg5_1 : i16, %arg6_1 : i16, %arg7_1 : !csl.comptime_struct, %arg8_1 : !csl.comptime_struct, %arg9 : i1): + %17 = "csl_wrapper.import"(%arg7_1) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module + %18 = "csl_wrapper.import"(%arg3_1, %arg5_1, %arg8_1) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module + %19 = memref.alloc() : memref<512xf32> + %20 = memref.alloc() : memref<512xf32> + %21 = "csl.addressof"(%19) : (memref<512xf32>) -> !csl.ptr, #csl> + %22 = "csl.addressof"(%20) : (memref<512xf32>) -> !csl.ptr, #csl> + "csl.export"(%21) <{"type" = !csl.ptr, #csl>, "var_name" = "a"}> : (!csl.ptr, #csl>) -> () + "csl.export"(%22) <{"type" = !csl.ptr, #csl>, "var_name" = "b"}> : (!csl.ptr, #csl>) -> () + "csl.export"() <{"type" = () -> (), "var_name" = @gauss_seidel}> : () -> () + csl.func @gauss_seidel() { + %23 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> + csl_stencil.apply(%19 : memref<512xf32>, %23 : memref<510xf32>) outs (%20 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + ^2(%arg10 : memref<4x510xf32>, %arg11 : index, %arg12 : memref<510xf32>): + %24 = csl_stencil.access %arg10[1, 0] : memref<4x510xf32> + %25 = csl_stencil.access %arg10[-1, 0] : memref<4x510xf32> + %26 = csl_stencil.access %arg10[0, 1] : memref<4x510xf32> + %27 = csl_stencil.access %arg10[0, -1] : memref<4x510xf32> + %28 = memref.subview %arg12[%arg11] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> + linalg.add ins(%27, %26 : memref<510xf32>, memref<510xf32>) outs(%28 : memref<510xf32, strided<[1], offset: ?>>) + linalg.add ins(%28, %25 : memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) outs(%28 : memref<510xf32, strided<[1], offset: ?>>) + linalg.add ins(%28, %24 : memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) outs(%28 : memref<510xf32, strided<[1], offset: ?>>) + %29 = memref.subview %arg12[%arg11] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> + "memref.copy"(%28, %29) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () + csl_stencil.yield %arg12 : memref<510xf32> + }, { + ^3(%arg10_1 : memref<512xf32>, %arg11_1 : memref<510xf32>): + %30 = arith.constant dense<1.666600e-01> : memref<510xf32> + %31 = memref.subview %arg10_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> + %32 = memref.subview %arg10_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> + linalg.add ins(%arg11_1, %32 : memref<510xf32>, memref<510xf32, strided<[1]>>) outs(%arg11_1 : memref<510xf32>) + linalg.add ins(%arg11_1, %31 : memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) outs(%arg11_1 : memref<510xf32>) + linalg.mul ins(%arg11_1, %30 : memref<510xf32>, memref<510xf32>) outs(%arg11_1 : memref<510xf32>) + csl_stencil.yield %arg11_1 : memref<510xf32> + }) to <[0, 0], [1, 1]> + "csl.member_call"(%17) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () + csl.return + } + "csl_wrapper.yield"() <{"fields" = []}> : () -> () + }) : () -> () +} + +// CHECK-NEXT: builtin.module { +// CHECK-NEXT: "csl_wrapper.module"() <{"height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=1 : i16>, #csl_wrapper.param<"chunk_size" default=510 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "gauss_seidel", "width" = 1024 : i16}> ({ +// CHECK-NEXT: ^0(%arg0 : i16, %arg1 : i16, %arg2 : i16, %arg3 : i16, %arg4 : i16, %arg5 : i16, %arg6 : i16, %arg7 : i16, %arg8 : i16): +// CHECK-NEXT: %0 = arith.constant 0 : i16 +// CHECK-NEXT: %1 = "csl.get_color"(%0) : (i16) -> !csl.color +// CHECK-NEXT: %2 = "csl_wrapper.import"(%arg2, %arg3, %1) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module +// CHECK-NEXT: %3 = "csl_wrapper.import"(%arg5, %arg2, %arg3) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module +// CHECK-NEXT: %4 = "csl.member_call"(%3, %arg0, %arg1, %arg2, %arg3, %arg5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct +// CHECK-NEXT: %5 = "csl.member_call"(%2, %arg0) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct +// CHECK-NEXT: %6 = arith.constant 1 : i16 +// CHECK-NEXT: %7 = arith.subi %arg5, %6 : i16 +// CHECK-NEXT: %8 = arith.subi %arg2, %arg0 : i16 +// CHECK-NEXT: %9 = arith.subi %arg3, %arg1 : i16 +// CHECK-NEXT: %10 = arith.cmpi slt, %arg0, %7 : i16 +// CHECK-NEXT: %11 = arith.cmpi slt, %arg1, %7 : i16 +// CHECK-NEXT: %12 = arith.cmpi slt, %8, %arg5 : i16 +// CHECK-NEXT: %13 = arith.cmpi slt, %9, %arg5 : i16 +// CHECK-NEXT: %14 = arith.ori %10, %11 : i1 +// CHECK-NEXT: %15 = arith.ori %14, %12 : i1 +// CHECK-NEXT: %16 = arith.ori %15, %13 : i1 +// CHECK-NEXT: "csl_wrapper.yield"(%5, %4, %16) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () +// CHECK-NEXT: }, { +// CHECK-NEXT: ^1(%arg0_1 : i16, %arg1_1 : i16, %arg2_1 : i16, %arg3_1 : i16, %arg4_1 : i16, %arg5_1 : i16, %arg6_1 : i16, %arg7_1 : !csl.comptime_struct, %arg8_1 : !csl.comptime_struct, %arg9 : i1): +// CHECK-NEXT: %17 = "csl_wrapper.import"(%arg7_1) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %18 = "csl_wrapper.import"(%arg3_1, %arg5_1, %arg8_1) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %19 = memref.alloc() : memref<512xf32> +// CHECK-NEXT: %20 = memref.alloc() : memref<512xf32> +// CHECK-NEXT: %21 = "csl.addressof"(%19) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: %22 = "csl.addressof"(%20) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: "csl.export"(%21) <{"type" = !csl.ptr, #csl>, "var_name" = "a"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"(%22) <{"type" = !csl.ptr, #csl>, "var_name" = "b"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"() <{"type" = () -> (), "var_name" = @gauss_seidel}> : () -> () +// CHECK-NEXT: csl.func @gauss_seidel() { +// CHECK-NEXT: %23 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> +// CHECK-NEXT: csl_stencil.apply(%19 : memref<512xf32>, %23 : memref<510xf32>, %20 : memref<512xf32>, %arg9 : i1) outs (%20 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 1 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ +// CHECK-NEXT: ^2(%arg10 : memref<4x510xf32>, %arg11 : index, %arg12 : memref<510xf32>): +// CHECK-NEXT: %24 = csl_stencil.access %arg10[1, 0] : memref<4x510xf32> +// CHECK-NEXT: %25 = csl_stencil.access %arg10[-1, 0] : memref<4x510xf32> +// CHECK-NEXT: %26 = csl_stencil.access %arg10[0, 1] : memref<4x510xf32> +// CHECK-NEXT: %27 = csl_stencil.access %arg10[0, -1] : memref<4x510xf32> +// CHECK-NEXT: %28 = memref.subview %arg12[%arg11] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> +// CHECK-NEXT: linalg.add ins(%27, %26 : memref<510xf32>, memref<510xf32>) outs(%28 : memref<510xf32, strided<[1], offset: ?>>) +// CHECK-NEXT: linalg.add ins(%28, %25 : memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) outs(%28 : memref<510xf32, strided<[1], offset: ?>>) +// CHECK-NEXT: linalg.add ins(%28, %24 : memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) outs(%28 : memref<510xf32, strided<[1], offset: ?>>) +// CHECK-NEXT: %29 = memref.subview %arg12[%arg11] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> +// CHECK-NEXT: "memref.copy"(%28, %29) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () +// CHECK-NEXT: csl_stencil.yield %arg12 : memref<510xf32> +// CHECK-NEXT: }, { +// CHECK-NEXT: ^3(%arg10_1 : memref<512xf32>, %arg11_1 : memref<510xf32>, %30 : memref<512xf32>, %31 : i1): +// CHECK-NEXT: scf.if %31 { +// CHECK-NEXT: } else { +// CHECK-NEXT: %32 = arith.constant dense<1.666600e-01> : memref<510xf32> +// CHECK-NEXT: %33 = memref.subview %arg10_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %34 = memref.subview %arg10_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: linalg.add ins(%arg11_1, %34 : memref<510xf32>, memref<510xf32, strided<[1]>>) outs(%arg11_1 : memref<510xf32>) +// CHECK-NEXT: linalg.add ins(%arg11_1, %33 : memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) outs(%arg11_1 : memref<510xf32>) +// CHECK-NEXT: linalg.mul ins(%arg11_1, %32 : memref<510xf32>, memref<510xf32>) outs(%arg11_1 : memref<510xf32>) +// CHECK-NEXT: %35 = memref.subview %30[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "memref.copy"(%arg11_1, %35) : (memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: } +// CHECK-NEXT: csl_stencil.yield +// CHECK-NEXT: }) to <[0, 0], [1, 1]> +// CHECK-NEXT: "csl.member_call"(%17) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () +// CHECK-NEXT: }) : () -> () +// CHECK-NEXT: } diff --git a/xdsl/tools/command_line_tool.py b/xdsl/tools/command_line_tool.py index e80c9d6b83..51115344cf 100644 --- a/xdsl/tools/command_line_tool.py +++ b/xdsl/tools/command_line_tool.py @@ -101,6 +101,11 @@ def get_csl_stencil_bufferize(): return csl_stencil_bufferize.CslStencilBufferize + def get_csl_stencil_materialize_stores(): + from xdsl.transforms import csl_stencil_materialize_stores + + return csl_stencil_materialize_stores.CslStencilMaterializeStores + def get_csl_stencil_to_csl_wrapper(): from xdsl.transforms import csl_stencil_to_csl_wrapper @@ -441,6 +446,7 @@ def get_stencil_bufferize(): "convert-stencil-to-ll-mlir": get_convert_stencil_to_ll_mlir, "cse": get_cse, "csl-stencil-bufferize": get_csl_stencil_bufferize, + "csl-stencil-materialize-stores": get_csl_stencil_materialize_stores, "csl-stencil-to-csl-wrapper": get_csl_stencil_to_csl_wrapper, "csl-wrapper-hoist-buffers": get_csl_wrapper_hoist_buffers, "csl-stencil-handle-async-flow": get_csl_stencil_handle_async_flow, diff --git a/xdsl/transforms/csl_stencil_materialize_stores.py b/xdsl/transforms/csl_stencil_materialize_stores.py new file mode 100644 index 0000000000..1099419bf9 --- /dev/null +++ b/xdsl/transforms/csl_stencil_materialize_stores.py @@ -0,0 +1,152 @@ +from dataclasses import dataclass + +from xdsl.context import MLContext +from xdsl.dialects import memref, scf +from xdsl.dialects.builtin import ModuleOp +from xdsl.dialects.csl import csl_stencil, csl_wrapper +from xdsl.ir import Attribute, Block, Operation, Region, SSAValue +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.rewriter import InsertPoint +from xdsl.utils.hints import isa + + +@dataclass(frozen=True) +class MaterializeInApplyDest(RewritePattern): + """ + Lowers csl_stencil.yield to csl.return. + Note, the callbacks generated return no values, whereas the yield op + to be replaced may still report to yield values. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl_stencil.YieldOp, rewriter: PatternRewriter, /): + assert isinstance(apply := op.parent_op(), csl_stencil.ApplyOp) + + # the second callback stores yielded values to dest + if op.parent_region() == apply.done_exchange: + views: list[Operation] = [] + add_args: list[SSAValue] = [] + for src, dst in zip(op.arguments, apply.dest): + assert isa(src.type, memref.MemRefType[Attribute]) + assert isa(dst.type, memref.MemRefType[Attribute]) + dst_arg = apply.done_exchange.block.insert_arg( + dst.type, len(apply.done_exchange.block.args) + ) + views.append( + memref.Subview.get( + dst_arg, + [ + (d - s) // 2 # symmetric offset + for s, d in zip(src.type.get_shape(), dst.type.get_shape()) + ], + src.type.get_shape(), + len(src.type.get_shape()) * [1], + src.type, + ) + ) + add_args.append(dst) + copies = [memref.CopyOp(src, dst) for src, dst in zip(op.arguments, views)] + rewriter.insert_op( + [*views, *copies], + InsertPoint.before(op), + ) + + rewriter.replace_matched_op(csl_stencil.YieldOp()) + rewriter.replace_op( + apply, + csl_stencil.ApplyOp( + operands=[ + apply.field, + apply.accumulator, + [*apply.args, *add_args], + apply.dest, + ], + regions=[apply.detach_region(r) for r in apply.regions], + properties=apply.properties, + result_types=apply.result_types or [[]], + ), + ) + + +@dataclass(frozen=True) +class DisableComputeInBorderRegion(RewritePattern): + """ + Processing elements in the border region do not need to do compute or store their values back to a buffer. + For simplicity, wrap the full `csl_stencil.apply.done_exchange` region in an `scf.if`. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /): + wrapper_op = op.parent_op() + while wrapper_op and not isinstance(wrapper_op, csl_wrapper.ModuleOp): + wrapper_op = wrapper_op.parent_op() + if not wrapper_op: + return + + cond = wrapper_op.get_program_param("isBorderRegionPE") + + op.done_exchange.block.insert_arg(cond.type, len(op.done_exchange.block.args)) + + rewriter.insert_op( + if_op := scf.If( + op.done_exchange.block.args[-1], [], Region(Block()), Region(Block()) + ), + InsertPoint.at_start(op.done_exchange.block), + ) + + assert if_op.next_op, "Block cannot be empty" + + if ( + not isinstance(yld := op.done_exchange.block.last_op, csl_stencil.YieldOp) + or not len(yld.arguments) == 0 + ): + return + + body = op.done_exchange.block.split_before(if_op.next_op) + rewriter.inline_block(body, InsertPoint.at_start(if_op.false_region.block)) + + rewriter.insert_op( + csl_stencil.YieldOp(), InsertPoint.at_end(op.done_exchange.block) + ) + rewriter.replace_op(yld, scf.Yield()) + rewriter.insert_op(scf.Yield(), InsertPoint.at_start(if_op.true_region.block)) + rewriter.replace_matched_op( + csl_stencil.ApplyOp( + operands=[ + op.field, + op.accumulator, + [*op.args, cond], + op.dest, + ], + regions=[op.detach_region(r) for r in op.regions], + properties=op.properties, + result_types=op.result_types or [[]], + ) + ) + + +@dataclass(frozen=True) +class CslStencilMaterializeStores(ModulePass): + """ + This pass creates stores for values yielded from `csl_stencil.apply.done_exchange.yield` + to the buffers in `apply.dest`. + Stores should only be materialised for PEs not in the border region. + + The pass operates on memrefs, run after bufferization. + """ + + name = "csl-stencil-materialize-stores" + + def apply(self, ctx: MLContext, op: ModuleOp) -> None: + PatternRewriteWalker( + MaterializeInApplyDest(), apply_recursively=False + ).rewrite_module(op) + PatternRewriteWalker( + DisableComputeInBorderRegion(), apply_recursively=False + ).rewrite_module(op) From b73d4541f5ad6f6c28237c00bb9e2f06da7900f5 Mon Sep 17 00:00:00 2001 From: n-io Date: Fri, 27 Sep 2024 18:17:38 +0200 Subject: [PATCH 2/4] simplifying lower-csl-stencil --- .../transforms/lower-csl-stencil.mlir | 4 +-- xdsl/transforms/lower_csl_stencil.py | 30 ++----------------- 2 files changed, 3 insertions(+), 31 deletions(-) diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index 0cea8e5dde..1ab5119bb7 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -53,7 +53,7 @@ builtin.module { "csl.fadds"(%arg3_1, %arg3_1, %43) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () %45 = arith.constant 1.666600e-01 : f32 "csl.fmuls"(%arg3_1, %arg3_1, %45) : (memref<510xf32>, memref<510xf32>, f32) -> () - csl_stencil.yield %arg3_1 : memref<510xf32> + csl_stencil.yield }) to <[0, 0], [1, 1]> csl.return } @@ -133,8 +133,6 @@ builtin.module { // CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %57) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () // CHECK-NEXT: %59 = arith.constant 1.666600e-01 : f32 // CHECK-NEXT: "csl.fmuls"(%accumulator, %accumulator, %59) : (memref<510xf32>, memref<510xf32>, f32) -> () -// CHECK-NEXT: %60 = memref.subview %arg1[1] [510] [1] : memref<512xf32> to memref<510xf32> -// CHECK-NEXT: "memref.copy"(%accumulator, %60) : (memref<510xf32>, memref<510xf32>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 1d1d3f1fa4..ee41d3225d 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from xdsl.context import MLContext -from xdsl.dialects import arith, func, memref +from xdsl.dialects import arith, func from xdsl.dialects.builtin import ( FunctionType, IndexType, @@ -11,7 +11,7 @@ i16, ) from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper -from xdsl.ir import Attribute, Block, Operation, Region +from xdsl.ir import Block, Operation, Region from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -21,7 +21,6 @@ op_type_rewrite_pattern, ) from xdsl.rewriter import InsertPoint -from xdsl.utils.hints import isa def get_dir_and_distance_ops( @@ -198,31 +197,6 @@ class LowerYieldOp(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: csl_stencil.YieldOp, rewriter: PatternRewriter, /): - assert isinstance(apply := op.parent_op(), csl_stencil.ApplyOp) - - # the second callback stores yielded values to dest - if op.parent_region() == apply.done_exchange: - views: list[Operation] = [] - for src, dst in zip(op.arguments, apply.dest): - assert isa(src.type, memref.MemRefType[Attribute]) - assert isa(dst.type, memref.MemRefType[Attribute]) - views.append( - memref.Subview.get( - dst, - [ - (d - s) // 2 # symmetric offset - for s, d in zip(src.type.get_shape(), dst.type.get_shape()) - ], - src.type.get_shape(), - len(src.type.get_shape()) * [1], - src.type, - ) - ) - copies = [memref.CopyOp(src, dst) for src, dst in zip(op.arguments, views)] - rewriter.insert_op( - [*views, *copies], - InsertPoint.before(op), - ) rewriter.replace_matched_op(csl.ReturnOp()) From d05f560277bfee39792347ecb2d085622d6b5d09 Mon Sep 17 00:00:00 2001 From: n-io Date: Fri, 27 Sep 2024 18:41:17 +0200 Subject: [PATCH 3/4] one pass only --- .../transforms/csl_stencil_materialize_stores.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/xdsl/transforms/csl_stencil_materialize_stores.py b/xdsl/transforms/csl_stencil_materialize_stores.py index 1099419bf9..589949c92b 100644 --- a/xdsl/transforms/csl_stencil_materialize_stores.py +++ b/xdsl/transforms/csl_stencil_materialize_stores.py @@ -7,6 +7,7 @@ from xdsl.ir import Attribute, Block, Operation, Region, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, PatternRewriter, PatternRewriteWalker, RewritePattern, @@ -26,6 +27,8 @@ class MaterializeInApplyDest(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: csl_stencil.YieldOp, rewriter: PatternRewriter, /): + if not len(op.arguments) > 0: + return assert isinstance(apply := op.parent_op(), csl_stencil.ApplyOp) # the second callback stores yielded values to dest @@ -90,6 +93,8 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, return cond = wrapper_op.get_program_param("isBorderRegionPE") + if cond in op.args: + return op.done_exchange.block.insert_arg(cond.type, len(op.done_exchange.block.args)) @@ -145,8 +150,11 @@ class CslStencilMaterializeStores(ModulePass): def apply(self, ctx: MLContext, op: ModuleOp) -> None: PatternRewriteWalker( - MaterializeInApplyDest(), apply_recursively=False - ).rewrite_module(op) - PatternRewriteWalker( - DisableComputeInBorderRegion(), apply_recursively=False + GreedyRewritePatternApplier( + [ + MaterializeInApplyDest(), + DisableComputeInBorderRegion(), + ] + ), + walk_regions_first=True, ).rewrite_module(op) From 6e94be2dacd7f4bfef83713bd0cf0ee8eb13d797 Mon Sep 17 00:00:00 2001 From: n-io Date: Mon, 30 Sep 2024 18:27:02 +0200 Subject: [PATCH 4/4] updates --- .../csl_stencil_materialize_stores.py | 89 +++++++++---------- xdsl/transforms/lower_csl_stencil.py | 5 +- 2 files changed, 47 insertions(+), 47 deletions(-) diff --git a/xdsl/transforms/csl_stencil_materialize_stores.py b/xdsl/transforms/csl_stencil_materialize_stores.py index 589949c92b..e549bf1d04 100644 --- a/xdsl/transforms/csl_stencil_materialize_stores.py +++ b/xdsl/transforms/csl_stencil_materialize_stores.py @@ -20,9 +20,7 @@ @dataclass(frozen=True) class MaterializeInApplyDest(RewritePattern): """ - Lowers csl_stencil.yield to csl.return. - Note, the callbacks generated return no values, whereas the yield op - to be replaced may still report to yield values. + Stores the yielded values to the buffers specified in `apply.dest` instead of yielding them. """ @op_type_rewrite_pattern @@ -31,50 +29,51 @@ def match_and_rewrite(self, op: csl_stencil.YieldOp, rewriter: PatternRewriter, return assert isinstance(apply := op.parent_op(), csl_stencil.ApplyOp) - # the second callback stores yielded values to dest - if op.parent_region() == apply.done_exchange: - views: list[Operation] = [] - add_args: list[SSAValue] = [] - for src, dst in zip(op.arguments, apply.dest): - assert isa(src.type, memref.MemRefType[Attribute]) - assert isa(dst.type, memref.MemRefType[Attribute]) - dst_arg = apply.done_exchange.block.insert_arg( - dst.type, len(apply.done_exchange.block.args) - ) - views.append( - memref.Subview.get( - dst_arg, - [ - (d - s) // 2 # symmetric offset - for s, d in zip(src.type.get_shape(), dst.type.get_shape()) - ], - src.type.get_shape(), - len(src.type.get_shape()) * [1], - src.type, - ) - ) - add_args.append(dst) - copies = [memref.CopyOp(src, dst) for src, dst in zip(op.arguments, views)] - rewriter.insert_op( - [*views, *copies], - InsertPoint.before(op), - ) + if op.parent_region() != apply.done_exchange: + return - rewriter.replace_matched_op(csl_stencil.YieldOp()) - rewriter.replace_op( - apply, - csl_stencil.ApplyOp( - operands=[ - apply.field, - apply.accumulator, - [*apply.args, *add_args], - apply.dest, + views: list[Operation] = [] + add_args: list[SSAValue] = [] + for src, dst in zip(op.arguments, apply.dest, strict=True): + assert isa(src.type, memref.MemRefType[Attribute]) + assert isa(dst.type, memref.MemRefType[Attribute]) + dst_arg = apply.done_exchange.block.insert_arg( + dst.type, len(apply.done_exchange.block.args) + ) + views.append( + memref.Subview.get( + dst_arg, + [ + (d - s) // 2 # symmetric offset + for s, d in zip(src.type.get_shape(), dst.type.get_shape()) ], - regions=[apply.detach_region(r) for r in apply.regions], - properties=apply.properties, - result_types=apply.result_types or [[]], - ), + src.type.get_shape(), + len(src.type.get_shape()) * [1], + src.type, + ) ) + add_args.append(dst) + copies = [memref.CopyOp(src, dst) for src, dst in zip(op.arguments, views)] + rewriter.insert_op( + [*views, *copies], + InsertPoint.before(op), + ) + + rewriter.replace_matched_op(csl_stencil.YieldOp()) + rewriter.replace_op( + apply, + csl_stencil.ApplyOp( + operands=[ + apply.field, + apply.accumulator, + [*apply.args, *add_args], + apply.dest, + ], + regions=[apply.detach_region(r) for r in apply.regions], + properties=apply.properties, + result_types=apply.result_types or [[]], + ), + ) @dataclass(frozen=True) @@ -109,7 +108,7 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, if ( not isinstance(yld := op.done_exchange.block.last_op, csl_stencil.YieldOp) - or not len(yld.arguments) == 0 + or len(yld.arguments) > 0 ): return diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index ee41d3225d..70ef2f11a8 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -191,8 +191,9 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, class LowerYieldOp(RewritePattern): """ Lowers csl_stencil.yield to csl.return. - Note, the callbacks generated return no values, whereas the yield op - to be replaced may still report to yield values. + Note, the callbacks generated return no values, and the yield op + to be replaced should also yield no values. This should be run + after `--csl-stencil-materialize-stores`. """ @op_type_rewrite_pattern