diff --git a/python/tvm/relay/op/contrib/cmsisnn.py b/python/tvm/relay/op/contrib/cmsisnn.py new file mode 100644 index 000000000000..daf1e098d7f1 --- /dev/null +++ b/python/tvm/relay/op/contrib/cmsisnn.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Arm(R) CMSIS-NN supported operators for Cortex-M.""" +import tvm.ir +from tvm.relay import transform +from tvm.relay.build_module import bind_params_by_name + +from ...dataflow_pattern import is_constant, is_op, wildcard +from .register import register_pattern_table + + +def partition_for_cmsisnn(mod, params=None, **opts): + """Partition the graph greedily offloading supported + operators on Cortex-M using CMSIS-NN + + Parameters + ---------- + mod : Module + The module to run passes on. + params : Optional[Dict[str, NDArray]] + Constant input parameters. + + Returns + ------- + ret : Module + annotated and partitioned module. + """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + seq = tvm.transform.Sequential( + [ + transform.InferType(), + transform.MergeComposite(pattern_table()), + transform.AnnotateTarget("cmsisnn"), + transform.MergeCompilerRegions(), + transform.PartitionGraph(), + ] + ) + + return seq(mod) + + +@register_pattern_table("cmsisnn") +def pattern_table(): + """Get the cmsisnn compiler pattern table.""" + + def softmax_pattern(): + pattern = is_op("qnn.dequantize")(wildcard(), is_constant(), is_constant()) + pattern = is_op("nn.softmax")(pattern) + pattern = is_op("qnn.quantize")(pattern, is_constant(), is_constant()) + return pattern + + def check_quantized_softmax(extract): + """Check if softmax is supported by CMSIS-NN.""" + + # check for dtypes of quantize and dequantize + return ( + extract.attrs.out_dtype == "int8" + and extract.args[0].args[0].args[0].checked_type.dtype == "int8" + ) + + return [ + ("cmsisnn.qnn_softmax", softmax_pattern(), check_quantized_softmax), + ] diff --git a/tests/python/contrib/test_cmsisnn/test_softmax.py b/tests/python/contrib/test_cmsisnn/test_softmax.py new file mode 100644 index 000000000000..afbc302af66f --- /dev/null +++ b/tests/python/contrib/test_cmsisnn/test_softmax.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CMSIS-NN integration tests: softmax""" + +import pytest +import sys + +import tvm +from tvm import relay +from tvm.relay.op.contrib import cmsisnn +import numpy as np + + +def count_num_calls(mod): + class CallCounter(relay.ExprVisitor): + def __init__(self): + super().__init__() + self.count = 0 + + def visit_call(self, call): + if isinstance(call.op, tvm.ir.Op): + self.count += 1 + + super().visit_call(call) + + counter = CallCounter() + for var in mod.get_global_vars(): + counter.visit(mod[var.name_hint]) + return counter.count + + +def make_module(func): + func = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(func) + return relay.transform.InferType()(mod) + + +def make_model(shape, zero_point, scale, in_dtype, out_dtype): + a = relay.var("a", shape=shape, dtype=in_dtype) + dequantize = relay.qnn.op.dequantize( + a, + input_scale=relay.const(scale, "float32"), + input_zero_point=relay.const(zero_point, "int32"), + ) + softmax = relay.nn.softmax(dequantize) + model = relay.qnn.op.quantize( + softmax, + output_scale=relay.const(scale, "float32"), + output_zero_point=relay.const(zero_point, "int32"), + out_dtype=out_dtype, + ) + return model + + +def test_softmax_int8(): + model = make_model([1, 16, 16, 3], 64, 0.02, "int8", "int8") + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert any(attrs), "At least one function with external attributes was expected." + + compilers = [ + key == "Compiler" and value == "cmsisnn" for attr in attrs for key, value in attr.items() + ] + assert any(compilers), "Module does not contain function for cmsisnn target." + + assert count_num_calls(orig_mod) == count_num_calls( + cmsisnn_mod + ), "Number of calls changed during partitioning" + + +@pytest.mark.parametrize("in_dtype,out_dtype", [["uint8", "int8"], ["int8", "uint8"]]) +def test_softmax_not_int8(in_dtype, out_dtype): + model = make_model([1, 16, 16, 3], 64, 0.02, in_dtype, out_dtype) + orig_mod = make_module(model) + cmsisnn_mod = cmsisnn.partition_for_cmsisnn(orig_mod) + + attrs = [ + cmsisnn_mod[var.name_hint].attrs + for var in cmsisnn_mod.get_global_vars() + if cmsisnn_mod[var.name_hint].attrs + ] + assert not any(attrs), "No function should have an external attribute." + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))