Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(dialect+transform): stencil to new csl_stencil dialect and transform #2766

Merged
merged 26 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions tests/filecheck/dialects/csl/csl-stencil-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP

builtin.module {
func.func @gauss_seidel_func(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "size" = 510, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (memref<4xtensor<510xf32>>)
%1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %3 = %pref : memref<4xtensor<510xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
%4 = arith.constant 1.666600e-01 : f32
%5 = csl_stencil.access %3[1, 0] : memref<4xtensor<510xf32>>
%6 = csl_stencil.access %3[-1, 0] : memref<4xtensor<510xf32>>
%7 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%8 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%9 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%10 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%11 = csl_stencil.access %3[0, 1] : memref<4xtensor<510xf32>>
%12 = csl_stencil.access %3[0, -1] : memref<4xtensor<510xf32>>
%13 = arith.addf %12, %11 : tensor<510xf32>
%14 = arith.addf %13, %10 : tensor<510xf32>
%15 = arith.addf %14, %9 : tensor<510xf32>
%16 = arith.addf %15, %6 : tensor<510xf32>
%17 = arith.addf %16, %5 : tensor<510xf32>
%18 = tensor.empty() : tensor<510xf32>
%19 = linalg.fill ins(%4 : f32) outs(%18 : tensor<510xf32>) -> tensor<510xf32>
%20 = arith.mulf %17, %19 : tensor<510xf32>
stencil.return %20 : tensor<510xf32>
}
stencil.store %1 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
func.return
}
}


// CHECK-NEXT: "builtin.module"() ({
// CHECK-NEXT: "func.func"() <{"sym_name" = "gauss_seidel_func", "function_type" = (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()}> ({
// CHECK-NEXT: ^0(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
// CHECK-NEXT: %0 = "stencil.load"(%a) : (!stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %pref = "csl_stencil.prefetch"(%0) <{"topo" = #dmp.topo<1022x510>, "size" = 510 : i64, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> memref<4xtensor<510xf32>>
// CHECK-NEXT: %1 = "stencil.apply"(%0, %pref) ({
// CHECK-NEXT: ^1(%2 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %3 : memref<4xtensor<510xf32>>):
// CHECK-NEXT: %4 = "arith.constant"() <{"value" = 1.666600e-01 : f32}> : () -> f32
// CHECK-NEXT: %5 = "csl_stencil.access"(%3) {"offset" = #stencil.index[1, 0], "offset_mapping" = #stencil.index[0, 1]} : (memref<4xtensor<510xf32>>) -> tensor<510xf32>
// CHECK-NEXT: %6 = "csl_stencil.access"(%3) {"offset" = #stencil.index[-1, 0], "offset_mapping" = #stencil.index[0, 1]} : (memref<4xtensor<510xf32>>) -> tensor<510xf32>
// CHECK-NEXT: %7 = "stencil.access"(%2) {"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
// CHECK-NEXT: %8 = "stencil.access"(%2) {"offset" = #stencil.index[0, 0], "offset_mapping" = #stencil.index[0, 1]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> tensor<512xf32>
// CHECK-NEXT: %9 = "tensor.extract_slice"(%7) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %10 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %11 = "csl_stencil.access"(%3) {"offset" = #stencil.index[0, 1], "offset_mapping" = #stencil.index[0, 1]} : (memref<4xtensor<510xf32>>) -> tensor<510xf32>
// CHECK-NEXT: %12 = "csl_stencil.access"(%3) {"offset" = #stencil.index[0, -1], "offset_mapping" = #stencil.index[0, 1]} : (memref<4xtensor<510xf32>>) -> tensor<510xf32>
// CHECK-NEXT: %13 = "arith.addf"(%12, %11) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %14 = "arith.addf"(%13, %10) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %15 = "arith.addf"(%14, %9) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %16 = "arith.addf"(%15, %6) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %17 = "arith.addf"(%16, %5) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %18 = "tensor.empty"() : () -> tensor<510xf32>
// CHECK-NEXT: %19 = "linalg.fill"(%4, %18) <{"operandSegmentSizes" = array<i32: 1, 1>}> : (f32, tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %20 = "arith.mulf"(%17, %19) <{"fastmath" = #arith.fastmath<none>}> : (tensor<510xf32>, tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: "stencil.return"(%20) : (tensor<510xf32>) -> ()
// CHECK-NEXT: }) : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, memref<4xtensor<510xf32>>) -> !stencil.temp<[0,1]x[0,1]xtensor<510xf32>>
// CHECK-NEXT: "stencil.store"(%1, %b) {"bounds" = #stencil.bounds[0, 0] : [1, 1]} : (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>, !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) -> ()
// CHECK-NEXT: "func.return"() : () -> ()
// CHECK-NEXT: }) : () -> ()
// CHECK-NEXT: }) : () -> ()
63 changes: 63 additions & 0 deletions tests/filecheck/dialects/csl/stencil-to-csl-stencil.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: xdsl-opt %s -p "stencil-to-csl-stencil" | filecheck %s

builtin.module {
func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
"dmp.swap"(%0) {"topo" = #dmp.topo<1022x510>, "swaps" = [#dmp.exchange<at [1, 0, 0] size [1, 1, 510] source offset [-1, 0, 0] to [1, 0, 0]>, #dmp.exchange<at [-1, 0, 0] size [1, 1, 510] source offset [1, 0, 0] to [-1, 0, 0]>, #dmp.exchange<at [0, 1, 0] size [1, 1, 510] source offset [0, -1, 0] to [0, 1, 0]>, #dmp.exchange<at [0, -1, 0] size [1, 1, 510] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> ()
%1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
%3 = arith.constant 1.666600e-01 : f32
%4 = stencil.access %2[1, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%5 = "tensor.extract_slice"(%4) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%6 = stencil.access %2[-1, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%7 = "tensor.extract_slice"(%6) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%8 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%9 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%10 = stencil.access %2[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%11 = "tensor.extract_slice"(%10) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%12 = stencil.access %2[0, 1] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%13 = "tensor.extract_slice"(%12) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%14 = stencil.access %2[0, -1] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
%15 = "tensor.extract_slice"(%14) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
%16 = arith.addf %15, %13 : tensor<510xf32>
%17 = arith.addf %16, %11 : tensor<510xf32>
%18 = arith.addf %17, %9 : tensor<510xf32>
%19 = arith.addf %18, %7 : tensor<510xf32>
%20 = arith.addf %19, %5 : tensor<510xf32>
%21 = tensor.empty() : tensor<510xf32>
%22 = linalg.fill ins(%3 : f32) outs(%21 : tensor<510xf32>) -> tensor<510xf32>
%23 = arith.mulf %20, %22 : tensor<510xf32>
stencil.return %23 : tensor<510xf32>
}
stencil.store %1 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
func.return
}
}

// CHECK-NEXT: builtin.module {
// CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %1 = "csl_stencil.prefetch"(%0) <{"size" = 510 : i64, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> : (!stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>) -> memref<4xtensor<510xf32>>
// CHECK-NEXT: %2 = stencil.apply(%3 = %0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %4 = %1 : memref<4xtensor<510xf32>>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) {
// CHECK-NEXT: %5 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %6 = csl_stencil.access %4[1, 0] : memref<4xtensor<510xf32>>
// CHECK-NEXT: %7 = csl_stencil.access %4[-1, 0] : memref<4xtensor<510xf32>>
// CHECK-NEXT: %8 = stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %9 = "tensor.extract_slice"(%8) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %10 = stencil.access %3[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %12 = csl_stencil.access %4[0, 1] : memref<4xtensor<510xf32>>
// CHECK-NEXT: %13 = csl_stencil.access %4[0, -1] : memref<4xtensor<510xf32>>
// CHECK-NEXT: %14 = arith.addf %13, %12 : tensor<510xf32>
// CHECK-NEXT: %15 = arith.addf %14, %11 : tensor<510xf32>
// CHECK-NEXT: %16 = arith.addf %15, %9 : tensor<510xf32>
// CHECK-NEXT: %17 = arith.addf %16, %7 : tensor<510xf32>
// CHECK-NEXT: %18 = arith.addf %17, %6 : tensor<510xf32>
// CHECK-NEXT: %19 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %20 = linalg.fill ins(%5 : f32) outs(%19 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %21 = arith.mulf %18, %20 : tensor<510xf32>
// CHECK-NEXT: stencil.return %21 : tensor<510xf32>
// CHECK-NEXT: }
// CHECK-NEXT: stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
n-io marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 6 additions & 0 deletions xdsl/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ def get_csl():

return CSL

def get_csl_stencil():
from xdsl.dialects.csl_.csl_stencil import CSL_STENCIL

return CSL_STENCIL

def get_dmp():
from xdsl.dialects.experimental.dmp import DMP

Expand Down Expand Up @@ -273,6 +278,7 @@ def get_x86():
"cmath": get_cmath,
"comb": get_comb,
"csl": get_csl,
"csl_stencil": get_csl_stencil,
"dmp": get_dmp,
"fir": get_fir,
"fsm": get_fsm,
Expand Down
Empty file added xdsl/dialects/csl_/__init__.py
n-io marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
Loading
Loading