Skip to content

Commit

Permalink
[CMSIS-NN] Initial operator support for Mul (#9163)
Browse files Browse the repository at this point in the history
This is largely as it says on the tin, it adds Mul support to CMSIS-NN
  • Loading branch information
Mousius authored Oct 1, 2021
1 parent 574a47c commit 6d120c0
Show file tree
Hide file tree
Showing 6 changed files with 350 additions and 111 deletions.
21 changes: 21 additions & 0 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ def check_quantized_softmax(extract):
and dequantize_call.args[0].checked_type.dtype == "int8"
)

def mul_pattern():
"""Matcher for QNN multiplication"""
return is_op("qnn.mul")(
wildcard(),
wildcard(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)

def check_quantized_mul(extract):
"""Check if multiply is supported by CMSIS-NN."""
return (
extract.args[0].checked_type.dtype == "int8"
and extract.args[1].checked_type.dtype == "int8"
)

return [
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
("cmsisnn.quantized_mul", mul_pattern(), check_quantized_mul),
]
89 changes: 74 additions & 15 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,37 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

class RelayToTIR : public MixedModeVisitor {
class RelayToTIRVisitor : public MixedModeVisitor {
public:
explicit RelayToTIR(String func_name) : func_name_(func_name) {}
explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {}

tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }

private:
void emit_softmax_tir(const Expr& expr) {
template <typename T>
const T ArgumentToConstantValue(const Expr& arg) {
const ConstantNode* constant_node = arg.as<ConstantNode>();
return static_cast<const T*>(constant_node->data->data)[0];
}

void CreatePrimFuncForExtern(Array<tir::Var> func_signature,
tvm::Array<PrimExpr> call_extern_args) {
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", func_name_);
dict_attrs.Set("tir.noalias", Bool(true));

tir::Stmt body = tir::Evaluate(
tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));

primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
}

void EmitSoftMax(const Expr& expr) {
auto* quantize_call = expr.as<CallNode>();
auto* softmax_call = quantize_call->args[0].as<CallNode>();
auto* dequant_call = softmax_call->args[0].as<CallNode>();
auto* scale_const = dequant_call->args[1].as<ConstantNode>();
const float quant_scale = static_cast<const float*>(scale_const->data->data)[0];
const float quant_scale = ArgumentToConstantValue<float>(dequant_call->args[1]);

// assuming layout as NHWC
auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
Expand Down Expand Up @@ -79,15 +99,51 @@ class RelayToTIR : public MixedModeVisitor {
IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size),
IntImm(DataType::Int(32), mult), IntImm(DataType::Int(32), shift),
IntImm(DataType::Int(32), diff_min), out_var};
tir::Stmt body =
tir::Evaluate(tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), args));

Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", func_name_);
dict_attrs.Set("tir.noalias", Bool(true));
CreatePrimFuncForExtern(func_signature, args);
}

primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
void EmitMul(const Expr& expr) {
auto* mul_call = expr.as<CallNode>();

const float input_0_scale = ArgumentToConstantValue<float>(mul_call->args[2]);
const int32_t input_0_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[3]);
const float input_1_scale = ArgumentToConstantValue<float>(mul_call->args[4]);
const int32_t input_1_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[5]);
const float output_scale = ArgumentToConstantValue<float>(mul_call->args[6]);
const int32_t output_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[7]);

double quantized_multiplier = static_cast<double>(input_0_scale) *
static_cast<double>(input_1_scale) /
static_cast<double>(output_scale);
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
int32_t output_multiplier = std::get<0>(mult_shift_pair);
int32_t output_shift = std::get<1>(mult_shift_pair);

PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();

tir::Var input_0("input_0", DataType::Handle(8));
tir::Var input_1("input_1", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));

Array<tir::Var> func_signature{input_0, input_1, output};

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_mul_s8"),
input_0,
input_1,
IntImm(DataType::Int(32), -input_0_zero_point),
IntImm(DataType::Int(32), -input_1_zero_point),
output,
IntImm(DataType::Int(32), output_zero_point),
IntImm(DataType::Int(32), output_multiplier),
IntImm(DataType::Int(32), output_shift),
IntImm(DataType::Int(32), std::numeric_limits<int8_t>::min()),
IntImm(DataType::Int(32), std::numeric_limits<int8_t>::max()),
tensor_size,
};

CreatePrimFuncForExtern(func_signature, args);
}

void VisitExpr_(const CallNode* call) final {
Expand All @@ -98,7 +154,10 @@ class RelayToTIR : public MixedModeVisitor {

auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined() && comp_name == "cmsisnn.quantized_softmax") {
emit_softmax_tir(func->body);
EmitSoftMax(func->body);
}
if (comp_name.defined() && comp_name == "cmsisnn.quantized_mul") {
EmitMul(func->body);
}
}

Expand All @@ -119,12 +178,12 @@ IRModule GenerateTIR(IRModule mod) {
}

// Prepare PrimFunc from Relay Function
auto relay_to_tir = RelayToTIR(func_name);
auto relay_to_tir = RelayToTIRVisitor(func_name);
relay_to_tir.VisitExpr(func->body);

// Build the TIR IRModule from the generated PrimFunc
Map<GlobalVar, BaseFunc> var_func_map;
var_func_map.Set(GlobalVar(func_name), relay_to_tir.primfunc_);
var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc());
return IRModule(var_func_map);
}

Expand Down
154 changes: 154 additions & 0 deletions tests/python/contrib/test_cmsisnn/test_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""CMSIS-NN integration tests: mul"""

import sys

import numpy as np
import pytest

from tvm import relay
from tvm.relay.op.contrib import cmsisnn

from utils import skip_if_no_reference_system, make_module, count_num_calls, get_range_for_dtype_str
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
generate_ref_data,
compile_and_run,
)


def make_model(
shape,
input_0_dtype,
input_1_dtype,
input_0_scale,
input_0_zero_point,
input_1_scale,
input_1_zero_point,
out_scale=1.0 / 256,
out_zero_point=-128,
):
"""Create a Relay Function / network model"""

return relay.qnn.op.mul(
relay.var("input_0", shape=shape, dtype=input_0_dtype),
relay.var("input_1", shape=shape, dtype=input_1_dtype),
relay.const(input_0_scale, "float32"),
relay.const(input_0_zero_point, "int32"),
relay.const(input_1_scale, "float32"),
relay.const(input_1_zero_point, "int32"),
relay.const(out_scale, "float32"),
relay.const(out_zero_point, "int32"),
)


@skip_if_no_reference_system
@pytest.mark.parametrize(
[
"input_0_scale",
"input_0_zero_point",
"input_1_scale",
"input_1_zero_point",
"output_tolerance",
],
[[0.256, 33, 0.256, 33, 0], [0.0128, -64, 0.0128, -64, 1], [0.0128, -64, 0.256, 33, 0]],
)
def test_mul_int8(
input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point, output_tolerance
):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER

dtype = "int8"
shape = [1, 16, 16, 3]
model = make_model(
shape, dtype, dtype, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point
)
orig_mod = make_module(model)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsisnn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"

# validate the output
in_min, in_max = get_range_for_dtype_str(dtype)
inputs = {
"input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
"input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
}
output_list = generate_ref_data(orig_mod["main"], inputs)
compile_and_run(
AOTTestModel(
module=cmsisnn_mod,
inputs=inputs,
outputs=output_list,
output_tolerance=output_tolerance,
),
test_runner,
interface_api,
use_unpacked_api,
)


@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]])
def test_invalid_parameters(
input_dtype,
):
input_scale = 0.256
input_zero_point = 33
model = make_model(
[1, 16, 16, 3],
input_dtype,
input_dtype,
input_scale,
input_zero_point,
input_scale,
input_zero_point,
)

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert not any(attrs), "No function should have an external attribute."


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
40 changes: 6 additions & 34 deletions tests/python/contrib/test_cmsisnn/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@

"""CMSIS-NN: testing with networks"""

import platform
import sys
import os
import pathlib
import tvm

import numpy as np
import pytest

from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.op.contrib import cmsisnn
import numpy as np
import pytest
import itertools

from utils import skip_if_no_reference_system, get_range_for_dtype_str
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
Expand All @@ -37,30 +35,6 @@
)


def get_range_for_dtype_str(dtype):
"""
Produce the min,max for a give data type.
Parameters
----------
dtype : str
a type string (e.g., int8)
Returns
-------
type_info.min : int
the minimum of the range
type_info.max : int
the maximum of the range
"""

try:
type_info = np.iinfo(dtype)
except ValueError:
type_info = np.finfo(dtype)
return type_info.min, type_info.max


def convert_to_relay(
tflite_model_buf,
input_data,
Expand Down Expand Up @@ -99,9 +73,7 @@ def convert_to_list(x):
return mod, params


@pytest.mark.skipif(
platform.machine() == "i686", reason="Reference system unavailable in i386 container"
)
@skip_if_no_reference_system
def test_cnn_small():
# download the model
base_url = "https://github.com/ARM-software/ML-zoo/raw/master/models/keyword_spotting/cnn_small/tflite_int8"
Expand Down
Loading

0 comments on commit 6d120c0

Please sign in to comment.