Skip to content

Commit

Permalink
[mlir][tosa][linalg] Apply direct tosa -> linalg Conv2D lowering
Browse files Browse the repository at this point in the history
TOSA defines the filter channel ordering for 2D convolution operation
`tosa.conv2d` as `[OC, KH, KW, IC]`. The LinAlg dialect supports `[F, H,
W, C]` and `[H, W, C, F]` orderings via the `linalg.conv_2d_nhwc_fhwc`
and `linalg.conv_2d_nhwc_hwcf` operations respectively. Where `F == OC`,
`KH == H`, `KW == W` and `C == IC`.

Currently `tosa.conv2d` is lowered to `linalg.conv2d_nhwc_hwcf` meaning
we need to insert a transposition operation to permute the filter
channels before they can be passed as weights to the linalg op, that is
`[F, H, W, C]` -> `[H, W, C, F]`. An analogous transformation needs to
be applied to the quantized operation that lowers to
`linalg.conv_2d_nhwc_hwcf_q`.

This commit updates the TOSA->LinAlg lowering so that `tosa.conv2d` is
lowered to `linalg.conv2d_nhwc_fhwc` removing the need for the
introduction of a transposition operation and making the mapping 1-1. It
also adds a `linalg.conv_2d_nhwc_fhwc_q` quantized operation to the
LinAlg dialect so the same direct 1-1 mapping can be applied to the
quantized variant.

This commit does not add any new lit tests but repurposes the current
TosaToLinalgNamed tests by removing the checks for transpositions and
updating the targeted LinAlg operations from `linalg.conv2d_nhwc_hwcf`
to linalg.conv2d_nhwc_fhwc`.

Signed-off-by: Jack Frankland <jack.frankland@arm.com>
  • Loading branch information
FranklandJack committed Oct 5, 2023
1 parent 96dd50e commit 9e5e7d9
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 34 deletions.
137 changes: 137 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2575,6 +2575,143 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: KZp
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nhwc_fhwc_q
cpp_class_name: Conv2DNhwcFhwcQOp
doc: |-
Performs 2-D convolution with zero point offsets.
Layout:
* Input: NHWC.
* Kernel: FHWC.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. This includes the zero
point offsets common to quantized operations.
implements:
- LinalgConvolutionOpInterface
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9)>
- !LinalgOperandDefConfig
name: K
kind: input_tensor
type_var: T2
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s10,
s3, s7, s9)>
- !LinalgOperandDefConfig
name: IZp
kind: scalar
type_var: I32
- !LinalgOperandDefConfig
name: KZp
kind: scalar
type_var: I32
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: U
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] -> (s0,
s1, s5, s10)>
- !LinalgOperandDefConfig
name: strides
kind: index_attr
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
(s2, s6)>
default_indices:
- 1
- 1
- !LinalgOperandDefConfig
name: dilations
kind: index_attr
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10] ->
(s4, s8)>
default_indices:
- 1
- 1
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10] -> (d0, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8, d6)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10] -> (d3, d4, d5, d6)>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10] -> ()>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10] -> ()>
- affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8,
s9, s10] -> (d0, d1, d2, d3)>
iterator_types:
- parallel
- parallel
- parallel
- parallel
- reduction
- reduction
- reduction
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: add
operands:
- !ScalarExpression
scalar_arg: O
- !ScalarExpression
scalar_fn:
kind: binary
fn_name: mul
operands:
- !ScalarExpression
scalar_fn:
kind: binary
fn_name: sub
operands:
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: I
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: IZp
- !ScalarExpression
scalar_fn:
kind: binary
fn_name: sub
operands:
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: K
- !ScalarExpression
scalar_fn:
kind: type
fn_name: cast_signed
type_var: U
operands:
- !ScalarExpression
scalar_arg: KZp
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: conv_2d_nchw_fchw
cpp_class_name: Conv2DNchwFchwOp
Expand Down
43 changes: 23 additions & 20 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,25 +248,28 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
pad.resize(pad.size() + 2, 0);
input = applyPad(loc, input, pad, zeroAttr, rewriter);

// Transpose the kernel to match dimension ordering of the linalg
// convolution operation.
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
SmallVector<int64_t> weightPerm;
for (int i = 1; i < resultTy.getRank(); i++)
weightPerm.push_back(i);
weightPerm.push_back(0);

SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
// For Conv3D transpose the kernel to match dimension ordering of the linalg
// convolution operation. Conv2D has a 1-1 mapping in linalg so better to
// map directly and then transpose later if desired.
if (5 == inputTy.getRank()) {
// TODO(suderman): See if this can be efficiently folded - check whether
// the input is used anywhere else, if not fold the constant.
SmallVector<int64_t> weightPerm;
for (int i = 1; i < resultTy.getRank(); i++)
weightPerm.push_back(i);
weightPerm.push_back(0);

SmallVector<int64_t> newWeightShape;
for (auto dim : weightPerm)
newWeightShape.push_back(weightShape[dim]);
auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
Value weightPermValue =
rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
Type newWeightTy =
RankedTensorType::get(newWeightShape, weightTy.getElementType());
weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
weightPermValue);
}

auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
Expand Down Expand Up @@ -977,7 +980,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<
// clang-format off
ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>,
ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcFhwcOp, linalg::Conv2DNhwcFhwcQOp>,
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
DepthwiseConvConverter,
MatMulConverter,
Expand Down
30 changes: 30 additions & 0 deletions mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,36 @@ def conv_2d_nhwc_hwcf_q(
) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp))


@linalg_structured_op
def conv_2d_nhwc_fhwc_q(
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
K=TensorDef(T2, S.F, S.KH, S.KW, S.C),
IZp=ScalarDef(I32),
KZp=ScalarDef(I32),
O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True),
strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]),
dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]),
):
"""Performs 2-D convolution with zero point offsets.
Layout:
* Input: NHWC.
* Kernel: FHWC.
Numeric casting is performed on the operands to the inner multiply, promoting
them to the same data type as the accumulator/output. This includes the zero
point offsets common to quantized operations.
"""
implements(ConvolutionOpInterface)
domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c)
O[D.n, D.oh, D.ow, D.f] += (
TypeFn.cast_signed(
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]
)
- TypeFn.cast_signed(U, IZp)
) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp))


@linalg_structured_op
def conv_2d_nchw_fchw(
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
Expand Down
20 changes: 6 additions & 14 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,11 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)

// CHECK-LABEL: @conv2d_i8
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
// CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
// CHECK: arith.extsi
// CHECK: arith.addi
Expand All @@ -385,13 +383,11 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi

// CHECK-LABEL: @conv2d_f32
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
// CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
// CHECK: %[[M_IN:.+]] = tensor.empty()
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty()
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
// CHECK: arith.addf
// CHECK: linalg.yield
Expand All @@ -408,13 +404,11 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
// CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
// CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
// CHECK: %[[ADD:.+]] = arith.addf
// CHECK: linalg.yield %[[ADD]] : f32
Expand Down Expand Up @@ -468,13 +462,11 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
// CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index

// Running convolution
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
// CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]]
// CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
// CHECK: %[[CST:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
// CHECK: %[[ADD:.+]] = arith.addf
// CHECK: linalg.yield %[[ADD]] : f32
Expand All @@ -489,7 +481,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: tensor.yield %[[C0]]
// CHECK: linalg.conv_2d_nhwc_hwcf
// CHECK: linalg.conv_2d_nhwc_fhwc
%0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
return
}
Expand All @@ -501,7 +493,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
// CHECK: %[[C22:.+]] = arith.constant -22
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: tensor.yield %[[C22]]
// CHECK: linalg.conv_2d_nhwc_hwcf_q
// CHECK: linalg.conv_2d_nhwc_fhwc_q
%0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
return
}
Expand Down

0 comments on commit 9e5e7d9

Please sign in to comment.