From 3317a9d3ba8e419e47570326afc485f3a78240c2 Mon Sep 17 00:00:00 2001 From: n-io Date: Wed, 31 Jul 2024 15:19:08 +0200 Subject: [PATCH 1/7] add arith constant tensors --- .../stencil-tensorize-z-dimension.mlir | 16 +++---- xdsl/dialects/arith.py | 4 +- .../stencil_tensorize_z_dimension.py | 48 +++++++++++++++---- 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir b/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir index eed23ff46d..bee532a828 100644 --- a/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir +++ b/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir @@ -33,7 +33,7 @@ builtin.module { // CHECK-NEXT: %1 = stencil.load %0 : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: %2 = stencil.external_load %b : memref<1024x512x512xf32> -> !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: %3 = stencil.apply(%4 = %1 : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1022]x[0,510]xtensor<510xf32>>) { -// CHECK-NEXT: %5 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: %5 = arith.constant dense<1.666600e-01> : tensor<510xf32> // CHECK-NEXT: %6 = stencil.access %4[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: %7 = "tensor.extract_slice"(%6) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> // CHECK-NEXT: %8 = stencil.access %4[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> @@ -51,10 +51,8 @@ builtin.module { // CHECK-NEXT: %20 = arith.addf %19, %11 : tensor<510xf32> // CHECK-NEXT: %21 = arith.addf %20, %9 : tensor<510xf32> // CHECK-NEXT: %22 = arith.addf %21, %7 : tensor<510xf32> -// CHECK-NEXT: %23 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %24 = linalg.fill ins(%5 : f32) outs(%23 : tensor<510xf32>) -> tensor<510xf32> -// CHECK-NEXT: %25 = arith.mulf %22, %24 : tensor<510xf32> -// CHECK-NEXT: stencil.return %25 : tensor<510xf32> +// CHECK-NEXT: %23 = arith.mulf %22, %5 : tensor<510xf32> +// CHECK-NEXT: stencil.return %23 : tensor<510xf32> // CHECK-NEXT: } // CHECK-NEXT: stencil.store %3 to %2 (<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return @@ -86,7 +84,7 @@ builtin.module { // CHECK: func.func @gauss_seidel_func(%a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>, %b : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>>) { // CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: %1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1022]x[0,510]xtensor<510xf32>>) { -// CHECK-NEXT: %3 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: %3 = arith.constant dense<1.666600e-01> : tensor<510xf32> // CHECK-NEXT: %4 = stencil.access %2[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: %5 = "tensor.extract_slice"(%4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> // CHECK-NEXT: %6 = stencil.access %2[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> @@ -104,10 +102,8 @@ builtin.module { // CHECK-NEXT: %18 = arith.addf %17, %9 : tensor<510xf32> // CHECK-NEXT: %19 = arith.addf %18, %7 : tensor<510xf32> // CHECK-NEXT: %20 = arith.addf %19, %5 : tensor<510xf32> -// CHECK-NEXT: %21 = tensor.empty() : tensor<510xf32> -// CHECK-NEXT: %22 = linalg.fill ins(%3 : f32) outs(%21 : tensor<510xf32>) -> tensor<510xf32> -// CHECK-NEXT: %23 = arith.mulf %20, %22 : tensor<510xf32> -// CHECK-NEXT: stencil.return %23 : tensor<510xf32> +// CHECK-NEXT: %21 = arith.mulf %20, %3 : tensor<510xf32> +// CHECK-NEXT: stencil.return %21 : tensor<510xf32> // CHECK-NEXT: } // CHECK-NEXT: stencil.store %1 to %b (<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 3d7d65d0bc..856686ba38 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -110,7 +110,9 @@ class Constant(IRDLOperation): @overload def __init__( - self, value: AnyIntegerAttr | FloatAttr[AnyFloat], value_type: None = None + self, + value: AnyIntegerAttr | FloatAttr[AnyFloat] | Attribute, + value_type: None = None, ) -> None: ... @overload diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 3f8ae6286b..2406f2a6dd 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -7,6 +7,7 @@ from xdsl.dialects.arith import ( Addf, BinaryOperation, + Constant, Divf, FloatingPointLikeBinaryOp, Mulf, @@ -16,6 +17,7 @@ AnyFloat, ArrayAttr, ContainerType, + DenseIntOrFPElementsAttr, IntAttr, ModuleOp, ShapedType, @@ -40,6 +42,8 @@ from xdsl.ir import ( Attribute, Operation, + OpResult, + SSAValue, ) from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( @@ -140,6 +144,24 @@ def arithBinaryOpTensorize( rewriter: PatternRewriter, /, ): + def rewrite_scalar_operand( + scalar_op: SSAValue, dest_typ: TensorType[Attribute] + ) -> SSAValue: + if isinstance(scalar_op, OpResult) and isinstance(scalar_op.op, Constant): + tens_const = Constant( + DenseIntOrFPElementsAttr([dest_typ, ArrayAttr([scalar_op.op.value])]) + ) + if len(scalar_op.uses) > 1: + rewriter.insert_op(tens_const, InsertPoint.before(scalar_op.op)) + else: + rewriter.replace_op(scalar_op.op, tens_const) + return tens_const.result + emptyop = EmptyOp((), dest_typ) + fillop = FillOp((scalar_op,), (emptyop,), (dest_typ,)) + rewriter.insert_op(emptyop, InsertPoint.before(op)) + rewriter.insert_op(fillop, InsertPoint.before(op)) + return fillop.res[0] + type_constructor = type(op) if is_tensor(op.result.type): return @@ -148,20 +170,14 @@ def arithBinaryOpTensorize( type_constructor(op.lhs, op.rhs, flags=None, result_type=op.lhs.type) ) elif is_tensor(op.lhs.type) and is_scalar(op.rhs.type): - emptyop = EmptyOp((), op.lhs.type) - fillop = FillOp((op.rhs,), (emptyop,), (op.lhs.type,)) - rewriter.insert_op(emptyop, InsertPoint.before(op)) - rewriter.insert_op(fillop, InsertPoint.before(op)) + new_rhs = rewrite_scalar_operand(op.rhs, op.lhs.type) rewriter.replace_matched_op( - type_constructor(op.lhs, fillop, flags=None, result_type=op.lhs.type) + type_constructor(op.lhs, new_rhs, flags=None, result_type=op.lhs.type) ) elif is_scalar(op.lhs.type) and is_tensor(op.rhs.type): - emptyop = EmptyOp((), op.rhs.type) - fillop = FillOp((op.lhs,), (emptyop,), (op.rhs.type,)) - rewriter.insert_op(emptyop, InsertPoint.before(op)) - rewriter.insert_op(fillop, InsertPoint.before(op)) + new_lhs = rewrite_scalar_operand(op.lhs, op.rhs.type) rewriter.replace_matched_op( - type_constructor(fillop, op.rhs, flags=None, result_type=op.rhs.type) + type_constructor(new_lhs, op.rhs, flags=None, result_type=op.rhs.type) ) @@ -361,6 +377,17 @@ def match_and_rewrite(self, op: FillOp, rewriter: PatternRewriter, /): ) +class ConstOpUpdateShape(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: Constant, rewriter: PatternRewriter, /): + if typ := get_required_result_type(op): + if needs_update_shape(op.result.type, typ): + assert isinstance(op.value, DenseIntOrFPElementsAttr) + rewriter.replace_matched_op( + Constant(DenseIntOrFPElementsAttr([typ, op.value.data])) + ) + + @dataclass(frozen=True) class BackpropagateStencilShapes(ModulePass): """ @@ -379,6 +406,7 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: EmptyOpUpdateShape(), FillOpUpdateShape(), ArithOpUpdateShape(), + ConstOpUpdateShape(), ] ), walk_reverse=True, From 1fbc583e19688889fa2d725efa2ed4ffa70be900 Mon Sep 17 00:00:00 2001 From: n-io Date: Wed, 31 Jul 2024 15:34:45 +0200 Subject: [PATCH 2/7] fix assertion --- .../experimental/stencil_tensorize_z_dimension.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 2406f2a6dd..1e5ec9f6d6 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -380,12 +380,13 @@ def match_and_rewrite(self, op: FillOp, rewriter: PatternRewriter, /): class ConstOpUpdateShape(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: Constant, rewriter: PatternRewriter, /): - if typ := get_required_result_type(op): - if needs_update_shape(op.result.type, typ): - assert isinstance(op.value, DenseIntOrFPElementsAttr) - rewriter.replace_matched_op( - Constant(DenseIntOrFPElementsAttr([typ, op.value.data])) - ) + if is_tensor(op.result.type): + if typ := get_required_result_type(op): + if needs_update_shape(op.result.type, typ): + assert isinstance(op.value, DenseIntOrFPElementsAttr) + rewriter.replace_matched_op( + Constant(DenseIntOrFPElementsAttr([typ, op.value.data])) + ) @dataclass(frozen=True) From 22bf12443a4ca6949270028b0a7c58ef161ca73d Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 1 Aug 2024 11:51:17 +0200 Subject: [PATCH 3/7] rollback --- xdsl/dialects/arith.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xdsl/dialects/arith.py b/xdsl/dialects/arith.py index 856686ba38..3d7d65d0bc 100644 --- a/xdsl/dialects/arith.py +++ b/xdsl/dialects/arith.py @@ -110,9 +110,7 @@ class Constant(IRDLOperation): @overload def __init__( - self, - value: AnyIntegerAttr | FloatAttr[AnyFloat] | Attribute, - value_type: None = None, + self, value: AnyIntegerAttr | FloatAttr[AnyFloat], value_type: None = None ) -> None: ... @overload From 64ebd3bda8fcd8689e8d4dc72b39e02f4bd0e56f Mon Sep 17 00:00:00 2001 From: n-io Date: Thu, 1 Aug 2024 17:47:11 +0200 Subject: [PATCH 4/7] fix --- xdsl/transforms/experimental/stencil_tensorize_z_dimension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 1e5ec9f6d6..858abbabfa 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -157,7 +157,7 @@ def rewrite_scalar_operand( rewriter.replace_op(scalar_op.op, tens_const) return tens_const.result emptyop = EmptyOp((), dest_typ) - fillop = FillOp((scalar_op,), (emptyop,), (dest_typ,)) + fillop = FillOp((scalar_op,), (emptyop.tensor,), (dest_typ,)) rewriter.insert_op(emptyop, InsertPoint.before(op)) rewriter.insert_op(fillop, InsertPoint.before(op)) return fillop.res[0] From 781293eb25ccc9f58319586168dab70f30a77c59 Mon Sep 17 00:00:00 2001 From: n-io Date: Fri, 2 Aug 2024 10:50:30 +0200 Subject: [PATCH 5/7] update --- .../stencil_tensorize_z_dimension.py | 78 +++++++++++-------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 858abbabfa..99ea98ab8c 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -139,14 +139,52 @@ def match_and_rewrite(self, op: AccessOp, rewriter: PatternRewriter, /): rewriter.replace_matched_op(extract) -def arithBinaryOpTensorize( - op: FloatingPointLikeBinaryOp, - rewriter: PatternRewriter, - /, -): - def rewrite_scalar_operand( - scalar_op: SSAValue, dest_typ: TensorType[Attribute] +class ArithOpTensorize(RewritePattern): + """ + Tensorises arith binary ops. + If both operands are tensor types, rebuilds the op with matching result type. + If one operand is scalar and an `arith.constant`, change it to produce a tensor value directly. + If one operand is scalar and not an `arith.constant`, create an empty tensor and fill it with the scalar value. + """ + + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: Addf | Subf | Mulf | Divf, rewriter: PatternRewriter, / + ): + type_constructor = type(op) + if is_tensor(op.result.type): + return + if is_tensor(op.lhs.type) and is_tensor(op.rhs.type): + rewriter.replace_matched_op( + type_constructor(op.lhs, op.rhs, flags=None, result_type=op.lhs.type) + ) + elif is_tensor(op.lhs.type) and is_scalar(op.rhs.type): + new_rhs = ArithOpTensorize._rewrite_scalar_operand( + op.rhs, op.lhs.type, op, rewriter + ) + rewriter.replace_matched_op( + type_constructor(op.lhs, new_rhs, flags=None, result_type=op.lhs.type) + ) + elif is_scalar(op.lhs.type) and is_tensor(op.rhs.type): + new_lhs = ArithOpTensorize._rewrite_scalar_operand( + op.lhs, op.rhs.type, op, rewriter + ) + rewriter.replace_matched_op( + type_constructor(new_lhs, op.rhs, flags=None, result_type=op.rhs.type) + ) + + @staticmethod + def _rewrite_scalar_operand( + scalar_op: SSAValue, + dest_typ: TensorType[Attribute], + op: FloatingPointLikeBinaryOp, + rewriter: PatternRewriter, ) -> SSAValue: + """ + Rewrites a scalar operand into a tensor. + If it is a constant, modify the constant op directly. + If it is not a constant, create an empty tensor and `linalg.fill` it with the scalar value. + """ if isinstance(scalar_op, OpResult) and isinstance(scalar_op.op, Constant): tens_const = Constant( DenseIntOrFPElementsAttr([dest_typ, ArrayAttr([scalar_op.op.value])]) @@ -162,32 +200,6 @@ def rewrite_scalar_operand( rewriter.insert_op(fillop, InsertPoint.before(op)) return fillop.res[0] - type_constructor = type(op) - if is_tensor(op.result.type): - return - if is_tensor(op.lhs.type) and is_tensor(op.rhs.type): - rewriter.replace_matched_op( - type_constructor(op.lhs, op.rhs, flags=None, result_type=op.lhs.type) - ) - elif is_tensor(op.lhs.type) and is_scalar(op.rhs.type): - new_rhs = rewrite_scalar_operand(op.rhs, op.lhs.type) - rewriter.replace_matched_op( - type_constructor(op.lhs, new_rhs, flags=None, result_type=op.lhs.type) - ) - elif is_scalar(op.lhs.type) and is_tensor(op.rhs.type): - new_lhs = rewrite_scalar_operand(op.lhs, op.rhs.type) - rewriter.replace_matched_op( - type_constructor(new_lhs, op.rhs, flags=None, result_type=op.rhs.type) - ) - - -class ArithOpTensorize(RewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite( - self, op: Addf | Subf | Mulf | Divf, rewriter: PatternRewriter, / - ): - arithBinaryOpTensorize(op, rewriter) - @dataclass(frozen=True) class ApplyOpTensorize(RewritePattern): From 8a285b4173af2fffdad727811a89e4d958394180 Mon Sep 17 00:00:00 2001 From: n-io Date: Fri, 2 Aug 2024 13:48:17 +0200 Subject: [PATCH 6/7] small rework --- .../experimental/stencil_tensorize_z_dimension.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py index 99ea98ab8c..fe6a0d439f 100644 --- a/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py +++ b/xdsl/transforms/experimental/stencil_tensorize_z_dimension.py @@ -143,7 +143,7 @@ class ArithOpTensorize(RewritePattern): """ Tensorises arith binary ops. If both operands are tensor types, rebuilds the op with matching result type. - If one operand is scalar and an `arith.constant`, change it to produce a tensor value directly. + If one operand is scalar and an `arith.constant`, create a tensor constant directly. If one operand is scalar and not an `arith.constant`, create an empty tensor and fill it with the scalar value. """ @@ -182,17 +182,14 @@ def _rewrite_scalar_operand( ) -> SSAValue: """ Rewrites a scalar operand into a tensor. - If it is a constant, modify the constant op directly. + If it is a constant, create a corresponding tensor constant. If it is not a constant, create an empty tensor and `linalg.fill` it with the scalar value. """ if isinstance(scalar_op, OpResult) and isinstance(scalar_op.op, Constant): tens_const = Constant( DenseIntOrFPElementsAttr([dest_typ, ArrayAttr([scalar_op.op.value])]) ) - if len(scalar_op.uses) > 1: - rewriter.insert_op(tens_const, InsertPoint.before(scalar_op.op)) - else: - rewriter.replace_op(scalar_op.op, tens_const) + rewriter.insert_op(tens_const, InsertPoint.before(scalar_op.op)) return tens_const.result emptyop = EmptyOp((), dest_typ) fillop = FillOp((scalar_op,), (emptyop.tensor,), (dest_typ,)) From 98a14ba0b49b83ac3cdfe86c9d5eb4a89f264394 Mon Sep 17 00:00:00 2001 From: n-io Date: Fri, 2 Aug 2024 14:10:38 +0200 Subject: [PATCH 7/7] fix filecheck --- .../stencil-tensorize-z-dimension.mlir | 78 ++++++++++--------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir b/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir index bee532a828..3d717759ea 100644 --- a/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir +++ b/tests/filecheck/transforms/stencil-tensorize-z-dimension.mlir @@ -34,25 +34,26 @@ builtin.module { // CHECK-NEXT: %2 = stencil.external_load %b : memref<1024x512x512xf32> -> !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: %3 = stencil.apply(%4 = %1 : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1022]x[0,510]xtensor<510xf32>>) { // CHECK-NEXT: %5 = arith.constant dense<1.666600e-01> : tensor<510xf32> -// CHECK-NEXT: %6 = stencil.access %4[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %7 = "tensor.extract_slice"(%6) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %8 = stencil.access %4[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %9 = "tensor.extract_slice"(%8) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %10 = stencil.access %4[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %12 = stencil.access %4[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %13 = "tensor.extract_slice"(%12) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %14 = stencil.access %4[0, 1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %15 = "tensor.extract_slice"(%14) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %16 = stencil.access %4[0, -1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %17 = "tensor.extract_slice"(%16) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %18 = arith.addf %17, %15 : tensor<510xf32> -// CHECK-NEXT: %19 = arith.addf %18, %13 : tensor<510xf32> -// CHECK-NEXT: %20 = arith.addf %19, %11 : tensor<510xf32> -// CHECK-NEXT: %21 = arith.addf %20, %9 : tensor<510xf32> -// CHECK-NEXT: %22 = arith.addf %21, %7 : tensor<510xf32> -// CHECK-NEXT: %23 = arith.mulf %22, %5 : tensor<510xf32> -// CHECK-NEXT: stencil.return %23 : tensor<510xf32> +// CHECK-NEXT: %6 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: %7 = stencil.access %4[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %8 = "tensor.extract_slice"(%7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %9 = stencil.access %4[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %10 = "tensor.extract_slice"(%9) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %11 = stencil.access %4[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %12 = "tensor.extract_slice"(%11) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %13 = stencil.access %4[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %15 = stencil.access %4[0, 1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %16 = "tensor.extract_slice"(%15) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %17 = stencil.access %4[0, -1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %18 = "tensor.extract_slice"(%17) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %19 = arith.addf %18, %16 : tensor<510xf32> +// CHECK-NEXT: %20 = arith.addf %19, %14 : tensor<510xf32> +// CHECK-NEXT: %21 = arith.addf %20, %12 : tensor<510xf32> +// CHECK-NEXT: %22 = arith.addf %21, %10 : tensor<510xf32> +// CHECK-NEXT: %23 = arith.addf %22, %8 : tensor<510xf32> +// CHECK-NEXT: %24 = arith.mulf %23, %5 : tensor<510xf32> +// CHECK-NEXT: stencil.return %24 : tensor<510xf32> // CHECK-NEXT: } // CHECK-NEXT: stencil.store %3 to %2 (<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return @@ -85,25 +86,26 @@ builtin.module { // CHECK-NEXT: %0 = stencil.load %a : !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> -> !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: %1 = stencil.apply(%2 = %0 : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>>) -> (!stencil.temp<[0,1022]x[0,510]xtensor<510xf32>>) { // CHECK-NEXT: %3 = arith.constant dense<1.666600e-01> : tensor<510xf32> -// CHECK-NEXT: %4 = stencil.access %2[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %5 = "tensor.extract_slice"(%4) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %6 = stencil.access %2[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %7 = "tensor.extract_slice"(%6) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %8 = stencil.access %2[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %9 = "tensor.extract_slice"(%8) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %10 = stencil.access %2[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %11 = "tensor.extract_slice"(%10) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %12 = stencil.access %2[0, 1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %13 = "tensor.extract_slice"(%12) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %14 = stencil.access %2[0, -1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> -// CHECK-NEXT: %15 = "tensor.extract_slice"(%14) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> -// CHECK-NEXT: %16 = arith.addf %15, %13 : tensor<510xf32> -// CHECK-NEXT: %17 = arith.addf %16, %11 : tensor<510xf32> -// CHECK-NEXT: %18 = arith.addf %17, %9 : tensor<510xf32> -// CHECK-NEXT: %19 = arith.addf %18, %7 : tensor<510xf32> -// CHECK-NEXT: %20 = arith.addf %19, %5 : tensor<510xf32> -// CHECK-NEXT: %21 = arith.mulf %20, %3 : tensor<510xf32> -// CHECK-NEXT: stencil.return %21 : tensor<510xf32> +// CHECK-NEXT: %4 = arith.constant 1.666600e-01 : f32 +// CHECK-NEXT: %5 = stencil.access %2[1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %6 = "tensor.extract_slice"(%5) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %7 = stencil.access %2[-1, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %8 = "tensor.extract_slice"(%7) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %9 = stencil.access %2[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %10 = "tensor.extract_slice"(%9) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %11 = stencil.access %2[0, 0] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %12 = "tensor.extract_slice"(%11) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %13 = stencil.access %2[0, 1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %14 = "tensor.extract_slice"(%13) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %15 = stencil.access %2[0, -1] : !stencil.temp<[-1,1023]x[-1,511]xtensor<512xf32>> +// CHECK-NEXT: %16 = "tensor.extract_slice"(%15) <{"static_offsets" = array, "static_sizes" = array, "static_strides" = array, "operandSegmentSizes" = array}> : (tensor<512xf32>) -> tensor<510xf32> +// CHECK-NEXT: %17 = arith.addf %16, %14 : tensor<510xf32> +// CHECK-NEXT: %18 = arith.addf %17, %12 : tensor<510xf32> +// CHECK-NEXT: %19 = arith.addf %18, %10 : tensor<510xf32> +// CHECK-NEXT: %20 = arith.addf %19, %8 : tensor<510xf32> +// CHECK-NEXT: %21 = arith.addf %20, %6 : tensor<510xf32> +// CHECK-NEXT: %22 = arith.mulf %21, %3 : tensor<510xf32> +// CHECK-NEXT: stencil.return %22 : tensor<510xf32> // CHECK-NEXT: } // CHECK-NEXT: stencil.store %1 to %b (<[0, 0], [1022, 510]>) : !stencil.temp<[0,1022]x[0,510]xtensor<510xf32>> to !stencil.field<[-1,1023]x[-1,511]xtensor<512xf32>> // CHECK-NEXT: func.return