Skip to content

Commit

Permalink
[CMSIS-NN] Initial operator support for Mul
Browse files Browse the repository at this point in the history
This is largely as it says on the tin, it adds Mul support to CMSIS-NN
  • Loading branch information
Mousius committed Sep 30, 2021
1 parent 229eca4 commit cf54890
Show file tree
Hide file tree
Showing 6 changed files with 350 additions and 111 deletions.
21 changes: 21 additions & 0 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,27 @@ def check_quantized_softmax(extract):
and dequantize_call.args[0].checked_type.dtype == "int8"
)

def mul_pattern():
"""Matcher for QNN multiplication"""
return is_op("qnn.mul")(
wildcard(),
wildcard(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)

def check_quantized_mul(extract):
"""Check if multiply is supported by CMSIS-NN."""
return (
extract.args[0].checked_type.dtype == "int8"
and extract.args[1].checked_type.dtype == "int8"
)

return [
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
("cmsisnn.quantized_mul", mul_pattern(), check_quantized_mul),
]
89 changes: 74 additions & 15 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,37 @@ namespace relay {
namespace contrib {
namespace cmsisnn {

class RelayToTIR : public MixedModeVisitor {
class RelayToTIRVisitor : public MixedModeVisitor {
public:
explicit RelayToTIR(String func_name) : func_name_(func_name) {}
explicit RelayToTIRVisitor(String func_name) : func_name_(func_name) {}

tir::PrimFunc GetReplacementPrimFunc() { return primfunc_; }

private:
void emit_softmax_tir(const Expr& expr) {
template <typename T>
const T ArgumentToConstantValue(const Expr& arg) {
const ConstantNode* constant_node = arg.as<ConstantNode>();
return static_cast<const T*>(constant_node->data->data)[0];
}

void CreatePrimFuncForExtern(Array<tir::Var> func_signature,
tvm::Array<PrimExpr> call_extern_args) {
Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", func_name_);
dict_attrs.Set("tir.noalias", Bool(true));

tir::Stmt body = tir::Evaluate(
tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), call_extern_args));

primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
}

void EmitSoftMax(const Expr& expr) {
auto* quantize_call = expr.as<CallNode>();
auto* softmax_call = quantize_call->args[0].as<CallNode>();
auto* dequant_call = softmax_call->args[0].as<CallNode>();
auto* scale_const = dequant_call->args[1].as<ConstantNode>();
const float quant_scale = static_cast<const float*>(scale_const->data->data)[0];
const float quant_scale = ArgumentToConstantValue<float>(dequant_call->args[1]);

// assuming layout as NHWC
auto shape = quantize_call->type_as<TensorTypeNode>()->shape;
Expand Down Expand Up @@ -79,15 +99,51 @@ class RelayToTIR : public MixedModeVisitor {
IntImm(DataType::Int(32), num_rows), IntImm(DataType::Int(32), row_size),
IntImm(DataType::Int(32), mult), IntImm(DataType::Int(32), shift),
IntImm(DataType::Int(32), diff_min), out_var};
tir::Stmt body =
tir::Evaluate(tvm::tir::Call(DataType::Int(8), tir::builtin::call_extern(), args));

Map<String, ObjectRef> dict_attrs;
dict_attrs.Set("global_symbol", func_name_);
dict_attrs.Set("tir.noalias", Bool(true));
CreatePrimFuncForExtern(func_signature, args);
}

primfunc_ = tir::PrimFunc(func_signature, body, VoidType(), Map<tir::Var, tir::Buffer>(),
DictAttrs(dict_attrs));
void EmitMul(const Expr& expr) {
auto* mul_call = expr.as<CallNode>();

const float input_0_scale = ArgumentToConstantValue<float>(mul_call->args[2]);
const int32_t input_0_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[3]);
const float input_1_scale = ArgumentToConstantValue<float>(mul_call->args[4]);
const int32_t input_1_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[5]);
const float output_scale = ArgumentToConstantValue<float>(mul_call->args[6]);
const int32_t output_zero_point = ArgumentToConstantValue<int32_t>(mul_call->args[7]);

double quantized_multiplier = static_cast<double>(input_0_scale) *
static_cast<double>(input_1_scale) /
static_cast<double>(output_scale);
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(quantized_multiplier);
int32_t output_multiplier = std::get<0>(mult_shift_pair);
int32_t output_shift = std::get<1>(mult_shift_pair);

PrimExpr tensor_size = mul_call->type_as<TensorTypeNode>()->Size();

tir::Var input_0("input_0", DataType::Handle(8));
tir::Var input_1("input_1", DataType::Handle(8));
tir::Var output("output", DataType::Handle(8));

Array<tir::Var> func_signature{input_0, input_1, output};

tvm::Array<PrimExpr> args = {
tir::StringImm("arm_elementwise_mul_s8"),
input_0,
input_1,
IntImm(DataType::Int(32), -input_0_zero_point),
IntImm(DataType::Int(32), -input_1_zero_point),
output,
IntImm(DataType::Int(32), output_zero_point),
IntImm(DataType::Int(32), output_multiplier),
IntImm(DataType::Int(32), output_shift),
IntImm(DataType::Int(32), std::numeric_limits<int8_t>::min()),
IntImm(DataType::Int(32), std::numeric_limits<int8_t>::max()),
tensor_size,
};

CreatePrimFuncForExtern(func_signature, args);
}

void VisitExpr_(const CallNode* call) final {
Expand All @@ -98,7 +154,10 @@ class RelayToTIR : public MixedModeVisitor {

auto comp_name = func->GetAttr<String>(attr::kComposite);
if (comp_name.defined() && comp_name == "cmsisnn.quantized_softmax") {
emit_softmax_tir(func->body);
EmitSoftMax(func->body);
}
if (comp_name.defined() && comp_name == "cmsisnn.quantized_mul") {
EmitMul(func->body);
}
}

Expand All @@ -119,12 +178,12 @@ IRModule GenerateTIR(IRModule mod) {
}

// Prepare PrimFunc from Relay Function
auto relay_to_tir = RelayToTIR(func_name);
auto relay_to_tir = RelayToTIRVisitor(func_name);
relay_to_tir.VisitExpr(func->body);

// Build the TIR IRModule from the generated PrimFunc
Map<GlobalVar, BaseFunc> var_func_map;
var_func_map.Set(GlobalVar(func_name), relay_to_tir.primfunc_);
var_func_map.Set(GlobalVar(func_name), relay_to_tir.GetReplacementPrimFunc());
return IRModule(var_func_map);
}

Expand Down
154 changes: 154 additions & 0 deletions tests/python/contrib/test_cmsisnn/test_mul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""CMSIS-NN integration tests: mul"""

import sys

import numpy as np
import pytest

from tvm import relay
from tvm.relay.op.contrib import cmsisnn

from utils import skip_if_no_reference_system, make_module, count_num_calls, get_range_for_dtype_str
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
generate_ref_data,
compile_and_run,
)


def make_model(
shape,
input_0_dtype,
input_1_dtype,
input_0_scale,
input_0_zero_point,
input_1_scale,
input_1_zero_point,
out_scale=1.0 / 256,
out_zero_point=-128,
):
"""Create a Relay Function / network model"""

return relay.qnn.op.mul(
relay.var("input_0", shape=shape, dtype=input_0_dtype),
relay.var("input_1", shape=shape, dtype=input_1_dtype),
relay.const(input_0_scale, "float32"),
relay.const(input_0_zero_point, "int32"),
relay.const(input_1_scale, "float32"),
relay.const(input_1_zero_point, "int32"),
relay.const(out_scale, "float32"),
relay.const(out_zero_point, "int32"),
)


@skip_if_no_reference_system
@pytest.mark.parametrize(
[
"input_0_scale",
"input_0_zero_point",
"input_1_scale",
"input_1_zero_point",
"output_tolerance",
],
[[0.256, 33, 0.256, 33, 0], [0.0128, -64, 0.0128, -64, 1], [0.0128, -64, 0.256, 33, 0]],
)
def test_mul_int8(
input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point, output_tolerance
):
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_CORSTONE300_RUNNER

dtype = "int8"
shape = [1, 16, 16, 3]
model = make_model(
shape, dtype, dtype, input_0_scale, input_0_zero_point, input_1_scale, input_1_zero_point
)
orig_mod = make_module(model)

cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

# validate pattern matching
attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert any(attrs), "At least one function with external attributes was expected."

compilers = [
key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items()
]
assert any(compilers), "Module does not contain function for cmsisnn target."

assert count_num_calls(orig_mod) == count_num_calls(
cmsisnn_mod
), "Number of calls changed during partitioning"

# validate the output
in_min, in_max = get_range_for_dtype_str(dtype)
inputs = {
"input_0": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
"input_1": np.random.randint(in_min, high=in_max, size=shape, dtype=dtype),
}
output_list = generate_ref_data(orig_mod["main"], inputs)
compile_and_run(
AOTTestModel(
module=cmsisnn_mod,
inputs=inputs,
outputs=output_list,
output_tolerance=output_tolerance,
),
test_runner,
interface_api,
use_unpacked_api,
)


@pytest.mark.parametrize(["input_dtype"], [["uint8"], ["int16"]])
def test_invalid_parameters(
input_dtype,
):
input_scale = 0.256
input_zero_point = 33
model = make_model(
[1, 16, 16, 3],
input_dtype,
input_dtype,
input_scale,
input_zero_point,
input_scale,
input_zero_point,
)

orig_mod = make_module(model)
cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod)

attrs = [
cmsisnn_mod[var.name_hint].attrs
for var in cmsisnn_mod.get_global_vars()
if cmsisnn_mod[var.name_hint].attrs
]
assert not any(attrs), "No function should have an external attribute."


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))
40 changes: 6 additions & 34 deletions tests/python/contrib/test_cmsisnn/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,16 @@

"""CMSIS-NN: testing with networks"""

import platform
import sys
import os
import pathlib
import tvm

import numpy as np
import pytest

from tvm import relay
from tvm.contrib.download import download_testdata
from tvm.relay.op.contrib import cmsisnn
import numpy as np
import pytest
import itertools

from utils import skip_if_no_reference_system, get_range_for_dtype_str
from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_CORSTONE300_RUNNER,
Expand All @@ -37,30 +35,6 @@
)


def get_range_for_dtype_str(dtype):
"""
Produce the min,max for a give data type.
Parameters
----------
dtype : str
a type string (e.g., int8)
Returns
-------
type_info.min : int
the minimum of the range
type_info.max : int
the maximum of the range
"""

try:
type_info = np.iinfo(dtype)
except ValueError:
type_info = np.finfo(dtype)
return type_info.min, type_info.max


def convert_to_relay(
tflite_model_buf,
input_data,
Expand Down Expand Up @@ -99,9 +73,7 @@ def convert_to_list(x):
return mod, params


@pytest.mark.skipif(
platform.machine() == "i686", reason="Reference system unavailable in i386 container"
)
@skip_if_no_reference_system
def test_cnn_small():
# download the model
base_url = "https://github.com/ARM-software/ML-zoo/raw/master/models/keyword_spotting/cnn_small/tflite_int8"
Expand Down
Loading

0 comments on commit cf54890

Please sign in to comment.