diff --git a/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir b/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir new file mode 100644 index 0000000000..3b99518a69 --- /dev/null +++ b/tests/filecheck/dialects/csl/csl-stencil-canonicalize.mlir @@ -0,0 +1,84 @@ +// RUN: xdsl-opt %s -p canonicalize --split-input-file | filecheck %s + + +builtin.module { + func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { + %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> + + %1 = tensor.empty() : tensor<510xf32> + %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({ + ^0(%3 : memref<4xtensor<255xf32>>, %4 : index, %5 : tensor<510xf32>): + %6 = csl_stencil.access %3[1, 0] : memref<4xtensor<255xf32>> + %7 = "tensor.insert_slice"(%6, %5, %4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> + csl_stencil.yield %7 : tensor<510xf32> + }, { + ^0(%8 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %9 : tensor<510xf32>): + csl_stencil.yield %9 : tensor<510xf32> + }) + stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + + %10 = tensor.empty() : tensor<510xf32> + %11 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({ + ^0(%12 : memref<4xtensor<255xf32>>, %13 : index, %14 : tensor<510xf32>): + %15 = csl_stencil.access %12[1, 0] : memref<4xtensor<255xf32>> + %16 = "tensor.insert_slice"(%15, %14, %13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> + csl_stencil.yield %16 : tensor<510xf32> + }, { + ^0(%17 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %18 : tensor<510xf32>): + csl_stencil.yield %18 : tensor<510xf32> + }) + stencil.store %11 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + + %19 = tensor.empty() : tensor<510xf32> + %20 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %19 : tensor<510xf32>) <{"num_chunks" = 2, "topo" = #dmp.topo<1022x510>, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange]}> -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({ + ^0(%21 : memref<4xtensor<255xf32>>, %22 : index, %23 : tensor<510xf32>): + %24 = csl_stencil.access %21[1, 0] : memref<4xtensor<255xf32>> + %25 = "tensor.insert_slice"(%24, %23, %22) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> + csl_stencil.yield %25 : tensor<510xf32> + }, { + ^0(%26 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %27 : tensor<510xf32>): + csl_stencil.yield %27 : tensor<510xf32> + }) + stencil.store %20 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> + func.return + } +} + + +// CHECK-NEXT: builtin.module { +// CHECK-NEXT: func.func @gauss_seidel(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { +// CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>> +// CHECK-NEXT: %1 = tensor.empty() : tensor<510xf32> +// CHECK-NEXT: %2 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({ +// CHECK-NEXT: ^0(%3 : memref<4xtensor<255xf32>>, %4 : index, %5 : tensor<510xf32>): +// CHECK-NEXT: %6 = csl_stencil.access %3[1, 0] : memref<4xtensor<255xf32>> +// CHECK-NEXT: %7 = "tensor.insert_slice"(%6, %5, %4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %7 : tensor<510xf32> +// CHECK-NEXT: }, { +// CHECK-NEXT: ^1(%8 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %9 : tensor<510xf32>): +// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32> +// CHECK-NEXT: }) +// CHECK-NEXT: stencil.store %2 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %3 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({ +// CHECK-NEXT: ^0(%4 : memref<4xtensor<255xf32>>, %5 : index, %6 : tensor<510xf32>): +// CHECK-NEXT: %7 = csl_stencil.access %4[1, 0] : memref<4xtensor<255xf32>> +// CHECK-NEXT: %8 = "tensor.insert_slice"(%7, %6, %5) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %8 : tensor<510xf32> +// CHECK-NEXT: }, { +// CHECK-NEXT: ^1(%9 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %10 : tensor<510xf32>): +// CHECK-NEXT: csl_stencil.yield %10 : tensor<510xf32> +// CHECK-NEXT: }) +// CHECK-NEXT: stencil.store %3 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %4 = csl_stencil.apply(%0 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %1 : tensor<510xf32>) -> (!stencil.temp<[0,1]x[0,1]xtensor<510xf32>>) ({ +// CHECK-NEXT: ^0(%5 : memref<4xtensor<255xf32>>, %6 : index, %7 : tensor<510xf32>): +// CHECK-NEXT: %8 = csl_stencil.access %5[1, 0] : memref<4xtensor<255xf32>> +// CHECK-NEXT: %9 = "tensor.insert_slice"(%8, %7, %6) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<255xf32>, tensor<510xf32>, index) -> tensor<510xf32> +// CHECK-NEXT: csl_stencil.yield %9 : tensor<510xf32> +// CHECK-NEXT: }, { +// CHECK-NEXT: ^1(%10 : !stencil.temp<[-1,2]x[-1,2]xtensor<512xf32>>, %11 : tensor<510xf32>): +// CHECK-NEXT: csl_stencil.yield %11 : tensor<510xf32> +// CHECK-NEXT: }) +// CHECK-NEXT: stencil.store %4 to %b ([0, 0] : [1, 1]) : !stencil.temp<[0,1]x[0,1]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: func.return +// CHECK-NEXT: } +// CHECK-NEXT: } diff --git a/xdsl/dialects/csl/csl_stencil.py b/xdsl/dialects/csl/csl_stencil.py index ef5ecd59d7..ef0c2bc1f5 100644 --- a/xdsl/dialects/csl/csl_stencil.py +++ b/xdsl/dialects/csl/csl_stencil.py @@ -35,9 +35,11 @@ var_result_def, ) from xdsl.parser import AttrParser, Parser +from xdsl.pattern_rewriter import RewritePattern from xdsl.printer import Printer from xdsl.traits import ( HasAncestor, + HasCanonicalisationPatternsTrait, HasParent, IsolatedFromAbove, IsTerminator, @@ -146,6 +148,16 @@ def __init__( ) +class ApplyOpHasCanonicalizationPatternsTrait(HasCanonicalisationPatternsTrait): + @classmethod + def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]: + from xdsl.transforms.canonicalization_patterns.csl_stencil import ( + RedundantIterArgInitialisation, + ) + + return (RedundantIterArgInitialisation(),) + + @irdl_op_definition class ApplyOp(IRDLOperation): """ @@ -205,7 +217,13 @@ class ApplyOp(IRDLOperation): res = var_result_def(stencil.TempType) - traits = frozenset([IsolatedFromAbove(), RecursiveMemoryEffect()]) + traits = frozenset( + [ + IsolatedFromAbove(), + ApplyOpHasCanonicalizationPatternsTrait(), + RecursiveMemoryEffect(), + ] + ) def print(self, printer: Printer): def print_arg(arg: tuple[SSAValue, Attribute]): diff --git a/xdsl/transforms/canonicalization_patterns/csl_stencil.py b/xdsl/transforms/canonicalization_patterns/csl_stencil.py new file mode 100644 index 0000000000..4fe5896d85 --- /dev/null +++ b/xdsl/transforms/canonicalization_patterns/csl_stencil.py @@ -0,0 +1,33 @@ +from xdsl.dialects import tensor +from xdsl.dialects.csl import csl_stencil +from xdsl.ir import OpResult +from xdsl.pattern_rewriter import ( + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) + + +class RedundantIterArgInitialisation(RewritePattern): + """ + Removes redundant allocations of empty tensors with no uses other than passed + as `iter_arg` to `csl_stencil.apply`. Prefer re-use where possible. + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter + ) -> None: + if len(op.iter_arg.uses) > 1: + return + + next_apply = op + while (next_apply := next_apply.next_op) is not None: + if ( + isinstance(next_apply, csl_stencil.ApplyOp) + and len(next_apply.iter_arg.uses) == 1 + and isinstance(next_apply.iter_arg, OpResult) + and isinstance(next_apply.iter_arg.op, tensor.EmptyOp) + and op.iter_arg.type == next_apply.iter_arg.type + ): + rewriter.replace_op(next_apply.iter_arg.op, [], [op.iter_arg])