Skip to content

Commit

Permalink
[CMSIS-NN] Initial operator support for Add (apache#9167)
Browse files Browse the repository at this point in the history
This patch aims to add initial support for the `Add` operator to CMSIS NN, which was actually similar enough to the `Mul` operator that it shares quite a bit of code - exciting times.
  • Loading branch information
Mousius authored and ylc committed Jan 7, 2022
1 parent 7e18368 commit 99fdd3d
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 42 deletions.
19 changes: 14 additions & 5 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def check_quantized_softmax(extract):
and dequantize_call.args[0].checked_type.dtype == "int8"
)

def mul_pattern():
"""Matcher for QNN multiplication"""
return is_op("qnn.mul")(
def binary_op_pattern(op):
"""Matches QNN binary operation"""
return is_op(f"qnn.{op}")(
wildcard(),
wildcard(),
is_constant(),
Expand All @@ -92,7 +92,7 @@ def mul_pattern():
is_constant(),
)

def check_quantized_mul(extract):
def check_quantized_binary_op(extract):
"""Check if multiply is supported by CMSIS-NN."""
return (
extract.args[0].checked_type.dtype == "int8"
Expand All @@ -101,5 +101,14 @@ def check_quantized_mul(extract):

return [
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
("cmsisnn.quantized_mul", mul_pattern(), check_quantized_mul),
(
"cmsisnn.quantized_mul",
binary_op_pattern("mul"),
check_quantized_binary_op,
),
(
"cmsisnn.quantized_add",
binary_op_pattern("add"),
check_quantized_binary_op,
),
]
130 changes: 102 additions & 28 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/tir/stmt_functor.h>

#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"

namespace tvm {
namespace relay {
Expand All @@ -39,11 +40,7 @@ class RelayToTIRVisitor : public MixedModeVisitor {
tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }

private:
template <typename T>
const T ArgumentToConstantValue(const Expr& arg) {
const ConstantNode* constant_node = arg.as<ConstantNode>();
return static_cast<const T*>(constant_node->data->data)[0];
}
inline IntImm ToArg(int32_t value) { return IntImm(DataType::Int(32), value); }

void CreatePrimFuncForExtern(Array<tir::Var> func_signature,
tvm::Array<PrimExpr> call_extern_args) {
Expand All @@ -62,7 +59,7 @@ class RelayToTIRVisitor : public MixedModeVisitor {
auto* quantize_call = expr.as<CallNode>();
auto* softmax_call = quantize_call->args[0].as<CallNode>();
auto* dequant_call = softmax_call->args[0].as<CallNode>();
const float quant_scale = ArgumentToConstantValue<float>(dequant_call->args[1]);
const float quant_scale = GetScalarFromConstant<float>(dequant_call->args[1]);

// assuming layout as NHWC
auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
Expand Down Expand Up @@ -95,23 +92,28 @@ class RelayToTIRVisitor : public MixedModeVisitor {
Array<tir::Var> func_signature{in_var, out_var};

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_softmax_s8"), in_var,
IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size),
IntImm(DataType::Int(32), mult), IntImm(DataType::Int(32), shift),
IntImm(DataType::Int(32), diff_min), out_var};
tir::StringImm("arm_softmax_s8"),
in_var,
ToArg(num_rows),
ToArg(row_size),
ToArg(mult),
ToArg(shift),
ToArg(diff_min),
out_var,
};

CreatePrimFuncForExtern(func_signature, args);
}

void EmitMul(const Expr& expr) {
auto* mul_call = expr.as<CallNode>();

const float input_0_scale = ArgumentToConstantValue<float>(mul_call->args[2]);
const int32_t input_0_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[3]);
const float input_1_scale = ArgumentToConstantValue<float>(mul_call->args[4]);
const int32_t input_1_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[5]);
const float output_scale = ArgumentToConstantValue<float>(mul_call->args[6]);
const int32_t output_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[7]);
const float input_0_scale = GetScalarFromConstant<float>(mul_call->args[2]);
const int32_t input_0_zero_point = GetScalarFromConstant<int32_t>(mul_call->args[3]);
const float input_1_scale = GetScalarFromConstant<float>(mul_call->args[4]);
const int32_t input_1_zero_point = GetScalarFromConstant<int32_t>(mul_call->args[5]);
const float output_scale = GetScalarFromConstant<float>(mul_call->args[6]);
const int32_t output_zero_point = GetScalarFromConstant<int32_t>(mul_call->args[7]);

double quantized_multiplier = static_cast<double>(input_0_scale) *
static_cast<double>(input_1_scale) /
Expand All @@ -132,14 +134,81 @@ class RelayToTIRVisitor : public MixedModeVisitor {
tir::StringImm("arm_elementwise_mul_s8"),
input_0,
input_1,
IntImm(DataType::Int(32), -input_0_zero_point),
IntImm(DataType::Int(32), -input_1_zero_point),
ToArg(-input_0_zero_point),
ToArg(-input_1_zero_point),
output,
IntImm(DataType::Int(32), output_zero_point),
IntImm(DataType::Int(32), output_multiplier),
IntImm(DataType::Int(32), output_shift),
IntImm(DataType::Int(32), std::numeric_limits<int8_t>::min()),
IntImm(DataType::Int(32), std::numeric_limits<int8_t>::max()),
ToArg(output_zero_point),
ToArg(output_multiplier),
ToArg(output_shift),
ToArg(std::numeric_limits<int8_t>::min()),
ToArg(std::numeric_limits<int8_t>::max()),
tensor_size,
};

CreatePrimFuncForExtern(func_signature, args);
}

void EmitAdd(const Expr& expr) {
auto* add_call = expr.as<CallNode>();

const float input_0_scale = GetScalarFromConstant<float>(add_call->args[2]);
const int32_t input_0_zero_point = GetScalarFromConstant<int32_t>(add_call->args[3]);
const float input_1_scale = GetScalarFromConstant<float>(add_call->args[4]);
const int32_t input_1_zero_point = GetScalarFromConstant<int32_t>(add_call->args[5]);
const float output_scale = GetScalarFromConstant<float>(add_call->args[6]);
const int32_t output_zero_point = GetScalarFromConstant<int32_t>(add_call->args[7]);

const int32_t left_shift = 20;
const int32_t input_0_offset = -input_0_zero_point;
const int32_t input_1_offset = -input_1_zero_point;

const float max_input_scale = std::max(input_0_scale, input_1_scale);
const double twice_max_input_scale = 2 * static_cast<double>(max_input_scale);
const double scaled_input_0_scale = static_cast<double>(input_0_scale) / twice_max_input_scale;
const double scaled_input_1_scale = static_cast<double>(input_1_scale) / twice_max_input_scale;
const double scaled_output_scale =
twice_max_input_scale / ((1 << left_shift) * static_cast<double>(output_scale));

auto input_0_mult_shift_pair =
tvm::relay::qnn::GetFixedPointMultiplierShift(scaled_input_0_scale);
int32_t input_0_multiplier = std::get<0>(input_0_mult_shift_pair);
int32_t input_0_shift = std::get<1>(input_0_mult_shift_pair);

auto input_1_mult_shift_pair =
tvm::relay::qnn::GetFixedPointMultiplierShift(scaled_input_1_scale);
int32_t input_1_multiplier = std::get<0>(input_1_mult_shift_pair);
int32_t input_1_shift = std::get<1>(input_1_mult_shift_pair);

auto output_mult_shift_pair =
tvm::relay::qnn::GetFixedPointMultiplierShift(scaled_output_scale);
int32_t output_multiplier = std::get<0>(output_mult_shift_pair);
int32_t output_shift = std::get<1>(output_mult_shift_pair);

PrimExpr tensor_size = add_call->type_as<TensorTypeNode>()->Size();

tir::Var input_0("input_0", DataType::Handle(8));
tir::Var input_1("input_1", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));

Array<tir::Var> func_signature{input_0, input_1, output};

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_add_s8"),
input_0,
input_1,
ToArg(input_0_offset),
ToArg(input_0_multiplier),
ToArg(input_0_shift),
ToArg(input_1_offset),
ToArg(input_1_multiplier),
ToArg(input_1_shift),
ToArg(left_shift),
output,
ToArg(output_zero_point),
ToArg(output_multiplier),
ToArg(output_shift),
ToArg(std::numeric_limits<int8_t>::min()),
ToArg(std::numeric_limits<int8_t>::max()),
tensor_size,
};

Expand All @@ -153,11 +222,16 @@ class RelayToTIRVisitor : public MixedModeVisitor {
}

auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined() && comp_name == "cmsisnn.quantized_softmax") {
EmitSoftMax(func->body);
}
if (comp_name.defined() && comp_name == "cmsisnn.quantized_mul") {
EmitMul(func->body);
if (comp_name.defined()) {
if (comp_name == "cmsisnn.quantized_softmax") {
EmitSoftMax(func->body);
}
if (comp_name == "cmsisnn.quantized_mul") {
EmitMul(func->body);
}
if (comp_name == "cmsisnn.quantized_add") {
EmitAdd(func->body);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

"""CMSIS-NN integration tests: mul"""
"""CMSIS-NN integration tests: binary ops"""

import sys

Expand All @@ -35,6 +35,7 @@


def make_model(
op,
shape,
input_0_dtype,
input_1_dtype,
Expand All @@ -47,7 +48,7 @@ def make_model(
):
"""Create a Relay Function / network model"""

return relay.qnn.op.mul(
return op(
relay.var("input_0", shape=shape, dtype=input_0_dtype),
relay.var("input_1", shape=shape, dtype=input_1_dtype),
relay.const(input_0_scale, "float32"),
Expand All @@ -60,27 +61,32 @@ def make_model(


@skip_if_no_reference_system
@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@pytest.mark.parametrize(
[
"input_0_scale",
"input_0_zero_point",
"input_1_scale",
"input_1_zero_point",
"output_tolerance",
],
[[0.256, 33, 0.256, 33, 0], [0.0128, -64, 0.0128, -64, 1], [0.0128, -64, 0.256, 33, 0]],
[[0.256, 33, 0.256, 33], [0.0128, -64, 0.0128, -64], [0.0128, -64, 0.256, 33]],
)
def test_mul_int8(
input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point, output_tolerance
):
def test_op_int8(op, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER

dtype = "int8"
shape = [1, 16, 16, 3]
model = make_model(
shape, dtype, dtype, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point
op,
shape,
dtype,
dtype,
input_0_scale,
input_0_zero_point,
input_1_scale,
input_1_zero_point,
)
orig_mod = make_module(model)

Expand Down Expand Up @@ -115,21 +121,24 @@ def test_mul_int8(
module=cmsisnn_mod,
inputs=inputs,
outputs=output_list,
output_tolerance=output_tolerance,
output_tolerance=1,
),
test_runner,
interface_api,
use_unpacked_api,
)


@pytest.mark.parametrize("op", [relay.qnn.op.mul, relay.qnn.op.add])
@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]])
def test_invalid_parameters(
op,
input_dtype,
):
input_scale = 0.256
input_zero_point = 33
model = make_model(
op,
[1, 16, 16, 3],
input_dtype,
input_dtype,
Expand Down

0 comments on commit 99fdd3d

Please sign in to comment.