Skip to content

Commit

Permalink
transformations: (convert-stencil-to-csl-stencil) Support multiple in…
Browse files Browse the repository at this point in the history
…put buffers (#3575)

Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Dec 6, 2024
1 parent b5acf31 commit 8e58c88
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 48 deletions.
45 changes: 43 additions & 2 deletions tests/filecheck/transforms/convert-stencil-to-csl-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ builtin.module {

// CHECK-NEXT: func.func @coefficients(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>]}> ({
// CHECK-NEXT: %1 = csl_stencil.apply(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %0 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 0>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 2.345678e-01 : f32>, #csl_stencil.coeff<#stencil.index<[0, -1]>, 3.141500e-01 : f32>]}> ({
// CHECK-NEXT: ^0(%2 : tensor<4x255xf32>, %3 : index, %4 : tensor<510xf32>):
// CHECK-NEXT: %5 = arith.constant dense<1.234500e-01> : tensor<510xf32>
// CHECK-NEXT: %6 = arith.constant dense<2.345678e-01> : tensor<510xf32>
Expand Down Expand Up @@ -191,7 +191,7 @@ builtin.module {
// CHECK-NEXT: %2 = arith.constant 1 : index
// CHECK-NEXT: %3, %4 = scf.for %arg2 = %1 to %0 step %2 iter_args(%arg3 = %arg0, %arg4 = %arg1) -> (!stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) {
// CHECK-NEXT: %5 = tensor.empty() : tensor<600xf32>
// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [2, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [-2, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, 2]>, #csl_stencil.exchange<to [0, -1]>, #csl_stencil.exchange<to [0, -2]>], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 1>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>]}> ({
// CHECK-NEXT: csl_stencil.apply(%arg3 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>, %5 : tensor<600xf32>) outs (%arg4 : !stencil.field<[-2,3]x[-2,3]xtensor<604xf32>>) <{"swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [2, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [-2, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, 2]>, #csl_stencil.exchange<to [0, -1]>, #csl_stencil.exchange<to [0, -2]>], "topo" = #dmp.topo<600x600>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 1>, "coeffs" = [#csl_stencil.coeff<#stencil.index<[-1, 0]>, 1.196003e+05 : f32>, #csl_stencil.coeff<#stencil.index<[1, 0]>, 1.196003e+05 : f32>]}> ({
// CHECK-NEXT: ^0(%6 : tensor<8x300xf32>, %7 : index, %8 : tensor<600xf32>):
// CHECK-NEXT: %9 = arith.constant dense<1.287158e+09> : tensor<600xf32>
// CHECK-NEXT: %10 = arith.constant dense<1.196003e+05> : tensor<600xf32>
Expand All @@ -214,6 +214,47 @@ builtin.module {
// CHECK-NEXT: func.return
// CHECK-NEXT: }

func.func @uvbke(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) {
"dmp.swap"(%arg0) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<64x64>, false>, "swaps" = [#dmp.exchange<at [0, -1, 0] size [1, 1, 64] source offset [0, 1, 0] to [0, -1, 0]>]} : (!stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) -> ()
"dmp.swap"(%arg1) {"strategy" = #dmp.grid_slice_2d<#dmp.topo<64x64>, false>, "swaps" = [#dmp.exchange<at [-1, 0, 0] size [1, 1, 64] source offset [1, 0, 0] to [-1, 0, 0]>]} : (!stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) -> ()
stencil.apply(%arg6 = %arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %arg7 = %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) outs (%arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) {
%0 = stencil.access %arg7[-1, 0] : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>
%1 = "tensor.extract_slice"(%0) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 64>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<64xf32>) -> tensor<64xf32>
%2 = stencil.access %arg7[0, 0] : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>
%3 = "tensor.extract_slice"(%2) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 64>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<64xf32>) -> tensor<64xf32>
%4 = stencil.access %arg6[0, -1] : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>
%5 = "tensor.extract_slice"(%4) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 64>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<64xf32>) -> tensor<64xf32>
%6 = stencil.access %arg6[0, 0] : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>
%7 = "tensor.extract_slice"(%6) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 64>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 0, 0, 0>}> : (tensor<64xf32>) -> tensor<64xf32>
%8 = arith.addf %1, %3 : tensor<64xf32>
%9 = arith.addf %8, %5 : tensor<64xf32>
%10 = arith.addf %9, %7 : tensor<64xf32>
stencil.return %10 : tensor<64xf32>
} to <[0, 0], [1, 1]>
func.return
}

// CHECK-NEXT: func.func @uvbke(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) {
// CHECK-NEXT: %0 = "csl_stencil.prefetch"(%arg1) <{"topo" = #dmp.topo<64x64>, "swaps" = [#csl_stencil.exchange<to [-1, 0]>]}> : (!stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) -> tensor<1x64xf32>
// CHECK-NEXT: %1 = tensor.empty() : tensor<64xf32>
// CHECK-NEXT: csl_stencil.apply(%arg0 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %1 : tensor<64xf32>, %arg1 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %0 : tensor<1x64xf32>) outs (%arg4 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>) <{"swaps" = [#csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<64x64>, "num_chunks" = 2 : i64, "bounds" = #stencil.bounds<[0, 0], [1, 1]>, "operandSegmentSizes" = array<i32: 1, 1, 0, 2, 1>}> ({
// CHECK-NEXT: ^0(%2 : tensor<1x32xf32>, %3 : index, %4 : tensor<64xf32>):
// CHECK-NEXT: %5 = csl_stencil.access %2[0, -1] : tensor<1x32xf32>
// CHECK-NEXT: %6 = "tensor.insert_slice"(%5, %4, %3) <{"static_offsets" = array<i64: -9223372036854775808>, "static_sizes" = array<i64: 32>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<32xf32>, tensor<64xf32>, index) -> tensor<64xf32>
// CHECK-NEXT: csl_stencil.yield %6 : tensor<64xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%7 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %8 : tensor<64xf32>, %9 : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>, %10 : tensor<1x64xf32>):
// CHECK-NEXT: %11 = csl_stencil.access %10[-1, 0] : tensor<1x64xf32>
// CHECK-NEXT: %12 = csl_stencil.access %9[0, 0] : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>
// CHECK-NEXT: %13 = csl_stencil.access %7[0, 0] : !stencil.field<[-1,1]x[-1,1]xtensor<64xf32>>
// CHECK-NEXT: %14 = arith.addf %8, %11 : tensor<64xf32>
// CHECK-NEXT: %15 = arith.addf %14, %12 : tensor<64xf32>
// CHECK-NEXT: %16 = arith.addf %15, %13 : tensor<64xf32>
// CHECK-NEXT: csl_stencil.yield %16 : tensor<64xf32>
// CHECK-NEXT: }) to <[0, 0], [1, 1]>
// CHECK-NEXT: func.return
// CHECK-NEXT: }


}
// CHECK-NEXT: }
88 changes: 42 additions & 46 deletions xdsl/transforms/convert_stencil_to_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
Region,
SSAValue,
)
from xdsl.irdl import Operand, base
from xdsl.irdl import Operand
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down Expand Up @@ -82,7 +82,7 @@ def get_prefetch_overhead(o: OpResult):
return

# select the prefetch with the biggest communication overhead to be fused with matched stencil.apply
prefetch = max(candidate_prefetches)[1]
prefetch = max(candidate_prefetches, key=lambda x: x[0])[1]
return op.operands.index(prefetch)


Expand All @@ -99,7 +99,7 @@ def _get_apply_op(op: Operation) -> stencil.ApplyOp | None:


@dataclass(frozen=True)
class ConvertAccessOpFromPrefetchPattern(RewritePattern):
class ConvertAccessOpPattern(RewritePattern):
"""
Rebuilds stencil.access by csl_stencil.access which operates on prefetched accesses.
Expand All @@ -109,51 +109,40 @@ class ConvertAccessOpFromPrefetchPattern(RewritePattern):
Note: This is intended to be called in a nested pattern rewriter, such that the above precondition is met.
"""

arg_index: int

@op_type_rewrite_pattern
def match_and_rewrite(self, op: stencil.AccessOp, rewriter: PatternRewriter, /):
assert len(op.offset) == 2
# translate access to own data or non-prefetch data, which operates on stencil.TempType
if op.temp != op.get_apply().region.block.args[self.arg_index] or tuple(
op.offset
) == (0, 0):
assert isattr(op.res.type, base(AnyTensorType))
rewriter.replace_matched_op(
csl_stencil.AccessOp(
op=op.temp,
offset=op.offset,
offset_mapping=op.offset_mapping,
result_type=op.res.type,
)
if isa(op.temp.type, AnyTensorType):
res_type = TensorType(
op.temp.type.get_element_type(), op.temp.type.get_shape()[1:]
)
else:
assert isa(op.res.type, AnyTensorType)
res_type = op.res.type
rewriter.replace_matched_op(
new_access_op := csl_stencil.AccessOp(
op=op.temp,
offset=op.offset,
offset_mapping=op.offset_mapping,
result_type=res_type,
)
return

prefetched_arg = op.get_apply().region.block.args[-1]
assert isa(t_type := prefetched_arg.type, TensorType[Attribute])

csl_access_op = csl_stencil.AccessOp(
op=prefetched_arg,
offset=op.offset,
offset_mapping=op.offset_mapping,
result_type=TensorType(t_type.get_element_type(), t_type.get_shape()[1:]),
)

# The stencil-tensorize-z-dimension pass inserts tensor.ExtractSliceOps after stencil.access to remove ghost cells.
# Since ghost cells are not prefetched, these ops can be removed again. Check if the ExtractSliceOp
# has no other effect and if so, remove both.
if (
len(op.res.uses) == 1
and isinstance(use := list(op.res.uses)[0].operation, tensor.ExtractSliceOp)
and use.static_sizes.get_values() == t_type.get_shape()[1:]
len(new_access_op.result.uses) == 1
and isinstance(
use := list(new_access_op.result.uses)[0].operation,
tensor.ExtractSliceOp,
)
and use.static_sizes.get_values() == res_type.get_shape()
and len(use.offsets) == 0
and len(use.sizes) == 0
and len(use.strides) == 0
):
rewriter.replace_op(use, csl_access_op)
rewriter.erase_op(op)
else:
rewriter.replace_matched_op(csl_access_op)
rewriter.replace_op(use, [], new_results=[new_access_op.result])


@dataclass(frozen=True)
Expand Down Expand Up @@ -228,11 +217,17 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):

# arg_idx points to the stencil.temp type whose data is prefetched in a separate buffer
arg_idx = apply_op.args.index(op.input_stencil)
field_block_arg = apply_op.region.block.args[arg_idx]

# add the prefetched buffer as the last arg to stencil.access
apply_op.region.block.insert_arg(
prefetch_block_arg = apply_op.region.block.insert_arg(
prefetch_op.result.type, len(apply_op.args)
)
field_block_arg.replace_by_if(
prefetch_block_arg,
lambda use: isinstance(use.operation, stencil.AccessOp)
and tuple(use.operation.offset) != (0, 0),
)

# rebuild stencil.apply op
r_types = apply_op.result_types
Expand All @@ -246,14 +241,6 @@ def match_and_rewrite(self, op: dmp.SwapOp, rewriter: PatternRewriter, /):
)
rewriter.replace_op(apply_op, new_apply_op)

# replace stencil.access (operating on stencil.temp at arg_index)
# with csl_stencil.access (operating on memref at last arg index)
nested_rewriter = PatternRewriteWalker(
ConvertAccessOpFromPrefetchPattern(arg_idx), listener=rewriter
)

nested_rewriter.rewrite_region(new_apply_op.region)


def split_ops(
ops: Sequence[Operation], buf: BlockArgument
Expand Down Expand Up @@ -506,7 +493,6 @@ def match_and_rewrite(self, op: stencil.ApplyOp, rewriter: PatternRewriter, /):
# add operations from list to receive_chunk, use translation table to rebuild operands
for o in chunk_region_ops:
if isinstance(o, stencil.ReturnOp | csl_stencil.YieldOp):
rewriter.erase_op(o)
break
o.operands = [chunk_region_oprnd_table.get(x, x) for x in o.operands]
rewriter.insert_op(o, InsertPoint.at_end(receive_chunk.block))
Expand Down Expand Up @@ -604,13 +590,23 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
GreedyRewritePatternApplier(
[
ConvertSwapToPrefetchPattern(),
ConvertApplyOpPattern(num_chunks=self.num_chunks),
PromoteCoefficients(),
ConvertAccessOpPattern(),
]
),
walk_reverse=False,
apply_recursively=False,
)
module_pass.rewrite_module(op)
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
ConvertApplyOpPattern(num_chunks=self.num_chunks),
PromoteCoefficients(),
]
),
apply_recursively=False,
).rewrite_module(op)

ConvertVarithToArithPass().apply(ctx, op)

if self.num_chunks > 1:
Expand Down

0 comments on commit 8e58c88

Please sign in to comment.