Skip to content

Commit

Permalink
(transform): csl_stencil canonicalization pass (#2814)
Browse files Browse the repository at this point in the history
Adds a canonicalisation pass for `csl_stencil.apply`. The op takes an
empty tensor as`iter_arg`, which it does not manage itself. The
conversion pass in #2803 initialises an `iter_arg` for each instance of
`apply`. This canonicalisation pass identifies where this can be
re-used, effectively removing redundant allocations.

---------

Co-authored-by: n-io <n-io@users.noreply.github.com>
  • Loading branch information
n-io and n-io authored Jul 26, 2024
1 parent 2d67e7c commit 7f09044
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 1 deletion.
84 changes: 84 additions & 0 deletions tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// RUN: xdsl-opt %s -p canonicalize --split-input-file | filecheck %s


builtin.module {
func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
%0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>

%1 = tensor.empty() : tensor<510xf32>
%2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
^0(%3 : memref<4xtensor<255xf32>>, %4 : index, %5 : tensor<510xf32>):
%6 = csl_stencil.access %3[1, 0] : memref<4xtensor<255xf32>>
%7 = "tensor.insert_slice"(%6, %5, %4) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
csl_stencil.yield %7 : tensor<510xf32>
}, {
^0(%8 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %9 : tensor<510xf32>):
csl_stencil.yield %9 : tensor<510xf32>
})
stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>

%10 = tensor.empty() : tensor<510xf32>
%11 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
^0(%12 : memref<4xtensor<255xf32>>, %13 : index, %14 : tensor<510xf32>):
%15 = csl_stencil.access %12[1, 0] : memref<4xtensor<255xf32>>
%16 = "tensor.insert_slice"(%15, %14, %13) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
csl_stencil.yield %16 : tensor<510xf32>
}, {
^0(%17 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %18 : tensor<510xf32>):
csl_stencil.yield %18 : tensor<510xf32>
})
stencil.store %11 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>

%19 = tensor.empty() : tensor<510xf32>
%20 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %19 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
^0(%21 : memref<4xtensor<255xf32>>, %22 : index, %23 : tensor<510xf32>):
%24 = csl_stencil.access %21[1, 0] : memref<4xtensor<255xf32>>
%25 = "tensor.insert_slice"(%24, %23, %22) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
csl_stencil.yield %25 : tensor<510xf32>
}, {
^0(%26 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %27 : tensor<510xf32>):
csl_stencil.yield %27 : tensor<510xf32>
})
stencil.store %20 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
func.return
}
}


// CHECK-NEXT: builtin.module {
// CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) {
// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>
// CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32>
// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
// CHECK-NEXT: ^0(%3 : memref<4xtensor<255xf32>>, %4 : index, %5 : tensor<510xf32>):
// CHECK-NEXT: %6 = csl_stencil.access %3[1, 0] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %7 = "tensor.insert_slice"(%6, %5, %4) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %7 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%8 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %9 : tensor<510xf32>):
// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32>
// CHECK-NEXT: })
// CHECK-NEXT: stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %3 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
// CHECK-NEXT: ^0(%4 : memref<4xtensor<255xf32>>, %5 : index, %6 : tensor<510xf32>):
// CHECK-NEXT: %7 = csl_stencil.access %4[1, 0] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %8 = "tensor.insert_slice"(%7, %6, %5) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %8 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%9 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>):
// CHECK-NEXT: csl_stencil.yield %10 : tensor<510xf32>
// CHECK-NEXT: })
// CHECK-NEXT: stencil.store %3 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: %4 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({
// CHECK-NEXT: ^0(%5 : memref<4xtensor<255xf32>>, %6 : index, %7 : tensor<510xf32>):
// CHECK-NEXT: %8 = csl_stencil.access %5[1, 0] : memref<4xtensor<255xf32>>
// CHECK-NEXT: %9 = "tensor.insert_slice"(%8, %7, %6) <{"static_offsets" = array<i64: 0>, "static_sizes" = array<i64: 255>, "static_strides" = array<i64: 1>, "operandSegmentSizes" = array<i32: 1, 1, 1, 0, 0>}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32>
// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32>
// CHECK-NEXT: }, {
// CHECK-NEXT: ^1(%10 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %11 : tensor<510xf32>):
// CHECK-NEXT: csl_stencil.yield %11 : tensor<510xf32>
// CHECK-NEXT: })
// CHECK-NEXT: stencil.store %4 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>
// CHECK-NEXT: func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
20 changes: 19 additions & 1 deletion xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@
var_result_def,
)
from xdsl.parser import AttrParser, Parser
from xdsl.pattern_rewriter import RewritePattern
from xdsl.printer import Printer
from xdsl.traits import (
HasAncestor,
HasCanonicalisationPatternsTrait,
HasParent,
IsolatedFromAbove,
IsTerminator,
Expand Down Expand Up @@ -146,6 +148,16 @@ def __init__(
)


class ApplyOpHasCanonicalizationPatternsTrait(HasCanonicalisationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from xdsl.transforms.canonicalization_patterns.csl_stencil import (
RedundantIterArgInitialisation,
)

return (RedundantIterArgInitialisation(),)


@irdl_op_definition
class ApplyOp(IRDLOperation):
"""
Expand Down Expand Up @@ -205,7 +217,13 @@ class ApplyOp(IRDLOperation):

res = var_result_def(stencil.TempType)

traits = frozenset([IsolatedFromAbove(), RecursiveMemoryEffect()])
traits = frozenset(
[
IsolatedFromAbove(),
ApplyOpHasCanonicalizationPatternsTrait(),
RecursiveMemoryEffect(),
]
)

def print(self, printer: Printer):
def print_arg(arg: tuple[SSAValue, Attribute]):
Expand Down
33 changes: 33 additions & 0 deletions xdsl/transforms/canonicalization_patterns/csl_stencil.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from xdsl.dialects import tensor
from xdsl.dialects.csl import csl_stencil
from xdsl.ir import OpResult
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)


class RedundantIterArgInitialisation(RewritePattern):
"""
Removes redundant allocations of empty tensors with no uses other than passed
as `iter_arg` to `csl_stencil.apply`. Prefer re-use where possible.
"""

@op_type_rewrite_pattern
def match_and_rewrite(
self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter
) -> None:
if len(op.iter_arg.uses) > 1:
return

next_apply = op
while (next_apply := next_apply.next_op) is not None:
if (
isinstance(next_apply, csl_stencil.ApplyOp)
and len(next_apply.iter_arg.uses) == 1
and isinstance(next_apply.iter_arg, OpResult)
and isinstance(next_apply.iter_arg.op, tensor.EmptyOp)
and op.iter_arg.type == next_apply.iter_arg.type
):
rewriter.replace_op(next_apply.iter_arg.op, [], [op.iter_arg])

0 comments on commit 7f09044

Please sign in to comment.