From 838f1f0c30397a0400cc35d79e4e85c8ba54755d Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 16 Oct 2023 17:59:39 -0700 Subject: [PATCH] Revert "[mlir][tosa][linalg] Apply direct tosa -> linalg Conv2D lowering (#68304)" This reverts commit e29a253c9ebaded53a823def985364392c4ba4ec. Breaking TFLite mobilenet test. Needs triage. --- .../Linalg/IR/LinalgNamedStructuredOps.yaml | 137 ------------------ .../TosaToLinalg/TosaToLinalgNamed.cpp | 43 +++--- .../linalg/opdsl/ops/core_named_ops.py | 30 ---- .../TosaToLinalg/tosa-to-linalg-named.mlir | 20 ++- 4 files changed, 34 insertions(+), 196 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index cd64b813c11e53..44bcbbab2bbe9d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2575,143 +2575,6 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: KZp --- !LinalgOpConfig -metadata: !LinalgOpMetadata - name: conv_2d_nhwc_fhwc_q - cpp_class_name: Conv2DNhwcFhwcQOp - doc: |- - Performs 2-D convolution with zero point offsets. - - Layout: - * Input: NHWC. - * Kernel: FHWC. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. This includes the zero - point offsets common to quantized operations. - implements: - - LinalgConvolutionOpInterface -structured_op: !LinalgStructuredOpConfig - args: - - !LinalgOperandDefConfig - name: I - kind: input_tensor - type_var: T1 - shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, - s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> - - !LinalgOperandDefConfig - name: K - kind: input_tensor - type_var: T2 - shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10, - s3, s7, s9)> - - !LinalgOperandDefConfig - name: IZp - kind: scalar - type_var: I32 - - !LinalgOperandDefConfig - name: KZp - kind: scalar - type_var: I32 - - !LinalgOperandDefConfig - name: O - kind: output_tensor - type_var: U - shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, - s1, s5, s10)> - - !LinalgOperandDefConfig - name: strides - kind: index_attr - index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> - (s2, s6)> - default_indices: - - 1 - - 1 - - !LinalgOperandDefConfig - name: dilations - kind: index_attr - index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> - (s4, s8)> - default_indices: - - 1 - - 1 - indexing_maps: !LinalgIndexingMapsConfig - static_indexing_maps: - - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)> - - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10] -> (d3, d4, d5, d6)> - - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10] -> ()> - - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10] -> ()> - - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10] -> (d0, d1, d2, d3)> - iterator_types: - - parallel - - parallel - - parallel - - parallel - - reduction - - reduction - - reduction - assignments: - - !ScalarAssign - arg: O - value: !ScalarExpression - scalar_fn: - kind: binary - fn_name: add - operands: - - !ScalarExpression - scalar_arg: O - - !ScalarExpression - scalar_fn: - kind: binary - fn_name: mul - operands: - - !ScalarExpression - scalar_fn: - kind: binary - fn_name: sub - operands: - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: I - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: IZp - - !ScalarExpression - scalar_fn: - kind: binary - fn_name: sub - operands: - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: K - - !ScalarExpression - scalar_fn: - kind: type - fn_name: cast_signed - type_var: U - operands: - - !ScalarExpression - scalar_arg: KZp ---- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nchw_fchw cpp_class_name: Conv2DNchwFchwOp diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 4214bb57563285..62ec44bf9c1e1e 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -248,28 +248,25 @@ class ConvConverter : public OpConversionPattern { pad.resize(pad.size() + 2, 0); input = applyPad(loc, input, pad, zeroAttr, rewriter); - // For Conv3D transpose the kernel to match dimension ordering of the linalg - // convolution operation. Conv2D has a 1-1 mapping in linalg so better to - // map directly and then transpose later if desired. - if (5 == inputTy.getRank()) { - // TODO(suderman): See if this can be efficiently folded - check whether - // the input is used anywhere else, if not fold the constant. - SmallVector weightPerm; - for (int i = 1; i < resultTy.getRank(); i++) - weightPerm.push_back(i); - weightPerm.push_back(0); - - SmallVector newWeightShape; - for (auto dim : weightPerm) - newWeightShape.push_back(weightShape[dim]); - auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm); - Value weightPermValue = - rewriter.create(loc, weightPermAttr); - Type newWeightTy = - RankedTensorType::get(newWeightShape, weightTy.getElementType()); - weight = rewriter.create(loc, newWeightTy, weight, - weightPermValue); - } + // Transpose the kernel to match dimension ordering of the linalg + // convolution operation. + // TODO(suderman): See if this can be efficiently folded - check whether + // the input is used anywhere else, if not fold the constant. + SmallVector weightPerm; + for (int i = 1; i < resultTy.getRank(); i++) + weightPerm.push_back(i); + weightPerm.push_back(0); + + SmallVector newWeightShape; + for (auto dim : weightPerm) + newWeightShape.push_back(weightShape[dim]); + auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm); + Value weightPermValue = + rewriter.create(loc, weightPermAttr); + Type newWeightTy = + RankedTensorType::get(newWeightShape, weightTy.getElementType()); + weight = rewriter.create(loc, newWeightTy, weight, + weightPermValue); auto resultZeroAttr = rewriter.getZeroAttr(resultETy); Value emptyTensor = rewriter.create( @@ -980,7 +977,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( RewritePatternSet *patterns) { patterns->add< // clang-format off - ConvConverter, + ConvConverter, ConvConverter, DepthwiseConvConverter, MatMulConverter, diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index a8f8f8e0fbd68b..6eae3d916c9288 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -693,36 +693,6 @@ def conv_2d_nhwc_hwcf_q( ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) -@linalg_structured_op -def conv_2d_nhwc_fhwc_q( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.F, S.KH, S.KW, S.C), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), -): - """Performs 2-D convolution with zero point offsets. - - Layout: - * Input: NHWC. - * Kernel: FHWC. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. This includes the zero - point offsets common to quantized operations. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += ( - TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] - ) - - TypeFn.cast_signed(U, IZp) - ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp)) - - @linalg_structured_op def conv_2d_nchw_fchw( I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index b601bfb28a4f28..bf970c84832e9e 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -363,11 +363,13 @@ func.func @avg_pool_dyn(%arg0: tensor) -> (tensor) // CHECK-LABEL: @conv2d_i8 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () { + // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> + // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] // CHECK: %[[M_IN:.+]] = tensor.empty() // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty() - // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>) // CHECK: arith.extsi // CHECK: arith.addi @@ -383,11 +385,13 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi // CHECK-LABEL: @conv2d_f32 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { + // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> + // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] // CHECK: %[[M_IN:.+]] = tensor.empty() // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty() - // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>) + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>) // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) // CHECK: arith.addf // CHECK: linalg.yield @@ -404,11 +408,13 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27 func.func @conv2d_dyn(%input: tensor, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] + // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> + // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]]) - // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor) + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor) // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor) outs(%[[B_IN]] : tensor) // CHECK: %[[ADD:.+]] = arith.addf // CHECK: linalg.yield %[[ADD]] : f32 @@ -462,11 +468,13 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index // Running convolution + // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> + // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]] // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) - // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>) + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>) // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>) // CHECK: %[[ADD:.+]] = arith.addf // CHECK: linalg.yield %[[ADD]] : f32 @@ -481,7 +489,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28 // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C0]] - // CHECK: linalg.conv_2d_nhwc_fhwc + // CHECK: linalg.conv_2d_nhwc_hwcf %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> return } @@ -493,7 +501,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x // CHECK: %[[C22:.+]] = arith.constant -22 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C22]] - // CHECK: linalg.conv_2d_nhwc_fhwc_q + // CHECK: linalg.conv_2d_nhwc_hwcf_q %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> return }