Skip to content

Commit

Permalink
transformations: (lower-csl-wrapper) Add params_as_consts flag (#3324)
Browse files Browse the repository at this point in the history
CSL module wrapper params are by default lowered to CSL params, which
can have a value specified in code that can be overridden manually by
providing command line inputs. If this is not desired, module wrapper
params can be lowered to constants instead by setting the
`params_as_consts` flag. For this to happen, the module wrapper params
must be numerical and have a specified default value.

Use this flag for programs if the generated program does not support
arbitrary param values.

---------

Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Oct 25, 2024
1 parent a3e2e7c commit 3f81f3e
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 51 deletions.
162 changes: 134 additions & 28 deletions tests/filecheck/transforms/lower-csl-wrapper.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,112 @@
// RUN: xdsl-opt -p lower-csl-wrapper %s | filecheck --match-full-lines %s
// RUN: xdsl-opt -p lower-csl-wrapper{params_as_consts=true} %s | filecheck --match-full-lines %s --check-prefix=CONST

builtin.module {
// CHECK: builtin.module {
// CONST: builtin.module {

"csl_wrapper.module"() <{
"width" = 256 : i16,
"height" = 128 : i16,
"params" = [
#csl_wrapper.param<"param_with_value" default=512 : i16>,
#csl_wrapper.param<"param_without_value" : i16>
],
"program_name" = "params_as_consts_func"
}> ({
^0(%xDim : i16, %yDim : i16, %width : i16, %height : i16, %param_with_value : i16, %param_without_value : i16):
%memparams = "test.op"() : () -> !csl.comptime_struct
"csl_wrapper.yield"(%memparams) <{"fields" = ["memcpy_params"]}> : (!csl.comptime_struct) -> ()
}, {
^1(%width : i16, %height : i16, %param_with_value : i16, %param_without_value : i16, %memcpy_params : !csl.comptime_struct):
%memcpyMod = "csl_wrapper.import"(%memcpy_params) <{"module" = "<memcpy/memcpy>", "fields" = [""]}> : (!csl.comptime_struct) -> !csl.imported_module
"csl.export"() <{"var_name" = @params_as_consts_func, "type" = () -> ()}> : () -> ()
csl.func @params_as_consts_func() {
"test.op"() : () -> ()
csl.return
}
"csl_wrapper.yield"() <{"fields" = []}> : () -> ()
}) : () -> ()

// CHECK: "csl.module"() <{"kind" = #csl<module_kind layout>}> ({
// CHECK-NEXT: %0 = arith.constant 0 : i16
// CHECK-NEXT: %1 = arith.constant 1 : i16
// CHECK-NEXT: %2 = arith.constant 256 : i16
// CHECK-NEXT: %3 = arith.constant 128 : i16
// CHECK-NEXT: %width = "csl.param"(%2) <{"param_name" = "width"}> : (i16) -> i16
// CHECK-NEXT: %height = "csl.param"(%3) <{"param_name" = "height"}> : (i16) -> i16
// CHECK-NEXT: %4 = arith.constant 512 : i16
// CHECK-NEXT: %param_with_value = "csl.param"(%4) <{"param_name" = "param_with_value"}> : (i16) -> i16
// CHECK-NEXT: %param_without_value = "csl.param"() <{"param_name" = "param_without_value"}> : () -> i16
// CHECK-NEXT: csl.layout {
// CHECK-NEXT: "csl.set_rectangle"(%width, %height) : (i16, i16) -> ()
// CHECK-NEXT: scf.for %xDim = %0 to %width step %1 : i16 {
// CHECK-NEXT: scf.for %yDim = %0 to %height step %1 : i16 {
// CHECK-NEXT: %memparams = "test.op"() : () -> !csl.comptime_struct
// CHECK-NEXT: %5 = "csl.const_struct"(%width, %height, %memparams) <{"ssa_fields" = ["width", "height", "memcpy_params"]}> : (i16, i16, !csl.comptime_struct) -> !csl.comptime_struct
// CHECK-NEXT: "csl.set_tile_code"(%xDim, %yDim, %5) <{"file" = "params_as_consts_func.csl"}> : (i16, i16, !csl.comptime_struct) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }) {"sym_name" = "params_as_consts_func_layout"} : () -> ()
// CHECK-NEXT: "csl.module"() <{"kind" = #csl<module_kind program>}> ({
// CHECK-NEXT: %memcpy_params = "csl.param"() <{"param_name" = "memcpy_params"}> : () -> !csl.comptime_struct
// CHECK-NEXT: %memcpyMod = "csl.const_struct"() <{"ssa_fields" = []}> : () -> !csl.comptime_struct
// CHECK-NEXT: %memcpyMod_1 = "csl.concat_structs"(%memcpyMod, %memcpy_params) : (!csl.comptime_struct, !csl.comptime_struct) -> !csl.comptime_struct
// CHECK-NEXT: %memcpyMod_2 = "csl.import_module"(%memcpyMod_1) <{"module" = "<memcpy/memcpy>"}> : (!csl.comptime_struct) -> !csl.imported_module
// CHECK-NEXT: %0 = arith.constant 256 : i16
// CHECK-NEXT: %1 = arith.constant 128 : i16
// CHECK-NEXT: %width = "csl.param"(%0) <{"param_name" = "width"}> : (i16) -> i16
// CHECK-NEXT: %height = "csl.param"(%1) <{"param_name" = "height"}> : (i16) -> i16
// CHECK-NEXT: %2 = arith.constant 512 : i16
// CHECK-NEXT: %param_with_value = "csl.param"(%2) <{"param_name" = "param_with_value"}> : (i16) -> i16
// CHECK-NEXT: %param_without_value = "csl.param"() <{"param_name" = "param_without_value"}> : () -> i16
// CHECK-NEXT: "csl.export"() <{"var_name" = @params_as_consts_func, "type" = () -> ()}> : () -> ()
// CHECK-NEXT: csl.func @params_as_consts_func() {
// CHECK-NEXT: "test.op"() : () -> ()
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: %3 = "csl.member_access"(%memcpyMod_2) <{"field" = "LAUNCH"}> : (!csl.imported_module) -> !csl.color
// CHECK-NEXT: "csl.rpc"(%3) : (!csl.color) -> ()
// CHECK-NEXT: }) {"sym_name" = "params_as_consts_func_program"} : () -> ()

// CONST: "csl.module"() <{"kind" = #csl<module_kind layout>}> ({
// CONST-NEXT: %0 = arith.constant 0 : i16
// CONST-NEXT: %1 = arith.constant 1 : i16
// CONST-NEXT: %width = arith.constant 256 : i16
// CONST-NEXT: %height = arith.constant 128 : i16
// CONST-NEXT: %param_with_value = arith.constant 512 : i16
// CONST-NEXT: %param_without_value = "csl.param"() <{"param_name" = "param_without_value"}> : () -> i16
// CONST-NEXT: csl.layout {
// CONST-NEXT: "csl.set_rectangle"(%width, %height) : (i16, i16) -> ()
// CONST-NEXT: scf.for %xDim = %0 to %width step %1 : i16 {
// CONST-NEXT: scf.for %yDim = %0 to %height step %1 : i16 {
// CONST-NEXT: %memparams = "test.op"() : () -> !csl.comptime_struct
// CONST-NEXT: %2 = "csl.const_struct"(%width, %height, %memparams) <{"ssa_fields" = ["width", "height", "memcpy_params"]}> : (i16, i16, !csl.comptime_struct) -> !csl.comptime_struct
// CONST-NEXT: "csl.set_tile_code"(%xDim, %yDim, %2) <{"file" = "params_as_consts_func.csl"}> : (i16, i16, !csl.comptime_struct) -> ()
// CONST-NEXT: }
// CONST-NEXT: }
// CONST-NEXT: }
// CONST-NEXT: }) {"sym_name" = "params_as_consts_func_layout"} : () -> ()
// CONST-NEXT: "csl.module"() <{"kind" = #csl<module_kind program>}> ({
// CONST-NEXT: %memcpy_params = "csl.param"() <{"param_name" = "memcpy_params"}> : () -> !csl.comptime_struct
// CONST-NEXT: %memcpyMod = "csl.const_struct"() <{"ssa_fields" = []}> : () -> !csl.comptime_struct
// CONST-NEXT: %memcpyMod_1 = "csl.concat_structs"(%memcpyMod, %memcpy_params) : (!csl.comptime_struct, !csl.comptime_struct) -> !csl.comptime_struct
// CONST-NEXT: %memcpyMod_2 = "csl.import_module"(%memcpyMod_1) <{"module" = "<memcpy/memcpy>"}> : (!csl.comptime_struct) -> !csl.imported_module
// CONST-NEXT: %width = arith.constant 256 : i16
// CONST-NEXT: %height = arith.constant 128 : i16
// CONST-NEXT: %param_with_value = arith.constant 512 : i16
// CONST-NEXT: %param_without_value = "csl.param"() <{"param_name" = "param_without_value"}> : () -> i16
// CONST-NEXT: "csl.export"() <{"var_name" = @params_as_consts_func, "type" = () -> ()}> : () -> ()
// CONST-NEXT: csl.func @params_as_consts_func() {
// CONST-NEXT: "test.op"() : () -> ()
// CONST-NEXT: csl.return
// CONST-NEXT: }
// CONST-NEXT: %0 = "csl.member_access"(%memcpyMod_2) <{"field" = "LAUNCH"}> : (!csl.imported_module) -> !csl.color
// CONST-NEXT: "csl.rpc"(%0) : (!csl.color) -> ()
// CONST-NEXT: }) {"sym_name" = "params_as_consts_func_program"} : () -> ()


"csl_wrapper.module"() <{
"width" = 1022 : i16,
"height" = 510 : i16,
Expand Down Expand Up @@ -70,10 +176,7 @@ builtin.module {
}
"csl_wrapper.yield"() <{"fields" = []}> : () -> ()
}) : () -> ()
}


// CHECK: builtin.module {
// CHECK-NEXT: "csl.module"() <{"kind" = #csl<module_kind layout>}> ({
// CHECK-NEXT: %0 = arith.constant 2 : i16
// CHECK-NEXT: %pattern = "csl.param"(%0) <{"param_name" = "pattern"}> : (i16) -> i16
Expand Down Expand Up @@ -133,14 +236,16 @@ builtin.module {
// CHECK-NEXT: %memcpyMod = "csl.const_struct"() <{"ssa_fields" = []}> : () -> !csl.comptime_struct
// CHECK-NEXT: %memcpyMod_1 = "csl.concat_structs"(%memcpyMod, %memcpy_params) : (!csl.comptime_struct, !csl.comptime_struct) -> !csl.comptime_struct
// CHECK-NEXT: %memcpyMod_2 = "csl.import_module"(%memcpyMod_1) <{"module" = "<memcpy/memcpy>"}> : (!csl.comptime_struct) -> !csl.imported_module
// CHECK-NEXT: %width = "csl.param"() <{"param_name" = "width"}> : () -> i16
// CHECK-NEXT: %height = "csl.param"() <{"param_name" = "height"}> : () -> i16
// CHECK-NEXT: %2 = arith.constant 512 : i16
// CHECK-NEXT: %zDim = "csl.param"(%2) <{"param_name" = "z_dim"}> : (i16) -> i16
// CHECK-NEXT: %3 = arith.constant 2 : i16
// CHECK-NEXT: %num_chunks = "csl.param"(%3) <{"param_name" = "num_chunks"}> : (i16) -> i16
// CHECK-NEXT: %4 = arith.constant 510 : i16
// CHECK-NEXT: %padded_z_dim = "csl.param"(%4) <{"param_name" = "padded_z_dim"}> : (i16) -> i16
// CHECK-NEXT: %2 = arith.constant 1022 : i16
// CHECK-NEXT: %3 = arith.constant 510 : i16
// CHECK-NEXT: %width = "csl.param"(%2) <{"param_name" = "width"}> : (i16) -> i16
// CHECK-NEXT: %height = "csl.param"(%3) <{"param_name" = "height"}> : (i16) -> i16
// CHECK-NEXT: %4 = arith.constant 512 : i16
// CHECK-NEXT: %zDim = "csl.param"(%4) <{"param_name" = "z_dim"}> : (i16) -> i16
// CHECK-NEXT: %5 = arith.constant 2 : i16
// CHECK-NEXT: %num_chunks = "csl.param"(%5) <{"param_name" = "num_chunks"}> : (i16) -> i16
// CHECK-NEXT: %6 = arith.constant 510 : i16
// CHECK-NEXT: %padded_z_dim = "csl.param"(%6) <{"param_name" = "padded_z_dim"}> : (i16) -> i16
// CHECK-NEXT: %isBorderRegionPE = "csl.param"() <{"param_name" = "isBorderRegionPE"}> : () -> i1
// CHECK-NEXT: %inputArr = memref.alloc() : memref<512xf32>
// CHECK-NEXT: %outputArr = memref.alloc() : memref<512xf32>
Expand All @@ -153,29 +258,30 @@ builtin.module {
// CHECK-NEXT: %scratchBuffer = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
// CHECK-NEXT: csl_stencil.apply(%inputArr : memref<512xf32>, %scratchBuffer : memref<510xf32>) outs (%outputArr : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 1>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>}> ({
// CHECK-NEXT: ^0(%arg2 : memref<4x255xf32>, %arg3 : index, %arg4 : memref<510xf32>):
// CHECK-NEXT: %5 = csl_stencil.access %arg2[1, 0] : memref<4x255xf32>
// CHECK-NEXT: %6 = csl_stencil.access %arg2[-1, 0] : memref<4x255xf32>
// CHECK-NEXT: %7 = csl_stencil.access %arg2[0, 1] : memref<4x255xf32>
// CHECK-NEXT: %8 = csl_stencil.access %arg2[0, -1] : memref<4x255xf32>
// CHECK-NEXT: %9 = memref.subview %arg4[%arg3] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>>
// CHECK-NEXT: "csl.fadds"(%9, %8, %7) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%9, %9, %6) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%9, %9, %5) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> ()
// CHECK-NEXT: %7 = csl_stencil.access %arg2[1, 0] : memref<4x255xf32>
// CHECK-NEXT: %8 = csl_stencil.access %arg2[-1, 0] : memref<4x255xf32>
// CHECK-NEXT: %9 = csl_stencil.access %arg2[0, 1] : memref<4x255xf32>
// CHECK-NEXT: %10 = csl_stencil.access %arg2[0, -1] : memref<4x255xf32>
// CHECK-NEXT: %11 = memref.subview %arg4[%arg3] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>>
// CHECK-NEXT: "csl.fadds"(%11, %10, %9) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>, memref<255xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%11, %11, %8) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> ()
// CHECK-NEXT: "csl.fadds"(%11, %11, %7) : (memref<255xf32, strided<[1], offset: ?>>, memref<255xf32, strided<[1], offset: ?>>, memref<255xf32>) -> ()
// CHECK-NEXT: csl_stencil.yield %arg4 : memref<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%arg2_1 : memref<512xf32>, %arg3_1 : memref<510xf32>):
// CHECK-NEXT: %10 = memref.subview %arg2_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>>
// CHECK-NEXT: %11 = memref.subview %arg2_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>>
// CHECK-NEXT: "csl.fadds"(%arg3_1, %arg3_1, %11) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> ()
// CHECK-NEXT: "csl.fadds"(%arg3_1, %arg3_1, %10) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> ()
// CHECK-NEXT: %12 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: "csl.fmuls"(%arg3_1, %arg3_1, %12) : (memref<510xf32>, memref<510xf32>, f32) -> ()
// CHECK-NEXT: %12 = memref.subview %arg2_1[2] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1], offset: 2>>
// CHECK-NEXT: %13 = memref.subview %arg2_1[0] [510] [1] : memref<512xf32> to memref<510xf32, strided<[1]>>
// CHECK-NEXT: "csl.fadds"(%arg3_1, %arg3_1, %13) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1]>>) -> ()
// CHECK-NEXT: "csl.fadds"(%arg3_1, %arg3_1, %12) : (memref<510xf32>, memref<510xf32>, memref<510xf32, strided<[1], offset: 2>>) -> ()
// CHECK-NEXT: %14 = arith.constant 1.666600e-01 : f32
// CHECK-NEXT: "csl.fmuls"(%arg3_1, %arg3_1, %14) : (memref<510xf32>, memref<510xf32>, f32) -> ()
// CHECK-NEXT: csl_stencil.yield %arg3_1 : memref<510xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: %5 = "csl.member_access"(%memcpyMod_2) <{"field" = "LAUNCH"}> : (!csl.imported_module) -> !csl.color
// CHECK-NEXT: "csl.rpc"(%5) : (!csl.color) -> ()
// CHECK-NEXT: %7 = "csl.member_access"(%memcpyMod_2) <{"field" = "LAUNCH"}> : (!csl.imported_module) -> !csl.color
// CHECK-NEXT: "csl.rpc"(%7) : (!csl.color) -> ()
// CHECK-NEXT: }) {"sym_name" = "gauss_seidel_func_program"} : () -> ()

}
// CHECK-NEXT: }
// CHECK-EMPTY:
65 changes: 42 additions & 23 deletions xdsl/transforms/lower_csl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,36 +26,51 @@

@dataclass(frozen=True)
class ExtractCslModules(RewritePattern):
params_as_consts: bool

@op_type_rewrite_pattern
def match_and_rewrite(self, op: csl_wrapper.ModuleOp, rewriter: PatternRewriter, /):
program_module = self.lower_program_module(op, rewriter)
layout_module = self.lower_layout_module(op, rewriter)
rewriter.replace_matched_op([layout_module, program_module])

@staticmethod
def _collect_params(op: csl_wrapper.ModuleOp) -> list[SSAValue]:
def _collect_params(
self, op: csl_wrapper.ModuleOp
) -> tuple[SSAValue, SSAValue, list[SSAValue]]:
"""
Creates a list of `csl.param`s which should replace the block arguments in the
layout and program regions of the wrapper.
To be called in an `ImplicitBuilder`
Returns width, height, and a list of other params as SSAValues.
Params can alternatively be lowered to constants via the `params_as_consts` flag.
"""
width = arith.Constant(op.width).result
height = arith.Constant(op.height).result
if not self.params_as_consts:
width = csl.ParamOp("width", op.width.type, width).res
height = csl.ParamOp("height", op.height.type, height).res

params = list[SSAValue]()
for param in op.params:
if isattr(param.value, builtin.AnyIntegerAttrConstr):
value = arith.Constant(param.value)
else:
value = None
p = csl.ParamOp(param.key.data, param.type, value)
params.append(p.res)
return params
if value and self.params_as_consts:
params.append(value.result)
else:
p = csl.ParamOp(param.key.data, param.type, value)
params.append(p.res)
return width, height, params

def add_tile_code(
self,
x: SSAValue,
y: SSAValue,
width: csl.ParamOp,
height: csl.ParamOp,
width: SSAValue,
height: SSAValue,
yield_op: csl_wrapper.YieldOp,
prog_name: str,
) -> tuple[csl.ConstStructOp, csl.SetTileCodeOp]:
Expand Down Expand Up @@ -115,13 +130,7 @@ def lower_layout_module(
const_0 = arith.Constant.from_int_and_width(0, builtin.IntegerType(16))
const_1 = arith.Constant.from_int_and_width(1, builtin.IntegerType(16))

const_width = arith.Constant(op.width)
param_width = csl.ParamOp("width", op.width.type, const_width)

const_height = arith.Constant(op.height)
param_height = csl.ParamOp("height", op.height.type, const_height)

params_from_block_args = self._collect_params(op)
param_width, param_height, params_from_block_args = self._collect_params(op)

layout = csl.LayoutOp(Region())
with ImplicitBuilder(layout.body.block):
Expand Down Expand Up @@ -149,8 +158,8 @@ def lower_layout_module(
arg_values=[
SSAValue.get(x),
SSAValue.get(y),
param_width.res,
param_height.res,
param_width,
param_height,
*params_from_block_args,
],
)
Expand Down Expand Up @@ -200,10 +209,7 @@ def lower_program_module(
prog_name = op.program_name.data if op.program_name else __DEFAULT_PROG_NAME
module_block = Block()
with ImplicitBuilder(module_block):
param_width = csl.ParamOp("width", op.width.type)
param_height = csl.ParamOp("height", op.height.type)

params_from_block_args = self._collect_params(op)
param_width, param_height, params_from_block_args = self._collect_params(op)

assert isa(yield_op := op.layout_module.block.last_op, csl_wrapper.YieldOp)
yield_args = self._collect_yield_args(yield_op)
Expand All @@ -215,8 +221,8 @@ def lower_program_module(
op.program_module.block,
InsertPoint.at_end(module_block),
arg_values=[
param_width.res,
param_height.res,
param_width,
param_height,
*params_from_block_args,
*(y.res for y in yield_args),
],
Expand Down Expand Up @@ -316,8 +322,21 @@ class LowerCslWrapperPass(ModulePass):

name = "lower-csl-wrapper"

params_as_consts: bool = False
"""
Set to lower numerical module wrapper params that have a default value to constants,
instead of lowering to csl params. Set flag to disallow command line overrides.
Module wrapper params without a default value will always be lowered to csl params
(hint: consider removing default values in cases where this is desired).
"""

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier([ExtractCslModules(), LowerImport()]),
GreedyRewritePatternApplier(
[
ExtractCslModules(params_as_consts=self.params_as_consts),
LowerImport(),
]
),
apply_recursively=False,
).rewrite_module(op)

0 comments on commit 3f81f3e

Please sign in to comment.