diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml index 44bcbbab2bbe9d..cd64b813c11e53 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2575,6 +2575,143 @@ structured_op: !LinalgStructuredOpConfig - !ScalarExpression scalar_arg: KZp --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: conv_2d_nhwc_fhwc_q + cpp_class_name: Conv2DNhwcFhwcQOp + doc: |- + Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + implements: + - LinalgConvolutionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, + s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)> + - !LinalgOperandDefConfig + name: K + kind: input_tensor + type_var: T2 + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10, + s3, s7, s9)> + - !LinalgOperandDefConfig + name: IZp + kind: scalar + type_var: I32 + - !LinalgOperandDefConfig + name: KZp + kind: scalar + type_var: I32 + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0, + s1, s5, s10)> + - !LinalgOperandDefConfig + name: strides + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s2, s6)> + default_indices: + - 1 + - 1 + - !LinalgOperandDefConfig + name: dilations + kind: index_attr + index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> + (s4, s8)> + default_indices: + - 1 + - 1 + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> (d3, d4, d5, d6)> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> ()> + - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, + s9, s10] -> (d0, d1, d2, d3)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + - reduction + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + fn_name: add + operands: + - !ScalarExpression + scalar_arg: O + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: mul + operands: + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: sub + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: I + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: IZp + - !ScalarExpression + scalar_fn: + kind: binary + fn_name: sub + operands: + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: K + - !ScalarExpression + scalar_fn: + kind: type + fn_name: cast_signed + type_var: U + operands: + - !ScalarExpression + scalar_arg: KZp +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nchw_fchw cpp_class_name: Conv2DNchwFchwOp diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 62ec44bf9c1e1e..4214bb57563285 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -248,25 +248,28 @@ class ConvConverter : public OpConversionPattern { pad.resize(pad.size() + 2, 0); input = applyPad(loc, input, pad, zeroAttr, rewriter); - // Transpose the kernel to match dimension ordering of the linalg - // convolution operation. - // TODO(suderman): See if this can be efficiently folded - check whether - // the input is used anywhere else, if not fold the constant. - SmallVector weightPerm; - for (int i = 1; i < resultTy.getRank(); i++) - weightPerm.push_back(i); - weightPerm.push_back(0); - - SmallVector newWeightShape; - for (auto dim : weightPerm) - newWeightShape.push_back(weightShape[dim]); - auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm); - Value weightPermValue = - rewriter.create(loc, weightPermAttr); - Type newWeightTy = - RankedTensorType::get(newWeightShape, weightTy.getElementType()); - weight = rewriter.create(loc, newWeightTy, weight, - weightPermValue); + // For Conv3D transpose the kernel to match dimension ordering of the linalg + // convolution operation. Conv2D has a 1-1 mapping in linalg so better to + // map directly and then transpose later if desired. + if (5 == inputTy.getRank()) { + // TODO(suderman): See if this can be efficiently folded - check whether + // the input is used anywhere else, if not fold the constant. + SmallVector weightPerm; + for (int i = 1; i < resultTy.getRank(); i++) + weightPerm.push_back(i); + weightPerm.push_back(0); + + SmallVector newWeightShape; + for (auto dim : weightPerm) + newWeightShape.push_back(weightShape[dim]); + auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm); + Value weightPermValue = + rewriter.create(loc, weightPermAttr); + Type newWeightTy = + RankedTensorType::get(newWeightShape, weightTy.getElementType()); + weight = rewriter.create(loc, newWeightTy, weight, + weightPermValue); + } auto resultZeroAttr = rewriter.getZeroAttr(resultETy); Value emptyTensor = rewriter.create( @@ -977,7 +980,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( RewritePatternSet *patterns) { patterns->add< // clang-format off - ConvConverter, + ConvConverter, ConvConverter, DepthwiseConvConverter, MatMulConverter, diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 6eae3d916c9288..a8f8f8e0fbd68b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -693,6 +693,36 @@ def conv_2d_nhwc_hwcf_q( ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) +@linalg_structured_op +def conv_2d_nhwc_fhwc_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.F, S.KH, S.KW, S.C), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp)) + + @linalg_structured_op def conv_2d_nchw_fchw( I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index bf970c84832e9e..b601bfb28a4f28 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -363,13 +363,11 @@ func.func @avg_pool_dyn(%arg0: tensor) -> (tensor) // CHECK-LABEL: @conv2d_i8 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () { - // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> - // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] // CHECK: %[[M_IN:.+]] = tensor.empty() // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty() - // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32> // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>) // CHECK: arith.extsi // CHECK: arith.addi @@ -385,13 +383,11 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi // CHECK-LABEL: @conv2d_f32 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { - // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> - // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] // CHECK: %[[M_IN:.+]] = tensor.empty() // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty() - // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>) + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>) // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>) // CHECK: arith.addf // CHECK: linalg.yield @@ -408,13 +404,11 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27 func.func @conv2d_dyn(%input: tensor, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () { // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] - // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> - // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]] // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]]) // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]]) - // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor) + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor) // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor) outs(%[[B_IN]] : tensor) // CHECK: %[[ADD:.+]] = arith.addf // CHECK: linalg.yield %[[ADD]] : f32 @@ -468,13 +462,11 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index // Running convolution - // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]> - // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]] // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) // CHECK: %[[CST:.+]] = arith.constant 0 // CHECK: %[[FILL:.+]] = linalg.fill // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) - // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>) + // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>) // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>) // CHECK: %[[ADD:.+]] = arith.addf // CHECK: linalg.yield %[[ADD]] : f32 @@ -489,7 +481,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28 // CHECK: %[[C0:.+]] = arith.constant 0 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C0]] - // CHECK: linalg.conv_2d_nhwc_hwcf + // CHECK: linalg.conv_2d_nhwc_fhwc %0 = tosa.conv2d %input, %weights, %bias {pad = array, stride = array, dilation = array} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32> return } @@ -501,7 +493,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x // CHECK: %[[C22:.+]] = arith.constant -22 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0] // CHECK: tensor.yield %[[C22]] - // CHECK: linalg.conv_2d_nhwc_hwcf_q + // CHECK: linalg.conv_2d_nhwc_fhwc_q %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, quantization_info = #tosa.conv_quant, stride = array} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32> return }