From dac420ec4b66da4beac575abffc0979c7890a9f9 Mon Sep 17 00:00:00 2001 From: n-io Date: Wed, 9 Oct 2024 21:56:23 +0200 Subject: [PATCH 1/5] transformations: (lower-csl-stencil) Optimise full-stencil access --- .../transforms/lower-csl-stencil.mlir | 211 ++++++++---------- xdsl/transforms/lower_csl_stencil.py | 121 ++++++++++ 2 files changed, 220 insertions(+), 112 deletions(-) diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index 822a3f72ac..421c35b330 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -104,35 +104,22 @@ builtin.module { // CHECK-NEXT: } // CHECK-NEXT: csl.func @receive_chunk_cb0(%offset : i16) { // CHECK-NEXT: %offset_1 = arith.index_cast %offset : i16 to index -// CHECK-NEXT: %41 = arith.constant 1 : i16 -// CHECK-NEXT: %42 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %43 = "csl.member_call"(%34, %42, %41) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %44 = builtin.unrealized_conversion_cast %43 : !csl to memref<255xf32> -// CHECK-NEXT: %45 = arith.constant 1 : i16 -// CHECK-NEXT: %46 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %47 = "csl.member_call"(%34, %46, %45) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %48 = builtin.unrealized_conversion_cast %47 : !csl to memref<255xf32> -// CHECK-NEXT: %49 = arith.constant 1 : i16 -// CHECK-NEXT: %50 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %51 = "csl.member_call"(%34, %50, %49) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %52 = builtin.unrealized_conversion_cast %51 : !csl to memref<255xf32> -// CHECK-NEXT: %53 = arith.constant 1 : i16 -// CHECK-NEXT: %54 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %55 = "csl.member_call"(%34, %54, %53) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %56 = builtin.unrealized_conversion_cast %55 : !csl to memref<255xf32> -// CHECK-NEXT: %57 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> -// CHECK-NEXT: "csl.fadds"(%57, %56, %52) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%57, %57, %48) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%57, %57, %44) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () +// CHECK-NEXT: %41 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> +// CHECK-NEXT: %42 = arith.constant 4 : i16 +// CHECK-NEXT: %43 = "csl.get_mem_dsd"(%accumulator, %42, %29, %31) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl +// CHECK-NEXT: %44 = arith.index_cast %offset_1 : index to si16 +// CHECK-NEXT: %45 = "csl.increment_dsd_offset"(%43, %44) <{"elem_type" = f32}> : (!csl, si16) -> !csl +// CHECK-NEXT: %46 = "csl.member_call"(%34) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl +// CHECK-NEXT: "csl.fadds"(%45, %45, %46) : (!csl, !csl, !csl) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @done_exchange_cb0() { -// CHECK-NEXT: %58 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> -// CHECK-NEXT: %59 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> -// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %59) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () -// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %58) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () -// CHECK-NEXT: %60 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: "csl.fmuls"(%accumulator, %accumulator, %60) : (memref<510xf32>, memref<510xf32>, f32) -> () +// CHECK-NEXT: %47 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %48 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %48) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () +// CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %47) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () +// CHECK-NEXT: %49 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%accumulator, %accumulator, %49) : (memref<510xf32>, memref<510xf32>, f32) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () @@ -249,52 +236,52 @@ builtin.module { // CHECK-NEXT: "csl_wrapper.module"() <{"height" = 512 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=1 : i16>, #csl_wrapper.param<"chunk_size" default=510 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "loop", "width" = 1024 : i16}> ({ // CHECK-NEXT: ^2(%arg0_1 : i16, %arg1_1 : i16, %arg2 : i16, %arg3 : i16, %arg4 : i16, %arg5 : i16, %arg6 : i16, %arg7 : i16, %arg8 : i16): -// CHECK-NEXT: %61 = arith.constant 0 : i16 -// CHECK-NEXT: %62 = "csl.get_color"(%61) : (i16) -> !csl.color -// CHECK-NEXT: %63 = "csl_wrapper.import"(%arg2, %arg3, %62) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module -// CHECK-NEXT: %64 = "csl_wrapper.import"(%arg5, %arg2, %arg3) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module -// CHECK-NEXT: %65 = "csl.member_call"(%64, %arg0_1, %arg1_1, %arg2, %arg3, %arg5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct -// CHECK-NEXT: %66 = "csl.member_call"(%63, %arg0_1) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct -// CHECK-NEXT: %67 = arith.constant 1 : i16 -// CHECK-NEXT: %68 = arith.subi %arg5, %67 : i16 -// CHECK-NEXT: %69 = arith.subi %arg2, %arg0_1 : i16 -// CHECK-NEXT: %70 = arith.subi %arg3, %arg1_1 : i16 -// CHECK-NEXT: %71 = arith.cmpi slt, %arg0_1, %68 : i16 -// CHECK-NEXT: %72 = arith.cmpi slt, %arg1_1, %68 : i16 -// CHECK-NEXT: %73 = arith.cmpi slt, %69, %arg5 : i16 -// CHECK-NEXT: %74 = arith.cmpi slt, %70, %arg5 : i16 -// CHECK-NEXT: %75 = arith.ori %71, %72 : i1 -// CHECK-NEXT: %76 = arith.ori %75, %73 : i1 -// CHECK-NEXT: %77 = arith.ori %76, %74 : i1 -// CHECK-NEXT: "csl_wrapper.yield"(%66, %65, %77) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () +// CHECK-NEXT: %50 = arith.constant 0 : i16 +// CHECK-NEXT: %51 = "csl.get_color"(%50) : (i16) -> !csl.color +// CHECK-NEXT: %52 = "csl_wrapper.import"(%arg2, %arg3, %51) <{"fields" = ["width", "height", "LAUNCH"], "module" = ""}> : (i16, i16, !csl.color) -> !csl.imported_module +// CHECK-NEXT: %53 = "csl_wrapper.import"(%arg5, %arg2, %arg3) <{"fields" = ["pattern", "peWidth", "peHeight"], "module" = "routes.csl"}> : (i16, i16, i16) -> !csl.imported_module +// CHECK-NEXT: %54 = "csl.member_call"(%53, %arg0_1, %arg1_1, %arg2, %arg3, %arg5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct +// CHECK-NEXT: %55 = "csl.member_call"(%52, %arg0_1) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct +// CHECK-NEXT: %56 = arith.constant 1 : i16 +// CHECK-NEXT: %57 = arith.subi %arg5, %56 : i16 +// CHECK-NEXT: %58 = arith.subi %arg2, %arg0_1 : i16 +// CHECK-NEXT: %59 = arith.subi %arg3, %arg1_1 : i16 +// CHECK-NEXT: %60 = arith.cmpi slt, %arg0_1, %57 : i16 +// CHECK-NEXT: %61 = arith.cmpi slt, %arg1_1, %57 : i16 +// CHECK-NEXT: %62 = arith.cmpi slt, %58, %arg5 : i16 +// CHECK-NEXT: %63 = arith.cmpi slt, %59, %arg5 : i16 +// CHECK-NEXT: %64 = arith.ori %60, %61 : i1 +// CHECK-NEXT: %65 = arith.ori %64, %62 : i1 +// CHECK-NEXT: %66 = arith.ori %65, %63 : i1 +// CHECK-NEXT: "csl_wrapper.yield"(%55, %54, %66) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () // CHECK-NEXT: }, { // CHECK-NEXT: ^3(%arg0_2 : i16, %arg1_2 : i16, %arg2_1 : i16, %arg3_1 : i16, %arg4_1 : i16, %arg5_1 : i16, %arg6_1 : i16, %arg7_1 : !csl.comptime_struct, %arg8_1 : !csl.comptime_struct, %arg9 : i1): -// CHECK-NEXT: %78 = "csl_wrapper.import"(%arg7_1) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module -// CHECK-NEXT: %79 = "csl_wrapper.import"(%arg3_1, %arg5_1, %arg8_1) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module -// CHECK-NEXT: %80 = memref.alloc() : memref<512xf32> -// CHECK-NEXT: %81 = memref.alloc() : memref<512xf32> -// CHECK-NEXT: %82 = "csl.addressof"(%80) : (memref<512xf32>) -> !csl.ptr, #csl> -// CHECK-NEXT: %83 = "csl.addressof"(%81) : (memref<512xf32>) -> !csl.ptr, #csl> -// CHECK-NEXT: "csl.export"(%82) <{"type" = !csl.ptr, #csl>, "var_name" = "a"}> : (!csl.ptr, #csl>) -> () -// CHECK-NEXT: "csl.export"(%83) <{"type" = !csl.ptr, #csl>, "var_name" = "b"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: %67 = "csl_wrapper.import"(%arg7_1) <{"fields" = [""], "module" = ""}> : (!csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %68 = "csl_wrapper.import"(%arg3_1, %arg5_1, %arg8_1) <{"fields" = ["pattern", "chunkSize", ""], "module" = "stencil_comms.csl"}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %69 = memref.alloc() : memref<512xf32> +// CHECK-NEXT: %70 = memref.alloc() : memref<512xf32> +// CHECK-NEXT: %71 = "csl.addressof"(%69) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: %72 = "csl.addressof"(%70) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: "csl.export"(%71) <{"type" = !csl.ptr, #csl>, "var_name" = "a"}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"(%72) <{"type" = !csl.ptr, #csl>, "var_name" = "b"}> : (!csl.ptr, #csl>) -> () // CHECK-NEXT: "csl.export"() <{"type" = () -> (), "var_name" = @gauss_seidel_func}> : () -> () -// CHECK-NEXT: %84 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var -// CHECK-NEXT: %85 = "csl.variable"() : () -> !csl.var> -// CHECK-NEXT: %86 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: %73 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var +// CHECK-NEXT: %74 = "csl.variable"() : () -> !csl.var> +// CHECK-NEXT: %75 = "csl.variable"() : () -> !csl.var> // CHECK-NEXT: csl.func @loop() { -// CHECK-NEXT: %87 = arith.constant 0 : index -// CHECK-NEXT: %88 = arith.constant 1000 : index -// CHECK-NEXT: %89 = arith.constant 1 : index -// CHECK-NEXT: "csl.store_var"(%85, %80) : (!csl.var>, memref<512xf32>) -> () -// CHECK-NEXT: "csl.store_var"(%86, %81) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: %76 = arith.constant 0 : index +// CHECK-NEXT: %77 = arith.constant 1000 : index +// CHECK-NEXT: %78 = arith.constant 1 : index +// CHECK-NEXT: "csl.store_var"(%74, %69) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%75, %70) : (!csl.var>, memref<512xf32>) -> () // CHECK-NEXT: csl.activate local, 1 : i32 // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.task @for_cond0() attributes {"kind" = #csl, "id" = 1 : i5}{ -// CHECK-NEXT: %90 = arith.constant 1000 : i16 -// CHECK-NEXT: %91 = "csl.load_var"(%84) : (!csl.var) -> i16 -// CHECK-NEXT: %92 = arith.cmpi slt, %91, %90 : i16 -// CHECK-NEXT: scf.if %92 { +// CHECK-NEXT: %79 = arith.constant 1000 : i16 +// CHECK-NEXT: %80 = "csl.load_var"(%73) : (!csl.var) -> i16 +// CHECK-NEXT: %81 = arith.cmpi slt, %80, %79 : i16 +// CHECK-NEXT: scf.if %81 { // CHECK-NEXT: "csl.call"() <{"callee" = @for_body0}> : () -> () // CHECK-NEXT: } else { // CHECK-NEXT: "csl.call"() <{"callee" = @for_post0}> : () -> () @@ -302,73 +289,73 @@ builtin.module { // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_body0() { -// CHECK-NEXT: %arg10 = "csl.load_var"(%84) : (!csl.var) -> i16 -// CHECK-NEXT: %arg11 = "csl.load_var"(%85) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: %arg12 = "csl.load_var"(%86) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg10 = "csl.load_var"(%73) : (!csl.var) -> i16 +// CHECK-NEXT: %arg11 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg12 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: %accumulator_1 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> -// CHECK-NEXT: %93 = arith.constant 1 : i16 -// CHECK-NEXT: %94 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %95 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl, #csl> -// CHECK-NEXT: %96 = memref.subview %arg11[1] [510] [1] : memref<512xf32> to memref<510xf32> -// CHECK-NEXT: "csl.member_call"(%79, %96, %93, %94, %95) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () +// CHECK-NEXT: %82 = arith.constant 1 : i16 +// CHECK-NEXT: %83 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %84 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %85 = memref.subview %arg11[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "csl.member_call"(%68, %85, %82, %83, %84) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @receive_chunk_cb1(%offset_2 : i16) { // CHECK-NEXT: %offset_3 = arith.index_cast %offset_2 : i16 to index -// CHECK-NEXT: %97 = arith.constant 1 : i16 -// CHECK-NEXT: %98 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %99 = "csl.member_call"(%79, %98, %97) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %100 = builtin.unrealized_conversion_cast %99 : !csl to memref<510xf32> -// CHECK-NEXT: %101 = arith.constant 1 : i16 -// CHECK-NEXT: %102 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %103 = "csl.member_call"(%79, %102, %101) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %104 = builtin.unrealized_conversion_cast %103 : !csl to memref<510xf32> -// CHECK-NEXT: %105 = arith.constant 1 : i16 -// CHECK-NEXT: %106 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %107 = "csl.member_call"(%79, %106, %105) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %108 = builtin.unrealized_conversion_cast %107 : !csl to memref<510xf32> -// CHECK-NEXT: %109 = arith.constant 1 : i16 -// CHECK-NEXT: %110 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %111 = "csl.member_call"(%79, %110, %109) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %112 = builtin.unrealized_conversion_cast %111 : !csl to memref<510xf32> -// CHECK-NEXT: %113 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> -// CHECK-NEXT: "csl.fadds"(%113, %112, %108) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>, memref<510xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%113, %113, %104) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%113, %113, %100) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () -// CHECK-NEXT: "memref.copy"(%113, %113) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () +// CHECK-NEXT: %86 = arith.constant 1 : i16 +// CHECK-NEXT: %87 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %88 = "csl.member_call"(%68, %87, %86) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %89 = builtin.unrealized_conversion_cast %88 : !csl to memref<510xf32> +// CHECK-NEXT: %90 = arith.constant 1 : i16 +// CHECK-NEXT: %91 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %92 = "csl.member_call"(%68, %91, %90) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %93 = builtin.unrealized_conversion_cast %92 : !csl to memref<510xf32> +// CHECK-NEXT: %94 = arith.constant 1 : i16 +// CHECK-NEXT: %95 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %96 = "csl.member_call"(%68, %95, %94) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %97 = builtin.unrealized_conversion_cast %96 : !csl to memref<510xf32> +// CHECK-NEXT: %98 = arith.constant 1 : i16 +// CHECK-NEXT: %99 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %100 = "csl.member_call"(%68, %99, %98) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %101 = builtin.unrealized_conversion_cast %100 : !csl to memref<510xf32> +// CHECK-NEXT: %102 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> +// CHECK-NEXT: "csl.fadds"(%102, %101, %97) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%102, %102, %93) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%102, %102, %89) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () +// CHECK-NEXT: "memref.copy"(%102, %102) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @done_exchange_cb1() { -// CHECK-NEXT: %arg12_1 = "csl.load_var"(%86) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: %arg11_1 = "csl.load_var"(%85) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg12_1 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %arg11_1 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: scf.if %arg9 { // CHECK-NEXT: } else { -// CHECK-NEXT: %114 = memref.subview %arg11_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> -// CHECK-NEXT: %115 = memref.subview %arg11_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> -// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %115) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () -// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %114) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () -// CHECK-NEXT: %116 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: "csl.fmuls"(%accumulator_1, %accumulator_1, %116) : (memref<510xf32>, memref<510xf32>, f32) -> () -// CHECK-NEXT: %117 = memref.subview %arg12_1[1] [510] [1] : memref<512xf32> to memref<510xf32> -// CHECK-NEXT: "memref.copy"(%accumulator_1, %117) : (memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: %103 = memref.subview %arg11_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %104 = memref.subview %arg11_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %104) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () +// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %103) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () +// CHECK-NEXT: %105 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%accumulator_1, %accumulator_1, %105) : (memref<510xf32>, memref<510xf32>, f32) -> () +// CHECK-NEXT: %106 = memref.subview %arg12_1[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "memref.copy"(%accumulator_1, %106) : (memref<510xf32>, memref<510xf32>) -> () // CHECK-NEXT: } // CHECK-NEXT: "csl.call"() <{"callee" = @for_inc0}> : () -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_inc0() { -// CHECK-NEXT: %118 = arith.constant 1 : i16 -// CHECK-NEXT: %119 = "csl.load_var"(%84) : (!csl.var) -> i16 -// CHECK-NEXT: %120 = arith.addi %119, %118 : i16 -// CHECK-NEXT: "csl.store_var"(%84, %120) : (!csl.var, i16) -> () -// CHECK-NEXT: %121 = "csl.load_var"(%85) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: %122 = "csl.load_var"(%86) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: "csl.store_var"(%85, %122) : (!csl.var>, memref<512xf32>) -> () -// CHECK-NEXT: "csl.store_var"(%86, %121) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: %107 = arith.constant 1 : i16 +// CHECK-NEXT: %108 = "csl.load_var"(%73) : (!csl.var) -> i16 +// CHECK-NEXT: %109 = arith.addi %108, %107 : i16 +// CHECK-NEXT: "csl.store_var"(%73, %109) : (!csl.var, i16) -> () +// CHECK-NEXT: %110 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %111 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: "csl.store_var"(%74, %111) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%75, %110) : (!csl.var>, memref<512xf32>) -> () // CHECK-NEXT: csl.activate local, 1 : i32 // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_post0() { -// CHECK-NEXT: "csl.member_call"(%78) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () +// CHECK-NEXT: "csl.member_call"(%67) <{"field" = "unblock_cmd_stream"}> : (!csl.imported_module) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 42ee90a17e..3e1d105dc8 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -1,8 +1,10 @@ from dataclasses import dataclass +from typing import cast from xdsl.context import MLContext from xdsl.dialects import arith, func, memref from xdsl.dialects.builtin import ( + ArrayAttr, FunctionType, IndexType, IntegerAttr, @@ -23,6 +25,7 @@ from xdsl.rewriter import InsertPoint from xdsl.traits import is_side_effect_free from xdsl.utils.hints import isa +from xdsl.utils.isattr import isattr def get_dir_and_distance_ops( @@ -252,6 +255,123 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, block_arg.replace_by(SSAValue.get(new_arg)) +@dataclass(frozen=True) +class FullStencilAccessImmediateReductionOptimization(RewritePattern): + """ + If an apply op accesses all points in the stencil shape *and* immediately performs a reduction, + lower to an API call that iterates over all receive buffers at once. This requires setting up a + 4d dsd that disregards all but one dimensions. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, /): + # check that apply is inside a csl_wrapper and retreive `pattern` (stencil arm length + self) + if (wrapper := _get_module_wrapper(op)) is None: + return + pattern = wrapper.get_param_value("pattern").value.data + + # get csl_stencil.access ops and offsets + access_ops: list[csl_stencil.AccessOp] = [ + a for a in op.receive_chunk.walk() if isinstance(a, csl_stencil.AccessOp) + ] + offsets = set(tuple(a.offset) for a in access_ops) + + # this rewrite only works if all points in the stencil shape are accessed + if not self.is_full_2d_starshaped_access(offsets, pattern - 1): + return + + # find potential 'reduction' ops + reduction_ops: set[Operation] = set( + u.operation for a in access_ops for u in a.result.uses + ) + + # check if reduction ops are of the same type + red_op_ts = set(type(r) for r in reduction_ops) + if len(red_op_ts) > 1 or (red_op_t := red_op_ts.pop()) not in [ + csl.FaddsOp, + csl.FmulsOp, + ]: + return + red_ops = cast(set[csl.BuiltinDsdOp], reduction_ops) + + # safety check 1: each access has one use + acc_ops_uses = set(len(a.result.uses) for a in access_ops) + if len(acc_ops_uses) > 1 or acc_ops_uses.pop() != 1: + return + + # safety check 2: reduction ops use only access ops + # note, we have already checked that each access op is only consumed once + red_args = set(arg for r in red_ops for arg in r.ops) + nonaccess_args = red_args - set(a.result for a in access_ops) + if len(nonaccess_args) > 1: + return + + # safety check 3: the non-access op is an accumulator, used as the result param of all reduction ops + accumulator = nonaccess_args.pop() + if any(accumulator != r.ops[0] for r in red_ops): + return + + if not isattr(accumulator.type, memref.MemRefType): + raise ValueError("Pass needs to be run on memref types") + + # op.accumulator needs to be a memref alloc for this pass to work + if not isinstance(op.accumulator, OpResult) or not isinstance( + alloc := op.accumulator.op, memref.Alloc + ): + return + + dsd_t = csl.DsdType(csl.DsdKind.mem4d_dsd) + direction_count = arith.Constant.from_int_and_width(4, 16) + pattern = wrapper.get_program_param("pattern") + chunk_size = wrapper.get_program_param("chunk_size") + acc_dsd = csl.GetMemDsdOp.build( + operands=[alloc, [direction_count, pattern, chunk_size]], + result_types=[dsd_t], + properties={"strides": ArrayAttr([IntegerAttr(i, 16) for i in [0, 0, 1]])}, + ) + + new_ops: list[Operation] = [direction_count, acc_dsd] + if ( + isinstance(accumulator, OpResult) + and isinstance(subview := accumulator.op, memref.Subview) + and subview.source == op.receive_chunk.block.args[2] + ): + assert isa(subview.source.type, memref.MemRefType[Attribute]) + new_ops.append( + cast_op := arith.IndexCastOp(subview.offsets[0], csl.i16_value) + ) + new_ops.append( + csl.IncrementDsdOffsetOp.build( + operands=[acc_dsd, cast_op], + properties={"elem_type": subview.source.type.get_element_type()}, + result_types=[dsd_t], + ) + ) + new_acc = new_ops[-1] + + api_call = csl.MemberCallOp( + "getRecvBufDsd", dsd_t, wrapper.get_program_import("stencil_comms.csl"), [] + ) + + reduction_func = red_op_t.build(operands=[[new_acc, new_acc, api_call]]) + + rewriter.insert_op( + [*new_ops, api_call, reduction_func], InsertPoint.after(list(red_ops)[-1]) + ) + + for e in [*access_ops, *red_ops]: + rewriter.erase_op(e, safe_erase=False) + + @staticmethod + def is_full_2d_starshaped_access( + offsets: set[tuple[int, ...]], max_offset: int + ) -> bool: + """Returns iff the offsets cover all points in a 2d star-shape without the (0,0) point.""" + x_set = set((x, 0) for x in range(-max_offset, max_offset + 1)) + y_set = set((0, y) for y in range(-max_offset, max_offset + 1)) + return offsets == x_set ^ y_set + + @dataclass(frozen=True) class LowerCslStencil(ModulePass): """ @@ -274,6 +394,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: [ LowerYieldOp(), InlineApplyOpArgs(), + FullStencilAccessImmediateReductionOptimization(), ] ), apply_recursively=False, From 7e1aba26712e0b34505a1d47d7071896e4969661 Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 10 Oct 2024 11:56:13 +0200 Subject: [PATCH 2/5] Moving opt as first in apply-recursively pass --- xdsl/transforms/lower_csl_stencil.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 3e1d105dc8..461e7d5d13 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -394,7 +394,6 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: [ LowerYieldOp(), InlineApplyOpArgs(), - FullStencilAccessImmediateReductionOptimization(), ] ), apply_recursively=False, @@ -402,6 +401,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: module_pass = PatternRewriteWalker( GreedyRewritePatternApplier( [ + FullStencilAccessImmediateReductionOptimization(), LowerAccessOp(), LowerApplyOp(), ] From 197c2b5cf53755063a6c9ae0d6fd441a0339424e Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 10 Oct 2024 12:33:56 +0200 Subject: [PATCH 3/5] fix filecheck --- .../transforms/lower-csl-stencil.mlir | 61 ++++++++----------- 1 file changed, 24 insertions(+), 37 deletions(-) diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index 421c35b330..dfec7bda2e 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -302,27 +302,14 @@ builtin.module { // CHECK-NEXT: } // CHECK-NEXT: csl.func @receive_chunk_cb1(%offset_2 : i16) { // CHECK-NEXT: %offset_3 = arith.index_cast %offset_2 : i16 to index -// CHECK-NEXT: %86 = arith.constant 1 : i16 -// CHECK-NEXT: %87 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %88 = "csl.member_call"(%68, %87, %86) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %89 = builtin.unrealized_conversion_cast %88 : !csl to memref<510xf32> -// CHECK-NEXT: %90 = arith.constant 1 : i16 -// CHECK-NEXT: %91 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %92 = "csl.member_call"(%68, %91, %90) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %93 = builtin.unrealized_conversion_cast %92 : !csl to memref<510xf32> -// CHECK-NEXT: %94 = arith.constant 1 : i16 -// CHECK-NEXT: %95 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %96 = "csl.member_call"(%68, %95, %94) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %97 = builtin.unrealized_conversion_cast %96 : !csl to memref<510xf32> -// CHECK-NEXT: %98 = arith.constant 1 : i16 -// CHECK-NEXT: %99 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction -// CHECK-NEXT: %100 = "csl.member_call"(%68, %99, %98) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl -// CHECK-NEXT: %101 = builtin.unrealized_conversion_cast %100 : !csl to memref<510xf32> -// CHECK-NEXT: %102 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> -// CHECK-NEXT: "csl.fadds"(%102, %101, %97) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>, memref<510xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%102, %102, %93) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () -// CHECK-NEXT: "csl.fadds"(%102, %102, %89) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>, memref<510xf32>) -> () -// CHECK-NEXT: "memref.copy"(%102, %102) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () +// CHECK-NEXT: %86 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> +// CHECK-NEXT: %87 = arith.constant 4 : i16 +// CHECK-NEXT: %88 = "csl.get_mem_dsd"(%accumulator_1, %87, %arg3_1, %arg5_1) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl +// CHECK-NEXT: %89 = arith.index_cast %offset_3 : index to si16 +// CHECK-NEXT: %90 = "csl.increment_dsd_offset"(%88, %89) <{"elem_type" = f32}> : (!csl, si16) -> !csl +// CHECK-NEXT: %91 = "csl.member_call"(%68) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl +// CHECK-NEXT: "csl.fadds"(%90, %90, %91) : (!csl, !csl, !csl) -> () +// CHECK-NEXT: "memref.copy"(%86, %86) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @done_exchange_cb1() { @@ -330,27 +317,27 @@ builtin.module { // CHECK-NEXT: %arg11_1 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: scf.if %arg9 { // CHECK-NEXT: } else { -// CHECK-NEXT: %103 = memref.subview %arg11_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> -// CHECK-NEXT: %104 = memref.subview %arg11_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> -// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %104) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () -// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %103) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () -// CHECK-NEXT: %105 = arith.constant 1.666600e-01 : f32 -// CHECK-NEXT: "csl.fmuls"(%accumulator_1, %accumulator_1, %105) : (memref<510xf32>, memref<510xf32>, f32) -> () -// CHECK-NEXT: %106 = memref.subview %arg12_1[1] [510] [1] : memref<512xf32> to memref<510xf32> -// CHECK-NEXT: "memref.copy"(%accumulator_1, %106) : (memref<510xf32>, memref<510xf32>) -> () +// CHECK-NEXT: %92 = memref.subview %arg11_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %93 = memref.subview %arg11_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %93) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () +// CHECK-NEXT: "csl.fadds"(%accumulator_1, %accumulator_1, %92) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () +// CHECK-NEXT: %94 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%accumulator_1, %accumulator_1, %94) : (memref<510xf32>, memref<510xf32>, f32) -> () +// CHECK-NEXT: %95 = memref.subview %arg12_1[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "memref.copy"(%accumulator_1, %95) : (memref<510xf32>, memref<510xf32>) -> () // CHECK-NEXT: } // CHECK-NEXT: "csl.call"() <{"callee" = @for_inc0}> : () -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_inc0() { -// CHECK-NEXT: %107 = arith.constant 1 : i16 -// CHECK-NEXT: %108 = "csl.load_var"(%73) : (!csl.var) -> i16 -// CHECK-NEXT: %109 = arith.addi %108, %107 : i16 -// CHECK-NEXT: "csl.store_var"(%73, %109) : (!csl.var, i16) -> () -// CHECK-NEXT: %110 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: %111 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> -// CHECK-NEXT: "csl.store_var"(%74, %111) : (!csl.var>, memref<512xf32>) -> () -// CHECK-NEXT: "csl.store_var"(%75, %110) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: %96 = arith.constant 1 : i16 +// CHECK-NEXT: %97 = "csl.load_var"(%73) : (!csl.var) -> i16 +// CHECK-NEXT: %98 = arith.addi %97, %96 : i16 +// CHECK-NEXT: "csl.store_var"(%73, %98) : (!csl.var, i16) -> () +// CHECK-NEXT: %99 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: %100 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> +// CHECK-NEXT: "csl.store_var"(%74, %100) : (!csl.var>, memref<512xf32>) -> () +// CHECK-NEXT: "csl.store_var"(%75, %99) : (!csl.var>, memref<512xf32>) -> () // CHECK-NEXT: csl.activate local, 1 : i32 // CHECK-NEXT: csl.return // CHECK-NEXT: } From 42af94b49f4c87d3fd83741606149e350dae43d5 Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 10 Oct 2024 15:22:47 +0200 Subject: [PATCH 4/5] small fixes, updating comments and docstrings --- xdsl/transforms/lower_csl_stencil.py | 57 ++++++++++++++++++---------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index 461e7d5d13..231a318152 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -260,7 +260,23 @@ class FullStencilAccessImmediateReductionOptimization(RewritePattern): """ If an apply op accesses all points in the stencil shape *and* immediately performs a reduction, lower to an API call that iterates over all receive buffers at once. This requires setting up a - 4d dsd that disregards all but one dimensions. + 4d dsd that disregards all but one dimension. + + The optimisation checks if it can be applied, and if so, sets up a new mem4d_dsd accumulator, lowers all + relevant `csl_stencil.access` calls to a single mem4d_dsd API call, and replaces all relevant reduction ops + with a single reduction op over the two mem4d_dsds. + + Note, if the optimisation is not applied, `csl_stencil.access` calls are left untouched to be handled by + the `LowerAccessOp` pass instead and translated to individual mem1d_dsd API calls. + + The optimisation is applied on the `csl_stencil.apply.receive_chunk` region iff: + * each point in the stencil shaped is accessed + * each `csl_stencil.access` has exactly one use + * each access is immediately processed by the same (type of) reduction op + * each reduction op uses the same accumulator to store a result + * each reduction op uses no inputs except from the above access ops + * todo: the data of the accumulator is not itself an input of the reduction + * todo: no other ops modify the accumulator in-between reduction ops """ @op_type_rewrite_pattern @@ -294,32 +310,31 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, return red_ops = cast(set[csl.BuiltinDsdOp], reduction_ops) - # safety check 1: each access has one use - acc_ops_uses = set(len(a.result.uses) for a in access_ops) - if len(acc_ops_uses) > 1 or acc_ops_uses.pop() != 1: + # check: only apply rewrite if each access has exactly one use + if any(len(a.result.uses) != 1 for a in access_ops): return - # safety check 2: reduction ops use only access ops - # note, we have already checked that each access op is only consumed once + # check: only apply rewrite if reduction ops use `access` ops only (plus one other, checked below) + # note, we have already checked that each access op is only consumed once, which by implication is here red_args = set(arg for r in red_ops for arg in r.ops) nonaccess_args = red_args - set(a.result for a in access_ops) if len(nonaccess_args) > 1: return - # safety check 3: the non-access op is an accumulator, used as the result param of all reduction ops + # check: only apply rewrite if the non-`access` op is an accumulator and the result param in all reduction ops accumulator = nonaccess_args.pop() if any(accumulator != r.ops[0] for r in red_ops): return - if not isattr(accumulator.type, memref.MemRefType): - raise ValueError("Pass needs to be run on memref types") - - # op.accumulator needs to be a memref alloc for this pass to work - if not isinstance(op.accumulator, OpResult) or not isinstance( - alloc := op.accumulator.op, memref.Alloc + if ( + not isattr(accumulator.type, memref.MemRefType) + or not isinstance(op.accumulator, OpResult) + or not isinstance(alloc := op.accumulator.op, memref.Alloc) ): - return + raise ValueError("Pass needs to be run on memref types") + # Set up new accumulator GetMemDsd, with 0-stride in `direction` and `distance` dimensions. + # Effectively, this activates only the z-value dimension. dsd_t = csl.DsdType(csl.DsdKind.mem4d_dsd) direction_count = arith.Constant.from_int_and_width(4, 16) pattern = wrapper.get_program_param("pattern") @@ -329,7 +344,9 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, result_types=[dsd_t], properties={"strides": ArrayAttr([IntegerAttr(i, 16) for i in [0, 0, 1]])}, ) + new_acc = acc_dsd + # If the accumulator is a subview at an offset, generate IncrementDsdOffset op (and index_cast). new_ops: list[Operation] = [direction_count, acc_dsd] if ( isinstance(accumulator, OpResult) @@ -341,22 +358,24 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, cast_op := arith.IndexCastOp(subview.offsets[0], csl.i16_value) ) new_ops.append( - csl.IncrementDsdOffsetOp.build( + new_acc := csl.IncrementDsdOffsetOp.build( operands=[acc_dsd, cast_op], properties={"elem_type": subview.source.type.get_element_type()}, result_types=[dsd_t], ) ) - new_acc = new_ops[-1] - api_call = csl.MemberCallOp( + # get dsd iterator over all points in stencil + full_stencil_dsd = csl.MemberCallOp( "getRecvBufDsd", dsd_t, wrapper.get_program_import("stencil_comms.csl"), [] ) - reduction_func = red_op_t.build(operands=[[new_acc, new_acc, api_call]]) + # rebuild compute func + reduction_op = red_op_t.build(operands=[[new_acc, new_acc, full_stencil_dsd]]) rewriter.insert_op( - [*new_ops, api_call, reduction_func], InsertPoint.after(list(red_ops)[-1]) + [*new_ops, full_stencil_dsd, reduction_op], + InsertPoint.after(list(red_ops)[-1]), ) for e in [*access_ops, *red_ops]: From 367c52dc62244ca49a273222c57de006f2bcba65 Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 10 Oct 2024 15:45:15 +0200 Subject: [PATCH 5/5] add filecheck for partial access generating getRecvBufDsdByNeighbor API calls --- .../transforms/lower-csl-stencil.mlir | 145 +++++++++++++++++- 1 file changed, 137 insertions(+), 8 deletions(-) diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index dfec7bda2e..552e935005 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -96,13 +96,13 @@ builtin.module { // CHECK-NEXT: csl.func @gauss_seidel_func() { // CHECK-NEXT: %accumulator = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> // CHECK-NEXT: %37 = arith.constant 2 : i16 -// CHECK-NEXT: %38 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb0}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb0}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %38 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %39 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl, #csl> // CHECK-NEXT: %40 = memref.subview %arg0[1] [510] [1] : memref<512xf32> to memref<510xf32> // CHECK-NEXT: "csl.member_call"(%34, %40, %37, %38, %39) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @receive_chunk_cb0(%offset : i16) { +// CHECK-NEXT: csl.func @receive_chunk_cb1(%offset : i16) { // CHECK-NEXT: %offset_1 = arith.index_cast %offset : i16 to index // CHECK-NEXT: %41 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> // CHECK-NEXT: %42 = arith.constant 4 : i16 @@ -113,7 +113,7 @@ builtin.module { // CHECK-NEXT: "csl.fadds"(%45, %45, %46) : (!csl, !csl, !csl) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @done_exchange_cb0() { +// CHECK-NEXT: csl.func @done_exchange_cb1() { // CHECK-NEXT: %47 = memref.subview %arg0[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> // CHECK-NEXT: %48 = memref.subview %arg0[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> // CHECK-NEXT: "csl.fadds"(%accumulator, %accumulator, %48) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () @@ -294,13 +294,13 @@ builtin.module { // CHECK-NEXT: %arg12 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: %accumulator_1 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> // CHECK-NEXT: %82 = arith.constant 1 : i16 -// CHECK-NEXT: %83 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb1}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> -// CHECK-NEXT: %84 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb1}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %83 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb2}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %84 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb2}> : () -> !csl.ptr<() -> (), #csl, #csl> // CHECK-NEXT: %85 = memref.subview %arg11[1] [510] [1] : memref<512xf32> to memref<510xf32> // CHECK-NEXT: "csl.member_call"(%68, %85, %82, %83, %84) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @receive_chunk_cb1(%offset_2 : i16) { +// CHECK-NEXT: csl.func @receive_chunk_cb2(%offset_2 : i16) { // CHECK-NEXT: %offset_3 = arith.index_cast %offset_2 : i16 to index // CHECK-NEXT: %86 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> // CHECK-NEXT: %87 = arith.constant 4 : i16 @@ -312,7 +312,7 @@ builtin.module { // CHECK-NEXT: "memref.copy"(%86, %86) : (memref<510xf32, strided<[1], offset: ?>>, memref<510xf32, strided<[1], offset: ?>>) -> () // CHECK-NEXT: csl.return // CHECK-NEXT: } -// CHECK-NEXT: csl.func @done_exchange_cb1() { +// CHECK-NEXT: csl.func @done_exchange_cb2() { // CHECK-NEXT: %arg12_1 = "csl.load_var"(%75) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: %arg11_1 = "csl.load_var"(%74) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: scf.if %arg9 { @@ -349,6 +349,135 @@ builtin.module { // CHECK-NEXT: }) : () -> () + "csl_wrapper.module"() <{"width" = 1022 : i16, "height" = 510 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=2 : i16>, #csl_wrapper.param<"chunk_size" default=255 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "partial_access"}> ({ + ^0(%0 : i16, %1 : i16, %2 : i16, %3 : i16, %4 : i16, %5 : i16, %6 : i16, %7 : i16, %8 : i16): + %9 = arith.constant 0 : i16 + %10 = "csl.get_color"(%9) : (i16) -> !csl.color + %11 = "csl_wrapper.import"(%2, %3, %10) <{"module" = "", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module + %12 = "csl_wrapper.import"(%5, %2, %3) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module + %13 = "csl.member_call"(%12, %0, %1, %2, %3, %5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct + %14 = "csl.member_call"(%11, %0) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct + %15 = arith.constant 1 : i16 + %16 = arith.subi %5, %15 : i16 + %17 = arith.subi %2, %0 : i16 + %18 = arith.subi %3, %1 : i16 + %19 = arith.cmpi slt, %0, %16 : i16 + %20 = arith.cmpi slt, %1, %16 : i16 + %21 = arith.cmpi slt, %17, %5 : i16 + %22 = arith.cmpi slt, %18, %5 : i16 + %23 = arith.ori %19, %20 : i1 + %24 = arith.ori %23, %21 : i1 + %25 = arith.ori %24, %22 : i1 + "csl_wrapper.yield"(%14, %13, %25) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () + }, { + ^1(%26 : i16, %27 : i16, %28 : i16, %29 : i16, %30 : i16, %31 : i16, %32 : i16, %memcpy_params : !csl.comptime_struct, %stencil_comms_params : !csl.comptime_struct, %isBorderRegionPE : i1): + %33 = "csl_wrapper.import"(%memcpy_params) <{"module" = "", "fields" = [""]}> : (!csl.comptime_struct) -> !csl.imported_module + %34 = "csl_wrapper.import"(%29, %31, %stencil_comms_params) <{"module" = "stencil_comms.csl", "fields" = ["pattern", "chunkSize", ""]}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module + %arg0 = memref.alloc() : memref<512xf32> + %arg1 = memref.alloc() : memref<512xf32> + %35 = "csl.addressof"(%arg0) : (memref<512xf32>) -> !csl.ptr, #csl> + %36 = "csl.addressof"(%arg1) : (memref<512xf32>) -> !csl.ptr, #csl> + "csl.export"(%35) <{"var_name" = "arg0", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () + "csl.export"(%36) <{"var_name" = "arg1", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () + "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () + csl.func @partial_access() { + %37 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> + csl_stencil.apply(%arg0 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg1 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ + ^2(%arg2 : memref<4x255xf32>, %arg3 : index, %arg4 : memref<510xf32>): + %38 = csl_stencil.access %arg2[1, 0] : memref<4x255xf32> + %39 = csl_stencil.access %arg2[-1, 0] : memref<4x255xf32> + %40 = csl_stencil.access %arg2[0, 1] : memref<4x255xf32> + %42 = memref.subview %arg4[%arg3] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> + "csl.fadds"(%42, %39, %40) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> () + "csl.fadds"(%42, %42, %38) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () + csl_stencil.yield %arg4 : memref<510xf32> + }, { + ^3(%arg2_1 : memref<512xf32>, %arg3_1 : memref<510xf32>): + %43 = memref.subview %arg2_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> + %44 = memref.subview %arg2_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> + "csl.fadds"(%arg3_1, %arg3_1, %44) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () + "csl.fadds"(%arg3_1, %arg3_1, %43) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () + %45 = arith.constant 1.666600e-01 : f32 + "csl.fmuls"(%arg3_1, %arg3_1, %45) : (memref<510xf32>, memref<510xf32>, f32) -> () + csl_stencil.yield + }) to <[0, 0], [1, 1]> + csl.return + } + "csl_wrapper.yield"() <{"fields" = []}> : () -> () + }) : () -> () + +// CHECK-NEXT: "csl_wrapper.module"() <{"width" = 1022 : i16, "height" = 510 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=2 : i16>, #csl_wrapper.param<"chunk_size" default=255 : i16>, #csl_wrapper.param<"padded_z_dim" default=510 : i16>], "program_name" = "partial_access"}> ({ +// CHECK-NEXT: ^4(%101 : i16, %102 : i16, %103 : i16, %104 : i16, %105 : i16, %106 : i16, %107 : i16, %108 : i16, %109 : i16): +// CHECK-NEXT: %110 = arith.constant 0 : i16 +// CHECK-NEXT: %111 = "csl.get_color"(%110) : (i16) -> !csl.color +// CHECK-NEXT: %112 = "csl_wrapper.import"(%103, %104, %111) <{"module" = "", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module +// CHECK-NEXT: %113 = "csl_wrapper.import"(%106, %103, %104) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module +// CHECK-NEXT: %114 = "csl.member_call"(%113, %101, %102, %103, %104, %106) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct +// CHECK-NEXT: %115 = "csl.member_call"(%112, %101) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct +// CHECK-NEXT: %116 = arith.constant 1 : i16 +// CHECK-NEXT: %117 = arith.subi %106, %116 : i16 +// CHECK-NEXT: %118 = arith.subi %103, %101 : i16 +// CHECK-NEXT: %119 = arith.subi %104, %102 : i16 +// CHECK-NEXT: %120 = arith.cmpi slt, %101, %117 : i16 +// CHECK-NEXT: %121 = arith.cmpi slt, %102, %117 : i16 +// CHECK-NEXT: %122 = arith.cmpi slt, %118, %106 : i16 +// CHECK-NEXT: %123 = arith.cmpi slt, %119, %106 : i16 +// CHECK-NEXT: %124 = arith.ori %120, %121 : i1 +// CHECK-NEXT: %125 = arith.ori %124, %122 : i1 +// CHECK-NEXT: %126 = arith.ori %125, %123 : i1 +// CHECK-NEXT: "csl_wrapper.yield"(%115, %114, %126) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> () +// CHECK-NEXT: }, { +// CHECK-NEXT: ^5(%127 : i16, %128 : i16, %129 : i16, %130 : i16, %131 : i16, %132 : i16, %133 : i16, %memcpy_params_1 : !csl.comptime_struct, %stencil_comms_params_1 : !csl.comptime_struct, %isBorderRegionPE_1 : i1): +// CHECK-NEXT: %134 = "csl_wrapper.import"(%memcpy_params_1) <{"module" = "", "fields" = [""]}> : (!csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %135 = "csl_wrapper.import"(%130, %132, %stencil_comms_params_1) <{"module" = "stencil_comms.csl", "fields" = ["pattern", "chunkSize", ""]}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module +// CHECK-NEXT: %arg0_3 = memref.alloc() : memref<512xf32> +// CHECK-NEXT: %arg1_3 = memref.alloc() : memref<512xf32> +// CHECK-NEXT: %136 = "csl.addressof"(%arg0_3) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: %137 = "csl.addressof"(%arg1_3) : (memref<512xf32>) -> !csl.ptr, #csl> +// CHECK-NEXT: "csl.export"(%136) <{"var_name" = "arg0", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"(%137) <{"var_name" = "arg1", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () +// CHECK-NEXT: "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () +// CHECK-NEXT: csl.func @partial_access() { +// CHECK-NEXT: %accumulator_2 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> +// CHECK-NEXT: %138 = arith.constant 2 : i16 +// CHECK-NEXT: %139 = "csl.addressof_fn"() <{"fn_name" = @receive_chunk_cb0}> : () -> !csl.ptr<(i16) -> (), #csl, #csl> +// CHECK-NEXT: %140 = "csl.addressof_fn"() <{"fn_name" = @done_exchange_cb0}> : () -> !csl.ptr<() -> (), #csl, #csl> +// CHECK-NEXT: %141 = memref.subview %arg0_3[1] [510] [1] : memref<512xf32> to memref<510xf32> +// CHECK-NEXT: "csl.member_call"(%135, %141, %138, %139, %140) <{"field" = "communicate"}> : (!csl.imported_module, memref<510xf32>, i16, !csl.ptr<(i16) -> (), #csl, #csl>, !csl.ptr<() -> (), #csl, #csl>) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @receive_chunk_cb0(%offset_4 : i16) { +// CHECK-NEXT: %offset_5 = arith.index_cast %offset_4 : i16 to index +// CHECK-NEXT: %142 = arith.constant 1 : i16 +// CHECK-NEXT: %143 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %144 = "csl.member_call"(%135, %143, %142) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %145 = builtin.unrealized_conversion_cast %144 : !csl to memref<255xf32> +// CHECK-NEXT: %146 = arith.constant 1 : i16 +// CHECK-NEXT: %147 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %148 = "csl.member_call"(%135, %147, %146) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %149 = builtin.unrealized_conversion_cast %148 : !csl to memref<255xf32> +// CHECK-NEXT: %150 = arith.constant 1 : i16 +// CHECK-NEXT: %151 = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction +// CHECK-NEXT: %152 = "csl.member_call"(%135, %151, %150) <{"field" = "getRecvBufDsdByNeighbor"}> : (!csl.imported_module, !csl.direction, i16) -> !csl +// CHECK-NEXT: %153 = builtin.unrealized_conversion_cast %152 : !csl to memref<255xf32> +// CHECK-NEXT: %154 = memref.subview %accumulator_2[%offset_5] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> +// CHECK-NEXT: "csl.fadds"(%154, %149, %153) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> () +// CHECK-NEXT: "csl.fadds"(%154, %154, %145) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: csl.func @done_exchange_cb0() { +// CHECK-NEXT: %155 = memref.subview %arg0_3[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>> +// CHECK-NEXT: %156 = memref.subview %arg0_3[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>> +// CHECK-NEXT: "csl.fadds"(%accumulator_2, %accumulator_2, %156) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> () +// CHECK-NEXT: "csl.fadds"(%accumulator_2, %accumulator_2, %155) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> () +// CHECK-NEXT: %157 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: "csl.fmuls"(%accumulator_2, %accumulator_2, %157) : (memref<510xf32>, memref<510xf32>, f32) -> () +// CHECK-NEXT: csl.return +// CHECK-NEXT: } +// CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> () +// CHECK-NEXT: }) : () -> () + + } // CHECK-NEXT: }