Skip to content

Commit

Permalink
[CMSIS-NN] Fallback to native schedules for unsupported partiti ned f…
Browse files Browse the repository at this point in the history
…unctions (apache#10603)
  • Loading branch information
ashutosh-arm authored Mar 18, 2022
1 parent 7233c29 commit 3ceae5f
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 3 deletions.
24 changes: 21 additions & 3 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,14 @@ class RelayToTIRVisitor : public MixedModeMutator {
buffer_creator.GetBufferMap(), args);
}

// Removes kCompiler attribute from the partitioned functions that are not supported by this
// RelayToTIR
Call CallToFuncWithoutCompilerAttr(GlobalVar new_global_var, Call call, Function func) {
Function new_func = WithoutAttr(std::move(func), attr::kCompiler);
ir_module_->Update(new_global_var, new_func);
return Call(new_global_var, call->args, call->attrs, call->type_args, call->span);
}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const CallNode* call = post.as<CallNode>()) {
auto* func = call->op.as<FunctionNode>();
Expand All @@ -657,11 +665,21 @@ class RelayToTIRVisitor : public MixedModeMutator {
auto codegen_name = func->GetAttr<String>(attr::kCompiler);
if (codegen_name.defined() && codegen_name == "cmsis-nn") {
const CallNode* inner_call = func->body.as<CallNode>();
auto global_func_name = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
GlobalVar new_global_var(global_func_name.value());

if (!inner_call) {
return CallToFuncWithoutCompilerAttr(new_global_var, GetRef<Call>(call),
GetRef<Function>(func));
}

const FunctionNode* composite_func = inner_call->op.as<FunctionNode>();
auto comp_name = composite_func->GetAttr<String>(attr::kComposite);
auto func_name = func->GetAttr<String>(::tvm::attr::kGlobalSymbol);
if (!composite_func) {
return CallToFuncWithoutCompilerAttr(new_global_var, GetRef<Call>(call),
GetRef<Function>(func));
}

GlobalVar new_global_var(func_name.value());
auto comp_name = composite_func->GetAttr<String>(attr::kComposite);
new_global_var->checked_type_ = composite_func->checked_type();

if (comp_name == "cmsis-nn.qnn_softmax") {
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor
[&](const Array<tvm::tir::Var>&) {
if (dtype == DataType::Int(16)) {
return make_const(dtype, static_cast<const int16_t*>(data)[0]);
} else if (dtype == DataType::Int(8)) {
return make_const(dtype, static_cast<const int8_t*>(data)[0]);
} else if (dtype == DataType::Int(32)) {
return make_const(dtype, static_cast<const int32_t*>(data)[0]);
} else if (dtype == DataType::Int(64)) {
Expand Down
83 changes: 83 additions & 0 deletions tests/python/contrib/test_cmsisnn/test_invalid_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""CMSIS-NN integration tests: Tests invalid graphs"""
import itertools
import numpy as np
import pytest
import tvm
from tvm import relay


from tests.python.relay.aot.aot_test_utils import (
AOTTestModel,
AOT_USMP_CORSTONE300_RUNNER,
generate_ref_data,
compile_and_run,
)
from utils import (
skip_if_no_reference_system,
get_range_for_dtype_str,
)


@skip_if_no_reference_system
@tvm.testing.requires_cmsisnn
def test_empty_function():
ORIGINAL_MODEL = """
#[version = "0.0.5"]
def @main(%data : Tensor[(16, 29), int8]) -> Tensor[(16, 29), int8] {
add(%data, %data)
}
"""
CMSISNN_MODEL = """
#[version = "0.0.5"]
def @tvmgen_default_cmsis_nn_main_1(%i1: Tensor[(16, 29), int8], Inline=1, Compiler="cmsis-nn", global_symbol="tvmgen_default_cmsis_nn_main_1", Primitive=1) -> Tensor[(16, 29), int8] {
add(%i1, %i1)
}
def @main(%data : Tensor[(16, 29), int8]) -> Tensor[(16, 29), int8] {
%1 = @tvmgen_default_cmsis_nn_main_1(%data) /* ty=Tensor[(16, 29), int8] */;
%1
}
"""
orig_mod = tvm.parser.fromtext(ORIGINAL_MODEL)
cmsisnn_mod = tvm.parser.fromtext(CMSISNN_MODEL)
params = {}

# validate the output
interface_api = "c"
use_unpacked_api = True
test_runner = AOT_USMP_CORSTONE300_RUNNER
dtype = "int8"
in_min, in_max = get_range_for_dtype_str(dtype)
rng = np.random.default_rng(12345)
inputs = {"data": rng.integers(in_min, high=in_max, size=(16, 29), dtype=dtype)}
outputs = generate_ref_data(orig_mod["main"], inputs, params)
compile_and_run(
AOTTestModel(
module=cmsisnn_mod,
inputs=inputs,
outputs=outputs,
params=params,
output_tolerance=0,
),
test_runner,
interface_api,
use_unpacked_api,
verbose=1,
test_dir="./test",
)

0 comments on commit 3ceae5f

Please sign in to comment.