Skip to content

Commit

Permalink
(transform): csl_wrapper program module init (#2891)
Browse files Browse the repository at this point in the history
This PR initialises the program_module of a CSL wrapper. Computes
additional known values to remove any dependency on `utils.csl`.

Minor fixes:
* Add utility function `get_param_value` to csl_wrapper.ModuleOp
* small fix to `get_apply` on csl_stencil.AccessOp to work for both
`csl.Apply` or `csl_stencil.Apply`

---------

Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Jul 26, 2024
1 parent 7f09044 commit 912a579
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 55 deletions.
96 changes: 49 additions & 47 deletions tests/filecheck/transforms/csl-stencil-to-csl-wrapper.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,57 +35,59 @@ builtin.module {
}

// CHECK-NEXT: builtin.module {
// CHECK-NEXT: "csl_wrapper.module"() <{"width" = 1022 : i16, "height" = 510 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>], "program_name" = "gauss_seidel"}> ({
// CHECK-NEXT: ^0(%0 : i16, %1 : i16, %2 : i16, %3 : i16, %4 : i16, %5 : i16):
// CHECK-NEXT: %6 = arith.constant 0 : i16
// CHECK-NEXT: %7 = "csl.get_color"(%6) : (i16) -> !csl.color
// CHECK-NEXT: %8 = "csl_wrapper.import"(%2, %3, %7) <{"module" = "<memcpy/get_params>", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module
// CHECK-NEXT: %9 = "csl_wrapper.import"(%5, %2, %3) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module
// CHECK-NEXT: %10 = "csl.member_call"(%9, %0, %1, %2, %3, %5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct
// CHECK-NEXT: %11 = "csl.member_call"(%8, %0) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct
// CHECK-NEXT: %12 = arith.constant 1 : i16
// CHECK-NEXT: %13 = arith.subi %12, %5 : i16
// CHECK-NEXT: %14 = arith.subi %2, %0 : i16
// CHECK-NEXT: %15 = arith.subi %3, %1 : i16
// CHECK-NEXT: %16 = arith.cmpi slt, %0, %13 : i16
// CHECK-NEXT: %17 = arith.cmpi slt, %1, %13 : i16
// CHECK-NEXT: %18 = arith.cmpi slt, %14, %5 : i16
// CHECK-NEXT: %19 = arith.cmpi slt, %15, %5 : i16
// CHECK-NEXT: %20 = arith.ori %16, %17 : i1
// CHECK-NEXT: %21 = arith.ori %20, %18 : i1
// CHECK-NEXT: %22 = arith.ori %21, %19 : i1
// CHECK-NEXT: "csl_wrapper.yield"(%11, %10, %22) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> ()
// CHECK-NEXT: "csl_wrapper.module"() <{"width" = 1022 : i16, "height" = 510 : i16, "params" = [#csl_wrapper.param<"z_dim" default=512 : i16>, #csl_wrapper.param<"pattern" default=2 : i16>, #csl_wrapper.param<"num_chunks" default=2 : i16>, #csl_wrapper.param<"chunk_size" default=256 : i16>, #csl_wrapper.param<"padded_z_dim" default=512 : i16>], "program_name" = "gauss_seidel"}> ({
// CHECK-NEXT: ^0(%0 : i16, %1 : i16, %2 : i16, %3 : i16, %4 : i16, %5 : i16, %6 : i16, %7 : i16, %8 : i16):
// CHECK-NEXT: %9 = arith.constant 0 : i16
// CHECK-NEXT: %10 = "csl.get_color"(%9) : (i16) -> !csl.color
// CHECK-NEXT: %11 = "csl_wrapper.import"(%2, %3, %10) <{"module" = "<memcpy/get_params>", "fields" = ["width", "height", "LAUNCH"]}> : (i16, i16, !csl.color) -> !csl.imported_module
// CHECK-NEXT: %12 = "csl_wrapper.import"(%5, %2, %3) <{"module" = "routes.csl", "fields" = ["pattern", "peWidth", "peHeight"]}> : (i16, i16, i16) -> !csl.imported_module
// CHECK-NEXT: %13 = "csl.member_call"(%12, %0, %1, %2, %3, %5) <{"field" = "computeAllRoutes"}> : (!csl.imported_module, i16, i16, i16, i16, i16) -> !csl.comptime_struct
// CHECK-NEXT: %14 = "csl.member_call"(%11, %0) <{"field" = "get_params"}> : (!csl.imported_module, i16) -> !csl.comptime_struct
// CHECK-NEXT: %15 = arith.constant 1 : i16
// CHECK-NEXT: %16 = arith.subi %15, %5 : i16
// CHECK-NEXT: %17 = arith.subi %2, %0 : i16
// CHECK-NEXT: %18 = arith.subi %3, %1 : i16
// CHECK-NEXT: %19 = arith.cmpi slt, %0, %16 : i16
// CHECK-NEXT: %20 = arith.cmpi slt, %1, %16 : i16
// CHECK-NEXT: %21 = arith.cmpi slt, %17, %5 : i16
// CHECK-NEXT: %22 = arith.cmpi slt, %18, %5 : i16
// CHECK-NEXT: %23 = arith.ori %19, %20 : i1
// CHECK-NEXT: %24 = arith.ori %23, %21 : i1
// CHECK-NEXT: %25 = arith.ori %24, %22 : i1
// CHECK-NEXT: "csl_wrapper.yield"(%14, %13, %25) <{"fields" = ["memcpy_params", "stencil_comms_params", "isBorderRegionPE"]}> : (!csl.comptime_struct, !csl.comptime_struct, i1) -> ()
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%23 : i16, %24 : i16, %25 : i16, %26 : i16, %memcpy_params : !csl.comptime_struct, %stencil_comms_params : !csl.comptime_struct, %isBorderRegionPE : i1, %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
// CHECK-NEXT: ^1(%26 : i16, %27 : i16, %28 : i16, %29 : i16, %30 : i16, %31 : i16, %32 : i16, %memcpy_params : !csl.comptime_struct, %stencil_comms_params : !csl.comptime_struct, %isBorderRegionPE : i1, %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>):
// CHECK-NEXT: %33 = "csl_wrapper.import"(%memcpy_params) <{"module" = "<memcpy/memcpy>", "fields" = [""]}> : (!csl.comptime_struct) -> !csl.imported_module
// CHECK-NEXT: %34 = "csl_wrapper.import"(%29, %31, %stencil_comms_params) <{"module" = "stencil_comms.csl", "fields" = ["pattern", "chunkSize", ""]}> : (i16, i16, !csl.comptime_struct) -> !csl.imported_module
// CHECK-NEXT: csl.func @gauss_seidel() {
// CHECK-NEXT: %27 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %28 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %29 = csl_stencil.apply(%27 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %28 : tensor<510xf32>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
// CHECK-NEXT: ^2(%30 : memref<4xtensor<255xf32>>, %31 : index, %32 : tensor<510xf32>):
// CHECK-NEXT: %33 = csl_stencil.access %30[1, 0] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %34 = csl_stencil.access %30[-1, 0] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %35 = csl_stencil.access %30[0, 1] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %36 = csl_stencil.access %30[0, -1] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %37 = arith.addf %36, %35 : tensor<255xf32>
// CHECK-NEXT: %38 = arith.addf %37, %34 : tensor<255xf32>
// CHECK-NEXT: %39 = arith.addf %38, %33 : tensor<255xf32>
// CHECK-NEXT: %40 = "tensor.insert_slice"(%39, %32, %31) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %40 : tensor<510xf32>
// CHECK-NEXT: %35 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %36 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %37 = csl_stencil.apply(%35 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %36 : tensor<510xf32>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
// CHECK-NEXT: ^2(%38 : memref<4xtensor<255xf32>>, %39 : index, %40 : tensor<510xf32>):
// CHECK-NEXT: %41 = csl_stencil.access %38[1, 0] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %42 = csl_stencil.access %38[-1, 0] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %43 = csl_stencil.access %38[0, 1] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %44 = csl_stencil.access %38[0, -1] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %45 = arith.addf %44, %43 : tensor<255xf32>
// CHECK-NEXT: %46 = arith.addf %45, %42 : tensor<255xf32>
// CHECK-NEXT: %47 = arith.addf %46, %41 : tensor<255xf32>
// CHECK-NEXT: %48 = "tensor.insert_slice"(%47, %40, %39) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %48 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^3(%41 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %42 : tensor<510xf32>):
// CHECK-NEXT: %43 = csl_stencil.access %41[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %44 = csl_stencil.access %41[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %45 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %46 = "tensor.extract_slice"(%43) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %47 = "tensor.extract_slice"(%44) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %48 = arith.addf %42, %47 : tensor<510xf32>
// CHECK-NEXT: %49 = arith.addf %48, %46 : tensor<510xf32>
// CHECK-NEXT: %50 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %51 = linalg.fill ins(%45 : f32) outs(%50 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %52 = arith.mulf %49, %51 : tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %52 : tensor<510xf32>
// CHECK-NEXT: ^3(%49 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %50 : tensor<510xf32>):
// CHECK-NEXT: %51 = csl_stencil.access %49[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %52 = csl_stencil.access %49[0, 0] : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %53 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: %54 = "tensor.extract_slice"(%51) <{"static_offsets" = array<i64: 1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %55 = "tensor.extract_slice"(%52) <{"static_offsets" = array<i64: -1>, "static_sizes" = array<i64: 510>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<512xf32>) -> tensor<510xf32>
// CHECK-NEXT: %56 = arith.addf %50, %55 : tensor<510xf32>
// CHECK-NEXT: %57 = arith.addf %56, %54 : tensor<510xf32>
// CHECK-NEXT: %58 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %59 = linalg.fill ins(%53 : f32) outs(%58 : tensor<510xf32>) -> tensor<510xf32>
// CHECK-NEXT: %60 = arith.mulf %57, %59 : tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %60 : tensor<510xf32>
// CHECK-NEXT: })
// CHECK-NEXT: stencil.store %29 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: stencil.store %37 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: "csl_wrapper.yield"() <{"fields" = []}> : () -> ()
Expand Down
9 changes: 6 additions & 3 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,18 +524,21 @@ def verify_(self) -> None:
f"apply, got {offset} >= {apply.get_rank()}"
)

def get_apply(self):
def get_apply(self) -> stencil.ApplyOp | ApplyOp:
"""
Simple helper to get the parent apply and raise otherwise.
"""
trait = cast(HasAncestor, self.get_trait(HasAncestor, (stencil.ApplyOp,)))
trait = cast(
HasAncestor, self.get_trait(HasAncestor, (stencil.ApplyOp, ApplyOp))
)
ancestor = trait.get_ancestor(self)
if ancestor is None:
raise ValueError(
"stencil.apply not found, this function should be called on"
"verified accesses only."
)
return cast(stencil.ApplyOp, ancestor)
assert isinstance(ancestor, stencil.ApplyOp | ApplyOp)
return ancestor


@irdl_op_definition
Expand Down
25 changes: 23 additions & 2 deletions xdsl/dialects/csl/csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ class ModuleOp(IRDLOperation):

name = "csl_wrapper.module"

width = prop_def(AnyIntegerAttr)
height = prop_def(AnyIntegerAttr)
width = prop_def(IntegerAttr[IntegerType])
height = prop_def(IntegerAttr[IntegerType])
program_name = opt_prop_def(StringAttr)
params: ArrayAttr[ParamAttribute] = prop_def(ArrayAttr[ParamAttribute])

Expand Down Expand Up @@ -323,6 +323,20 @@ def get_program_param(self, name: str) -> BlockArgument:
# not found = value error
raise ValueError(f"{name} does not refer to a block arg of this program_module")

def get_param_value(self, name: str) -> IntegerAttr[IntegerType]:
"""Retrieve the value of a named op param."""
if name == "width":
return self.width
elif name == "height":
return self.height
res = NoneAttr()
for param in self.params.data:
if name == param.key.data:
res = param.value
if isinstance(res, NoneAttr):
raise ValueError(f"Parameter name is unknown or has no value: {name}")
return res

@property
def layout_yield_op(self) -> YieldOp:
"""
Expand All @@ -339,6 +353,13 @@ def exported_symbols(self) -> Sequence[BlockArgument]:
2 + len(self.params) + len(self.layout_yield_op.fields) :
]

def get_program_import(self, name: str) -> ImportOp:
"""Get top-level import op in the program_module"""
for op in self.program_module.ops:
if isinstance(op, ImportOp) and op.module.data == name:
return op
raise ValueError(f"Cannot get program_module import of {name}")


@irdl_op_definition
class YieldOp(IRDLOperation):
Expand Down
35 changes: 32 additions & 3 deletions xdsl/transforms/csl_stencil_to_csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from xdsl.dialects import arith, builtin, func, stencil
from xdsl.dialects.builtin import IntegerAttr, TensorType
from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper
from xdsl.ir import Attribute
from xdsl.ir import Attribute, Operation
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down Expand Up @@ -41,6 +41,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
height: int = 1
z_dim_no_ghost_cells: int = 1
z_dim: int = 1
num_chunks: int = 1
for apply_op in apply_ops:
# loop over accesses to get max_distance (from which we build `pattern`)
for ap in apply_op.get_accesses():
Expand Down Expand Up @@ -77,13 +78,22 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
if isa(field_t := arg.type, stencil.FieldType[TensorType[Attribute]]):
z_dim = max(z_dim, field_t.get_element_type().get_shape()[0])

num_chunks = max(num_chunks, apply_op.num_chunks.value.data)

# some computations we don't need to do in CSL
chunk_size: int = (z_dim // num_chunks) + (0 if z_dim % num_chunks == 0 else 1)
padded_z_dim: int = chunk_size * num_chunks

# initialise module op
module_op = csl_wrapper.ModuleOp(
width=IntegerAttr(width, 16),
height=IntegerAttr(height, 16),
params={
"z_dim": IntegerAttr(z_dim, 16),
"pattern": IntegerAttr(max_distance + 1, 16),
"num_chunks": IntegerAttr(num_chunks, 16),
"chunk_size": IntegerAttr(chunk_size, 16),
"padded_z_dim": IntegerAttr(padded_z_dim, 16),
},
)

Expand Down Expand Up @@ -111,8 +121,8 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter, /):
module_op.exported_symbols,
)

# add main and empty yield to program_module
module_op.program_module.block.add_ops([main_func, csl_wrapper.YieldOp([], [])])
# initialise program_module and add main func and empty yield op
self.initialise_program_module(module_op, add_ops=[main_func])

# replace (now empty) func by module wrapper
rewriter.replace_matched_op(module_op)
Expand Down Expand Up @@ -207,6 +217,25 @@ def initialise_layout_module(self, module_op: csl_wrapper.ModuleOp):
}
)

def initialise_program_module(
self, module_op: csl_wrapper.ModuleOp, add_ops: Sequence[Operation]
):
with ImplicitBuilder(module_op.program_module.block):
csl_wrapper.ImportOp(
"<memcpy/memcpy>",
field_name_mapping={"": module_op.get_program_param("memcpy_params")},
)
csl_wrapper.ImportOp(
"stencil_comms.csl",
field_name_mapping={
"pattern": module_op.get_program_param("pattern"),
"chunkSize": module_op.get_program_param("chunk_size"),
"": module_op.get_program_param("stencil_comms_params"),
},
)
module_op.program_module.block.add_ops(add_ops)
module_op.program_module.block.add_op(csl_wrapper.YieldOp([], []))


@dataclass(frozen=True)
class CslStencilToCslWrapperPass(ModulePass):
Expand Down

0 comments on commit 912a579

Please sign in to comment.