From 5cd99b032820bb162f19412f6e5ab095ac39048e Mon Sep 17 00:00:00 2001 From: Nicola Lancellotti Date: Thu, 11 Nov 2021 17:27:40 +0000 Subject: [PATCH] Arm(R) Ethos(TM)-U NPU BinaryElementwise operators support (#9442) This commit adds support for the binary elementwise primitive operators for the Arm(R) Ethos(TM)-U NPU and includes a few minor rewording changes. --- .../relay/backend/contrib/ethosu/legalize.py | 227 +++++++++ .../backend/contrib/ethosu/op/__init__.py | 1 + .../contrib/ethosu/op/binary_elementwise.py | 215 +++++++++ .../backend/contrib/ethosu/op/convolution.py | 4 +- .../backend/contrib/ethosu/op/depthwise.py | 7 +- .../backend/contrib/ethosu/op/pooling.py | 6 +- .../backend/contrib/ethosu/te/__init__.py | 1 + .../contrib/ethosu/te/binary_elementwise.py | 184 ++++++++ .../contrib/ethosu/tir/binary_elementwise.py | 102 ++++ .../backend/contrib/ethosu/tir/passes.py | 2 + .../relay/backend/contrib/ethosu/tir/spec.py | 21 + .../contrib/ethosu/tir_to_cs_translator.py | 76 ++- .../tvm/relay/backend/contrib/ethosu/util.py | 15 + python/tvm/relay/op/contrib/ethosu.py | 352 +++++++++++++- .../op/contrib/ethosu/binary_elementwise.cc | 301 ++++++++++++ src/relay/op/contrib/ethosu/common.cc | 18 + src/relay/op/contrib/ethosu/common.h | 11 + src/relay/op/contrib/ethosu/pooling.cc | 2 +- tests/python/contrib/test_ethosu/infra.py | 53 +++ .../contrib/test_ethosu/test_codegen.py | 252 +++++++++- .../contrib/test_ethosu/test_legalize.py | 188 ++++++++ .../test_replace_binary_elementwise.py | 335 ++++++++++++++ .../test_ethosu/test_tir_to_cs_translator.py | 434 ++++++++++++++++++ .../test_ethosu/test_type_inference.py | 116 +++++ 24 files changed, 2902 insertions(+), 21 deletions(-) create mode 100644 python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py create mode 100644 python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py create mode 100644 src/relay/op/contrib/ethosu/binary_elementwise.cc create mode 100644 tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index c4b70c130d4ef..d0d04cebaefe4 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -413,6 +413,224 @@ def __call__(self, *args, **kwargs): pass +class BinaryElementwiseRewriter(DFPatternCallback): + """Convert ethosu binary elementwise composite functions to + ethosu_binary_elementwise operators""" + + def __init__( + self, + params_class: Type, + pattern: CallPattern, + ): + super().__init__(require_type=True) + self.params_class = params_class + self.pattern = pattern + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = self.params_class(post.op.body) + params.ifm.tensor = post.args[1] if params.reversed_operands else post.args[0] + params.ifm2.tensor = post.args[0] if params.reversed_operands else post.args[1] + channels_map = { + "NHWC": 3, + } + if str(params.ofm.layout) not in channels_map.keys(): + raise UnsupportedLayout(str(params.ofm.layout)) + + activation_map = {"clip": "CLIP"} + if params.activation: + activation = activation_map[params.activation.op.name] + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + + # We don't yet support activation functions that need to get legalized to LUTs. + lut = relay.const([], dtype="int8") + + return ethosu_ops.ethosu_binary_elementwise( + ifm=params.ifm.tensor, + ifm2=params.ifm2.tensor, + lut=lut, + operator_type=params.operator_type, + ifm_scale=float(params.ifm.q_params.scale_f32), + ifm_zero_point=int(params.ifm.q_params.zero_point), + ifm2_scale=float(params.ifm2.q_params.scale_f32), + ifm2_zero_point=int(params.ifm2.q_params.zero_point), + ofm_scale=float(params.ofm.q_params.scale_f32), + ofm_zero_point=int(params.ofm.q_params.zero_point), + ifm_channels=params.ifm.shape[3], + ifm2_channels=params.ifm2.shape[3], + reversed_operands=params.reversed_operands, + ofm_dtype=params.ofm.dtype, + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + ifm_layout=str(params.ifm.layout), + ifm2_layout=str(params.ifm2.layout), + ofm_layout=str(params.ofm.layout), + ) + + +class AddRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.AddParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.AddParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeAdd: + """This is the pass that wraps the AddRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(AddRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class SubRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.SubParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.SubParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeSub: + """This is the pass that wraps the SubRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(SubRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class MulRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MulParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MulParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeMul: + """This is the pass that wraps the MulRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(MulRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class MinRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MinParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MinParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeMin: + """This is the pass that wraps the MinRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(MinRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class MaxRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MaxParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.MaxParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeMax: + """This is the pass that wraps the MaxRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(MaxRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class ShlRewriter(BinaryElementwiseRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.ShlParams, + pattern=(wildcard().has_attr({"Composite": ethosu_patterns.ShlParams.composite_name}))( + wildcard(), wildcard() + ), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeShl: + """This is the pass that wraps the ShlRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(ShlRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -423,11 +641,20 @@ class LegalizeEthosU: def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: + """This is the method that replaces the operations with hardware/codegen supported + operations. + """ mod = LegalizeSplit()(mod) mod = LegalizeConv2D()(mod) mod = LegalizeDepthwiseConv2D()(mod) mod = LegalizeMaxPooling()(mod) mod = LegalizeAvgPooling()(mod) + mod = LegalizeAdd()(mod) + mod = LegalizeSub()(mod) + mod = LegalizeMul()(mod) + mod = LegalizeMin()(mod) + mod = LegalizeMax()(mod) + mod = LegalizeShl()(mod) return mod def __call__(self, *args, **kwargs): diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py index c9aa59b5dd2b0..05d4053045891 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py @@ -19,3 +19,4 @@ from .convolution import ethosu_conv2d from .depthwise import ethosu_depthwise_conv2d from .pooling import ethosu_pooling +from .binary_elementwise import ethosu_binary_elementwise diff --git a/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py new file mode 100644 index 0000000000000..d4ae18b529740 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/binary_elementwise.py @@ -0,0 +1,215 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Relay operators for binary elementwise operators for Arm(R) Ethos(TM)-U NPU""" +from typing import Optional +import tvm +from tvm.relay.op import _make +from tvm.topi.generic import schedule_injective +from tvm.relay.op.op import OpStrategy +from tvm.relay.op import strategy as _strategy + +from ..te import binary_elementwise_compute + + +def _extract_ethosu_binary_elementwise_params(attrs, args): + """Get the parameters necessary to construct a ethosu_binary_elementwise compute TE + from a ethosu_binary_elementwise Relay call.""" + ifm = args[0] + ifm2 = args[1] + lut = args[2] + operator_type = attrs.operator_type + ifm_scale = attrs.ifm_scale + ifm_zero_point = attrs.ifm_zero_point + ifm2_scale = attrs.ifm2_scale + ifm2_zero_point = attrs.ifm2_zero_point + ofm_scale = attrs.ofm_scale + ofm_zero_point = attrs.ofm_zero_point + ifm_channels = attrs.ifm_channels + ifm2_channels = attrs.ifm2_channels + reversed_operands = attrs.reversed_operands + activation = attrs.activation + clip_min = attrs.clip_min + clip_max = attrs.clip_max + ifm_layout = attrs.ifm_layout + ifm2_layout = attrs.ifm2_layout + ofm_layout = attrs.ofm_layout + ofm_dtype = attrs.ofm_dtype + + return ( + ifm, + ifm2, + lut, + operator_type, + ifm_scale, + ifm_zero_point, + ifm2_scale, + ifm2_zero_point, + ofm_scale, + ofm_zero_point, + ifm_channels, + ifm2_channels, + reversed_operands, + activation, + clip_min, + clip_max, + ifm_layout, + ifm2_layout, + ofm_layout, + ofm_dtype, + ) + + +@tvm.ir.register_op_attr("contrib.ethosu.binary_elementwise", "FTVMCompute") +def create_ethosu_binary_elementwise_compute(attrs, args, out_type): + """Create an ethosu_binary_elementwise compute op.""" + params = _extract_ethosu_binary_elementwise_params(attrs, args) + op = binary_elementwise_compute(*params) + return [op] + + +@tvm.ir.register_op_attr("contrib.ethosu.binary_elementwise", "FTVMStrategy") +def binary_elementwise_strategy_ethosu(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + create_ethosu_binary_elementwise_compute, + _strategy.wrap_topi_schedule(schedule_injective), + name="ethosu_binary_elementwise", + ) + return strategy + + +def ethosu_binary_elementwise( + ifm: tvm.relay.Expr, + ifm2: tvm.relay.Expr, + lut: tvm.relay.Expr, + operator_type: str, + ifm_scale: float, + ifm_zero_point: int, + ifm2_scale: float, + ifm2_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + ifm_channels: int, + ifm2_channels: int, + reversed_operands: bool, + ofm_dtype: str, + activation: Optional[str] = "NONE", + clip_min: Optional[int] = 0, + clip_max: Optional[int] = 0, + ifm_layout: Optional[str] = "NHWC", + ifm2_layout: Optional[str] = "NHWC", + ofm_layout: Optional[str] = "NHWC", +) -> tvm.relay.Call: + """This is a quantized binary elementwise operation as supported by + the NPU. It accepts either NHWC or NHCWB16 format + for the input data. + + Parameters + ---------- + ifm : tvm.relay.Expr + The Input Feature Map tensor (IFM). + ifm2 : tvm.relay.Expr + The Input Feature Map tensor 2 (IFM2). + lut : tvm.relay.Expr + The look-up table of values to use if activation = "LUT". + operator_type: str + The type of the binary elementwise operator. + "ADD" + "SUB" + "MUL" + "MIN" + "MAX" + "SHR" + "SHL" + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ifm2_scale : float + The quantization scale for the Input Feature Map tensor 2. + ifm2_zero_point : int + The quantization zero point for the Input Feature Map tensor 2. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + ifm_channels : int + The number of the Input Feature Map channels. + ifm2_channels : int + The number of the Input Feature Map 2 channels. + reversed_operands : bool + True if IFM2 is the first operand and IFM is the second operand. + ofm_dtype: str + The Output Feature Map tensor type. + MUL, ADD, SUB {IFM}->{OFM}: + {uint8, int8 int32} -> {uint8, int8, int32}, any pairing + MAX, MIN: + IFM and OFM must be of the same type, one of: + {int8, uint8} + SHR {IFM}->{OFM}: + {int32}->{int8, uint8, int32}, any pairing" + SHL: + {int32}->{int32} only + activation : str, optional + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + Available activations for activation type: + {int8, uint8}: "NONE", "CLIP", "TANH", "SIGMOID", "LUT" + {int32}: "NONE" + clip_min : int, optional + The minimum clipping value if activation = "CLIP". + clip_max : int, optional + The maximum clipping value if activation = "CLIP". + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ifm2_layout : str, optional + The layout of the Input Feature Map tensor 2. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + out : tvm.relay.Call + A call to the ethosu_binary_elementwise op. + """ + return _make.ethosu_binary_elementwise( + ifm, + ifm2, + lut, + operator_type, + ifm_scale, + ifm_zero_point, + ifm2_scale, + ifm2_zero_point, + ofm_scale, + ofm_zero_point, + ifm_channels, + ifm2_channels, + reversed_operands, + activation, + clip_min, + clip_max, + ifm_layout, + ifm2_layout, + ofm_layout, + ofm_dtype, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py index 7fb054edb6b61..970e366e50401 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py @@ -112,8 +112,8 @@ def ethosu_conv2d( ifm_layout: str = "NHWC", ofm_layout: str = "NHWC", ) -> tvm.relay.Call: - """This is a quantized 2D convolution operation as supported by the - Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format + """This is a quantized 2D convolution operation as supported by + the NPU. It accepts either NHWC or NHCWB16 format for the input data and OHWI format for the kernel weights. Reference: https://developer.arm.com/documentation/102420/0200/ diff --git a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py index d1b49ef6e8988..d8f2e8b3106c8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument -"""Relay operator for depthwise convolution""" +"""Relay operator for depthwise convolution for Arm(R) Ethos(TM)-U NPU""" + from typing import Tuple import tvm @@ -112,8 +113,8 @@ def ethosu_depthwise_conv2d( ifm_layout: str = "NHWC", ofm_layout: str = "NHWC", ) -> tvm.relay.Call: - """This is a quantized 2D depthwise convolution operation as supported by the - Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format + """This is a quantized 2D depthwise convolution operation as supported by + the NPU. It accepts either NHWC or NHCWB16 format for the input data and OHWI format for the kernel weights. Reference: https://developer.arm.com/documentation/102420/0200/ diff --git a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py index f344f61f1dd1b..cc363738c37f6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=unused-argument -"""Relay operators for pooling""" +"""Relay operators for pooling for Arm(R) Ethos(TM)-U NPU""" from typing import Tuple import tvm @@ -107,8 +107,8 @@ def ethosu_pooling( ifm_layout: str = "NHWC", ofm_layout: str = "NHWC", ) -> tvm.relay.Call: - """This is a quantized 2D pooling operation as supported by the - Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format + """This is a quantized 2D pooling operation as supported by + the NPU. It accepts either NHWC or NHCWB16 format for the input data. Parameters diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py index e2eb28f8f9152..5c262362e4f42 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py @@ -19,3 +19,4 @@ from .convolution import * from .depthwise import * from .pooling import * +from .binary_elementwise import * diff --git a/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py new file mode 100644 index 0000000000000..84d4e1b41558f --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/binary_elementwise.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Tensor Expressions for binary_elementwise""" +import operator +from tvm import te +from .dma import dma_ofm_compute, dma_ifm_compute + + +def binary_elementwise_compute( + ifm: te.Tensor, + ifm2: te.Tensor, + lut: te.Tensor, + operator_type: str, + ifm_scale: float, + ifm_zero_point: int, + ifm2_scale: float, + ifm2_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + ifm_channels: int, + ifm2_channels: int, + reversed_operands: bool, + activation: str, + clip_min: int, + clip_max: int, + ifm_layout: str, + ifm2_layout: str, + ofm_layout: str, + ofm_dtype: str, +) -> te.Tensor: + """A compute operator representing the capabilities of binary_elementwise for the NPU. + + Parameters + ---------- + ifm : te.Tensor + The Input Feature Map tensor (IFM). + ifm2 : te.Tensor + The Input Feature Map tensor 2 (IFM2). + lut : te.Tensor + The look-up table values to use if activation = "LUT". + operator_type: str + The type of the binary elementwise operator. + "ADD" + "SUB" + "MUL" + "MIN" + "MAX" + "SHR" + "SHL" + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ifm2_scale : float + The quantization scale for the Input Feature Map tensor 2. + ifm2_zero_point : int + The quantization zero point for the Input Feature Map tensor 1. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + ifm_channels : int + The number of the Input Feature Map channels. + ifm2_channels : int + The number of the Input Feature Map 2 channels. + reversed_operands : bool + True if IFM2 is the first operand and IFM is the second operand. + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + Available activations for activation type: + {int8, uint8}: "NONE", "CLIP", "TANH", "SIGMOID", "LUT" + {int32}: "NONE" + clip_min : int + The minimum clipping value if activation = "CLIP". + clip_max : int + The maximum clipping value if activation = "CLIP". + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ifm2_layout : str, optional + The layout of the Input Feature Map tensor 2. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_dtype: str + The Output Feature Map tensor type. + MUL, ADD, SUB {IFM}->{OFM}: + {uint8, int8 int32} -> {uint8, int8, int32}, any pairing + MAX, MIN: + IFM and OFM must be of the same type, one of: + {int8, uint8} + SHR {IFM}->{OFM}: + {int32}->{int8, uint8, int32}, any pairing" + SHL: + {int32}->{int32} only + + Returns + ------- + te.Tensor + The Output Feature Map tensor. + """ + # Compute operation for the IFM DMA pipeline + dmaed_ifm = dma_ifm_compute( + ifm, ifm_layout, ifm_zero_point, ifm_scale, ifm_channels, (0, 0, 0, 0) + ) + dmaed_ifm2 = dma_ifm_compute( + ifm2, ifm2_layout, ifm2_zero_point, ifm2_scale, ifm2_channels, (0, 0, 0, 0) + ) + + # Binary elementwise compute operation + ofm_height = dmaed_ifm.shape[1] + ofm_width = dmaed_ifm.shape[2] + + binary_elementwise_attrs = { + "op": "ethosu_binary_elementwise", + "operator_type": operator_type, + "reversed_operands": reversed_operands, + "activation": activation, + "clip_min": clip_min, + "clip_max": clip_max, + } + + operators = { + "ADD": operator.add, + "SUB": operator.sub, + "MUL": operator.mul, + "MIN": te.min, + "MAX": te.max, + "SHR": operator.add, + "SHL": operator.add, + } + broadcast = [value == 1 for value in dmaed_ifm2.shape] + + if reversed_operands: + binary_elementwise = te.compute( + (1, ofm_height, ofm_width, ifm_channels), + lambda nn, hh, ww, cc: operators[operator_type]( + dmaed_ifm2( + 0 if broadcast[0] else nn, + 0 if broadcast[1] else hh, + 0 if broadcast[2] else ww, + 0 if broadcast[3] else cc, + ).astype(ifm.dtype), + dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype), + ).astype(ofm_dtype), + name="ethosu_binary_elementwise", + attrs=binary_elementwise_attrs, + ) + else: + binary_elementwise = te.compute( + (1, ofm_height, ofm_width, ifm_channels), + lambda nn, hh, ww, cc: operators[operator_type]( + dmaed_ifm(nn, hh, ww, cc).astype(ifm.dtype), + dmaed_ifm2( + 0 if broadcast[0] else nn, + 0 if broadcast[1] else hh, + 0 if broadcast[2] else ww, + 0 if broadcast[3] else cc, + ).astype(ifm.dtype), + ).astype(ofm_dtype), + name="ethosu_binary_elementwise", + attrs=binary_elementwise_attrs, + ) + + # Compute operation for the OFM DMA pipeline + return dma_ofm_compute(binary_elementwise, ofm_layout, ofm_zero_point, ofm_scale, ifm_channels) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py new file mode 100644 index 0000000000000..1ea24edccb604 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the binary_elementwise operators in TIR.""" +from typing import Dict, Tuple +import tvm +from .utils import get_outer_loops, get_op_attrs +from .dma import get_ifm_params, get_ofm_params +from .spec import SerialActivation, SerialBinaryElementwise + + +def ignore_cast(tir_load: tvm.tir.expr.Load) -> tvm.tir.Var: + """When the datatype of the ifm, ifm2 and ofm do not match, + casts are inserted in TE to handle the difference in these types. + Since TIR is not directly run on the NPU we can simply ignore + these, and allow the NPU to handle the difference in datatypes + itself. + + Parameters + ---------- + tir_load : tvm.tir.expr.Load + + Returns + ------- + tvm.tir.Var + """ + return tir_load.value if isinstance(tir_load, tvm.tir.Cast) else tir_load + + +def get_binary_elementwise_params( + stmt: tvm.tir.AttrStmt, + producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], +) -> Tuple[SerialBinaryElementwise, tvm.tir.Var, tvm.tir.Var]: + """Get the parameters necessary to construct a call_extern for a binary_elementwise. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a binary elementwise loop nest. + producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + SerialBinaryElementwise + The parameters needed to construct a binary elementwise operator. + output_pointer : tvm.tir.Var + The output pointer of the binary elementwise operation. + replace_pointer : tvm.tir.Var + The output pointer of the DMA write operation, which is to replace + the binary elementwise output pointer. + """ + attrs, body = get_op_attrs(stmt) + reversed_operands = attrs["reversed_operands"] + + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + op = ignore_cast(inner.value) + input_pointer = ignore_cast(op.a).buffer_var + input_pointer1 = ignore_cast(op.b).buffer_var + + if reversed_operands: + input_pointer, input_pointer1 = input_pointer1, input_pointer + output_pointer = inner.buffer_var + # Get feature map info + serial_ifm, _ = get_ifm_params(input_pointer, producers) + serial_ifm2, _ = get_ifm_params(input_pointer1, producers) + serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + # Get activation info + serial_activation = SerialActivation( + op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] + ) + return ( + SerialBinaryElementwise( + ifm=serial_ifm, + ifm2=serial_ifm2, + ofm=serial_ofm, + operator_type=attrs["operator_type"], + reversed_operands=reversed_operands, + activation=serial_activation, + ), + output_pointer, + replace_pointer, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 2f5d7abd260df..a5678d1cc2d19 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -23,6 +23,7 @@ from .convolution import get_conv2d_params from .depthwise import get_depthwise_conv2d_params from .pooling import get_pooling_params +from .binary_elementwise import get_binary_elementwise_params from .transform import get_copy_params from .utils import get_weights_pointer, get_scale_bias_pointer @@ -56,6 +57,7 @@ def ReplaceOperators(): "ethosu_copy": get_copy_params, "ethosu_depthwise_conv2d": get_depthwise_conv2d_params, "ethosu_pooling": get_pooling_params, + "ethosu_binary_elementwise": get_binary_elementwise_params, } pointer_to_producer = {} pointer_to_consumer = {} diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py index ff019c7783db7..269238a157ef8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -261,3 +261,24 @@ def __init__( self.padding = padding self.activation = activation self.upscale = upscale + + +class SerialBinaryElementwise(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.binary_elementwise tir extern call on a predefined ordering""" + + def __init__( + self, + ifm: SerialFeatureMap, + ifm2: SerialFeatureMap, + ofm: SerialFeatureMap, + operator_type: str, + reversed_operands: bool, + activation: SerialActivation, + ): + self.ifm = ifm + self.ifm2 = ifm2 + self.ofm = ofm + self.operator_type = operator_type + self.reversed_operands = reversed_operands + self.activation = activation diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 861669588f72c..f82d7bb857a61 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -213,7 +213,10 @@ def replace_npu_fm_with_address(npu_fm): buffer = npu_fm.tiles.addresses[0].buffer_var assert buffer in buffer_addresses.keys() address, buffer_type = buffer_addresses[buffer] - npu_fm.tiles.addresses[0] = address + index = npu_fm.tiles.addresses[0].index * ( + np.iinfo(np.dtype(npu_fm.tiles.addresses[0])).bits // 8 + ) + npu_fm.tiles.addresses[0] = address + int(index) npu_fm.region = _REGION_MAP[buffer_type] return npu_fm @@ -304,6 +307,7 @@ def translate_ethosu_tir_call_extern(tir_call_extern): "ethosu_copy": translate_ethosu_copy, "ethosu_depthwise_conv2d": translate_ethosu_depthwise_conv2d, "ethosu_pooling": translate_ethosu_pooling, + "ethosu_binary_elementwise": translate_ethosu_binary_elementwise, } ext_call_type = tir_call_extern.args[0].value assert ext_call_type in supported_call_extern.keys(), f"{ext_call_type} is not yet supported" @@ -482,6 +486,7 @@ def _create_npu_feature_map(serial_feature_map: spec.SerialFeatureMap) -> vapi.N } layout = str(serial_feature_map.layout.value) data_type = str(serial_feature_map.data_type.value) + date_type_bytes = np.iinfo(np.dtype(data_type)).bits // 8 assert layout in layout_map.keys() assert data_type in datatype_map.keys() nfm = vapi.NpuFeatureMap() @@ -507,9 +512,9 @@ def _create_npu_feature_map(serial_feature_map: spec.SerialFeatureMap) -> vapi.N ) nfm.layout = layout_map[layout] nfm.strides = vapi.NpuShape3D( - int(serial_feature_map.stride_h), - int(serial_feature_map.stride_w), - int(serial_feature_map.stride_c), + int(serial_feature_map.stride_h.value) * date_type_bytes, + int(serial_feature_map.stride_w.value) * date_type_bytes, + int(serial_feature_map.stride_c.value) * date_type_bytes, ) return nfm @@ -677,3 +682,66 @@ def _create_npu_op_pooling(serial_pooling: spec.SerialPooling): npu_pooling_op.block_config = block_config return npu_pooling_op + + +def translate_ethosu_binary_elementwise( + tir_call_extern: tvm.tir.Call, +) -> vapi.NpuElementWiseOperation: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. + + Parameters + ---------- + tir_call_extern : tvm.tir.Call + This should be a TIR call_extern that has agreed upon ordering + for TIR Compiler. See SerialBinaryElementwise in + tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. + + Returns + ------- + ethosu.vela.api.NpuElementWiseOperation + The vela object containing the params of ethosu_binary_elementwise + """ + serial_object = spec.create_serial_object( + spec.SerialBinaryElementwise, tir_call_extern.args[1:] + ) + return _create_npu_op_binary_elementwise(serial_object) + + +def _create_npu_op_binary_elementwise(serial_binary_elementwise: spec.SerialBinaryElementwise): + operator_type = serial_binary_elementwise.operator_type + if operator_type == "ADD": + op = vapi.NpuElementWiseOp.ADD + elif operator_type == "SUB": + op = vapi.NpuElementWiseOp.SUB + elif operator_type == "MUL": + op = vapi.NpuElementWiseOp.MUL + elif operator_type == "MIN": + op = vapi.NpuElementWiseOp.MIN + elif operator_type == "MAX": + op = vapi.NpuElementWiseOp.MAX + elif operator_type == "SHR": + op = vapi.NpuElementWiseOp.SHR + elif operator_type == "SHL": + op = vapi.NpuElementWiseOp.SHL + + npu_binary_elementwise_op = vapi.NpuElementWiseOperation(op) + npu_binary_elementwise_op.ifm = _create_npu_feature_map(serial_binary_elementwise.ifm) + npu_binary_elementwise_op.ifm2 = _create_npu_feature_map(serial_binary_elementwise.ifm2) + npu_binary_elementwise_op.ofm = _create_npu_feature_map(serial_binary_elementwise.ofm) + npu_binary_elementwise_op.reversed_operands = serial_binary_elementwise.reversed_operands + + npu_binary_elementwise_op.activation = _create_npu_activation( + serial_binary_elementwise.activation + ) + if ( + npu_binary_elementwise_op.activation + and npu_binary_elementwise_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + ): + _convert_clip_bounds(npu_binary_elementwise_op) + + target_accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_binary_elementwise_op, target_accel_config) + npu_binary_elementwise_op.block_config = block_config + + return npu_binary_elementwise_op diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index ee47e4abd42bd..8afb6eb9b9eeb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -75,6 +75,21 @@ class ClipArgs(Enum): A_MAX = 2 +class BinaryElementwiseArgs(Enum): + """This is a helper enums to access the correct index + of binary elementwise arguments + """ + + ifm = 0 + ifm2 = 1 + ifm_scale = 2 + ifm_zero_point = 3 + ifm2_scale = 4 + ifm2_zero_point = 5 + ofm_scale = 6 + ofm_zero_point = 7 + + def is_composite_func(func: relay.Function, name: str) -> bool: """ This method checks whether the call is to diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index a152235c702b0..25538cae9dbcd 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -40,6 +40,7 @@ from tvm.relay.backend.contrib.ethosu.util import QConv2DArgs # type: ignore from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs from tvm.relay.backend.contrib.ethosu.util import RequantArgs + from tvm.relay.backend.contrib.ethosu.util import BinaryElementwiseArgs from tvm.relay.backend.contrib.ethosu.util import get_dim_value except ImportError: vapi = None @@ -99,9 +100,8 @@ def check_strides(strides: List[int]) -> bool: return True -def check_valid_dtypes(tensor_params: List[TensorParams]) -> bool: +def check_valid_dtypes(tensor_params: List[TensorParams], supported_dtypes: List[type]) -> bool: """This function checks whether dtypes are supported by the NPU""" - supported_dtypes = (np.uint8, np.int8) for tep in tensor_params: # Check for dtypes if np.dtype(tep.dtype) not in supported_dtypes: @@ -248,7 +248,7 @@ def is_valid(self) -> bool: This function checks whether QnnConv2D has compatible attributes with the NPU """ tensor_params = [self.weights, self.ifm, self.ofm] - if not check_valid_dtypes(tensor_params): + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): return False if not check_weights(self.weights, self.dilation): return False @@ -287,7 +287,7 @@ def is_valid(self): Checks whether QnnDepthwiseConv2D + activation function has compatible attributes with HW """ tensor_params = [self.weights, self.ifm, self.ofm] - if not check_valid_dtypes(tensor_params): + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): return False if not check_weights(self.weights, self.dilation): return False @@ -373,7 +373,7 @@ def is_valid(self): This function checks whether MaxPool2D has compatible attributes with the NPU """ tensor_params = [self.ifm, self.ofm] - if not check_valid_dtypes(tensor_params): + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): return False if self.ifm.dtype != self.ofm.dtype: return False @@ -432,7 +432,7 @@ def is_valid(self): This function checks whether AvgPool2D has compatible attributes with the NPU """ tensor_params = [self.ifm, self.ofm] - if not check_valid_dtypes(tensor_params): + if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, np.int8]): return False if self.ifm.dtype != self.ofm.dtype: return False @@ -458,6 +458,316 @@ def qnn_avgpool2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return pattern +class BinaryElementwiseParams: + """ + This class will parse a call to a ethosu.binary_elementwise composite function + and extract the parameter information. + """ + + def __init__(self, func_body: Call, operator_type: str, has_quantization_parameters: bool): + clip = None + if str(func_body.op) == "clip": + clip = func_body + binary_op = clip.args[0] + else: + binary_op = func_body + + layout = "NHWC" + + if has_quantization_parameters: + self.ifm = TensorParams( + binary_op.args[BinaryElementwiseArgs.ifm.value], + layout, + binary_op.args[BinaryElementwiseArgs.ifm_scale.value], + binary_op.args[BinaryElementwiseArgs.ifm_zero_point.value], + ) + self.ifm2 = TensorParams( + binary_op.args[BinaryElementwiseArgs.ifm2.value], + layout, + binary_op.args[BinaryElementwiseArgs.ifm2_scale.value], + binary_op.args[BinaryElementwiseArgs.ifm2_zero_point.value], + ) + self.ofm = TensorParams( + binary_op, + layout, + binary_op.args[BinaryElementwiseArgs.ofm_scale.value], + binary_op.args[BinaryElementwiseArgs.ofm_zero_point.value], + ) + else: + self.ifm = TensorParams( + binary_op.args[BinaryElementwiseArgs.ifm.value], + layout, + ) + self.ifm2 = TensorParams( + binary_op.args[BinaryElementwiseArgs.ifm2.value], + layout, + ) + self.ofm = TensorParams( + binary_op, + layout, + ) + self.activation = clip + self.operator_type = operator_type + + def can_broadcast(x, y): + for i in range(1, 4): + if x.shape[i] == y.shape[i] or y.shape[i] == 1: + continue + return False + return True + + if can_broadcast(self.ifm, self.ifm2): + self.reversed_operands = False + self.valid_broadcast = True + elif can_broadcast(self.ifm2, self.ifm): + self.reversed_operands = True + self.ifm, self.ifm2 = self.ifm2, self.ifm + self.valid_broadcast = True + else: + self.valid_broadcast = False + + def is_valid(self): + """ + This function checks whether BinaryElementwise has compatible attributes with the NPU + """ + if np.dtype(self.ofm) == np.int32 and self.activation is not None: + return False + if len(self.ifm.shape) != 4 or len(self.ifm2.shape) != 4: + return False + if self.ifm.shape[0] != 1 or self.ifm2.shape[0] != 1: + return False + if not self.valid_broadcast: + return False + return True + + +class AddParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Add composite function + and extract the parameter information. + """ + + composite_name = "ethosu.add" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "ADD", True) + + def is_valid(self): + """ + This function checks whether Add has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8, np.int32] + ): + return False + return True + + +def qnn_add_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for qnn.add with optional fused RELU activation. + """ + pattern = is_op("qnn.add")( + wildcard(), + wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class SubParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Sub composite function + and extract the parameter information. + """ + + composite_name = "ethosu.sub" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "SUB", True) + + def is_valid(self): + """ + This function checks whether Sub has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8, np.int32] + ): + return False + return True + + +def qnn_subtract_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for qnn.subtract with optional fused RELU activation. + """ + pattern = is_op("qnn.subtract")( + wildcard(), + wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class MulParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Mul composite function + and extract the parameter information. + """ + + composite_name = "ethosu.mul" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "MUL", True) + + def is_valid(self): + """ + This function checks whether Mul has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8, np.int32] + ): + return False + return True + + +def qnn_mul_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for qnn.mul with optional fused RELU activation. + """ + pattern = is_op("qnn.mul")( + wildcard(), + wildcard(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + is_constant(), + ) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class MinParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Min composite function + and extract the parameter information. + """ + + composite_name = "ethosu.min" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "MIN", False) + + def is_valid(self): + """ + This function checks whether Min has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if self.ifm.dtype != self.ifm2.dtype: + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] + ): + return False + return True + + +def minimum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for minimum with optional fused RELU activation. + """ + pattern = is_op("minimum")(wildcard(), wildcard()) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class MaxParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Max composite function + and extract the parameter information. + """ + + composite_name = "ethosu.max" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "MAX", False) + + def is_valid(self): + """ + This function checks whether Max has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if self.ifm.dtype != self.ifm2.dtype: + return False + if not check_valid_dtypes( + [self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.uint8, np.int8] + ): + return False + return True + + +def maximum_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for maximum with optional fused RELU activation. + """ + pattern = is_op("maximum")(wildcard(), wildcard()) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class ShlParams(BinaryElementwiseParams): + """ + This class will parse a call to a ethosu.binary_elementwise Shl composite function + and extract the parameter information. + """ + + composite_name = "ethosu.shl" + + def __init__(self, func_body: Call): + BinaryElementwiseParams.__init__(self, func_body, "SHL", False) + + def is_valid(self): + """ + This function checks whether Shl has compatible attributes with the NPU + """ + if not super().is_valid(): + return False + if not check_valid_dtypes([self.ifm, self.ifm2, self.ofm], supported_dtypes=[np.int32]): + return False + return True + + +def shl_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for left_shift with optional fused RELU activation. + """ + pattern = is_op("left_shift")(wildcard(), wildcard()) + pattern = pattern.optional(is_op("clip")) + return pattern + + @register_pattern_table("ethosu") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -481,6 +791,36 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_avgpool2d_pattern(), lambda pat: AvgPool2DParams(pat).is_valid(), ), + ( + AddParams.composite_name, + qnn_add_pattern(), + lambda pat: AddParams(pat).is_valid(), + ), + ( + SubParams.composite_name, + qnn_subtract_pattern(), + lambda pat: SubParams(pat).is_valid(), + ), + ( + MulParams.composite_name, + qnn_mul_pattern(), + lambda pat: MulParams(pat).is_valid(), + ), + ( + MinParams.composite_name, + minimum_pattern(), + lambda pat: MinParams(pat).is_valid(), + ), + ( + MaxParams.composite_name, + maximum_pattern(), + lambda pat: MaxParams(pat).is_valid(), + ), + ( + ShlParams.composite_name, + shl_pattern(), + lambda pat: ShlParams(pat).is_valid(), + ), ] diff --git a/src/relay/op/contrib/ethosu/binary_elementwise.cc b/src/relay/op/contrib/ethosu/binary_elementwise.cc new file mode 100644 index 0000000000000..5b4900edc74bf --- /dev/null +++ b/src/relay/op/contrib/ethosu/binary_elementwise.cc @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/binary_elementwise.cc + * \brief Binary elementwise operators definitions for the Arm(R) Ethos(TM)-U NPU. + */ +#include + +#include "common.h" + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes used by the Ethos(TM)-U NPU binary elementwise operators */ +struct EthosuBinaryElementwiseAttrs : public tvm::AttrsNode { + String operator_type; + double ifm_scale; + int ifm_zero_point; + double ifm2_scale; + int ifm2_zero_point; + double ofm_scale; + int ofm_zero_point; + IndexExpr ifm_channels; + IndexExpr ifm2_channels; + bool reversed_operands; + String activation; + int clip_min; + int clip_max; + String ifm_layout; + String ifm2_layout; + String ofm_layout; + String ofm_dtype; + + TVM_DECLARE_ATTRS(EthosuBinaryElementwiseAttrs, "relay.attrs.EthosuBinaryElementwiseAttrs") { + TVM_ATTR_FIELD(operator_type) + .describe( + "The type of the binary elementwise operator." + "'ADD'" + "'SUB'" + "'MUL'" + "'MIN'" + "'MAX'" + "'SHR'" + "'SHL'"); + TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm_zero_point) + .describe("The quantization zero point for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm2_scale) + .describe("The quantization scale for the Input Feature Map tensor 2."); + TVM_ATTR_FIELD(ifm2_zero_point) + .describe("The quantization zero point for the Input Feature Map tensor 2."); + TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ofm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ifm_channels).describe("The number of the Input Feature Map channels."); + TVM_ATTR_FIELD(ifm2_channels).describe("The number of the Input Feature Map 2 channels."); + TVM_ATTR_FIELD(reversed_operands) + .describe("True if IFM2 is the first operand and IFM is the second operand.") + .set_default(false); + TVM_ATTR_FIELD(activation) + .describe( + "The activation function to use. " + "'NONE' - no activation function. " + "'CLIP' - clip the output between clip_min and clip_max. " + "'TANH' - tanh activation function. " + "'SIGMOID' - sigmoid activation function. " + "'LUT' - use a look-up table to perform the activation function." + "Available activations for activation type:" + "{int8, uint8}: 'NONE', 'CLIP', 'TANH', 'SIGMOID', 'LUT'" + "{int32}: 'NONE'") + .set_default("NONE"); + TVM_ATTR_FIELD(clip_min) + .describe("The minimum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(clip_max) + .describe("The maximum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(ifm_layout) + .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + TVM_ATTR_FIELD(ifm2_layout) + .describe("The layout of the Input Feature Map tensor 2. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + TVM_ATTR_FIELD(ofm_layout) + .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + TVM_ATTR_FIELD(ofm_dtype).describe( + "The Output Feature Map tensor type." + "MUL, ADD, SUB {IFM}->{OFM}:" + " {uint8, int8 int32} -> {uint8, int8, int32}, any pairing" + "MAX, MIN:" + " IFM and OFM must be of the same type, one of:" + " {int8, uint8}" + "SHR {IFM}->{OFM}:" + " {int32}->{int8, uint8, int32}, any pairing" + "SHL:" + " {int32}->{int32} only"); + } +}; + +TVM_REGISTER_NODE_TYPE(EthosuBinaryElementwiseAttrs); + +bool EthosuBinaryElementwiseRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + const int ifm_index = 0; + const int ifm2_index = 1; + const int result_index = 3; + ICHECK_EQ(types.size(), result_index + 1); + + const auto* ifm = types[ifm_index].as(); + const auto* ifm2 = types[ifm2_index].as(); + if (ifm == nullptr) return false; + if (ifm2 == nullptr) return false; + + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "EthosuBinaryElementwiseAttrs cannot be nullptr."; + + String operator_type = param->operator_type; + auto ifm_dtype = ifm->dtype; + auto ifm2_dtype = ifm2->dtype; + DataType ofm_dtype; + + if (param->ofm_dtype == "int8") { + ofm_dtype = DataType::Int(8); + } else if (param->ofm_dtype == "uint8") { + ofm_dtype = DataType::UInt(8); + } else if (param->ofm_dtype == "int32") { + ofm_dtype = DataType::Int(32); + } + + if (ifm_dtype != ifm2_dtype) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << "type for ifm2 be the same of ifm but was " << ifm2_dtype + << " instead of " << ifm_dtype); + return false; + } + + if (operator_type == "ADD" || operator_type == "SUB" || operator_type == "MUL") { + if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8) && + ifm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " << operator_type + << " type(uint8) or type(int8) or type(int32) for ifm but was " << ifm_dtype); + return false; + } + if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && + ofm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " << operator_type + << " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype); + return false; + } + } else if (operator_type == "MIN" || operator_type == "MAX") { + if (ifm_dtype != DataType::UInt(8) && ifm_dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " << operator_type + << " type(uint8) or type(int8) for ifm but was " << ifm_dtype); + return false; + } + if (ifm_dtype != ofm_dtype) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << operator_type + << " type for ofm be the same of ifm but was " << ofm_dtype + << " instead of " << ifm_dtype); + return false; + } + } else if (operator_type == "SHR") { + if (ifm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << operator_type << " type(int32) for ifm but was " + << ifm_dtype); + return false; + } + if (ofm_dtype != DataType::UInt(8) && ofm_dtype != DataType::Int(8) && + ofm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " << operator_type + << " type(uint8) or type(int8) or type(int32) for ofm but was " << ofm_dtype); + return false; + } + } else if (operator_type == "SHL") { + if (ifm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << operator_type << " type(int32) for ifm but was " + << ifm_dtype); + + return false; + } + if (ofm_dtype != DataType::Int(32)) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise " + << operator_type << " type(int32) for ofm but was " + << ofm_dtype); + return false; + } + } else { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected ethosu_binary_elementwise 'ADD' or 'SUB' or 'MUL' or " + << "'MIN' or 'MAX' or 'SHR' or 'SHL' for operator_type but was " << param->operator_type); + return false; + } + + // Assign ofm type + auto ofm_shape = EthosuInferBinaryElementwiseOutputShape(ifm->shape, param->ifm_layout, + param->ofm_layout, param->ifm_channels); + reporter->Assign(types[result_index], TensorType(ofm_shape, ofm_dtype)); + return true; +} + +Expr MakeEthosuBinaryElementwise(Expr ifm, Expr ifm2, Expr lut, String operator_type, + double ifm_scale, int ifm_zero_point, double ifm2_scale, + int ifm2_zero_point, double ofm_scale, int ofm_zero_point, + IndexExpr ifm_channels, IndexExpr ifm2_channels, + bool reversed_operands, String activation, int clip_min, + int clip_max, String ifm_layout, String ifm2_layout, + String ofm_layout, String ofm_dtype) { + auto attrs = make_object(); + + attrs->operator_type = std::move(operator_type); + attrs->ifm_scale = ifm_scale; + attrs->ifm_zero_point = ifm_zero_point; + attrs->ifm2_scale = ifm2_scale; + attrs->ifm2_zero_point = ifm2_zero_point; + attrs->ofm_scale = ofm_scale; + attrs->ofm_zero_point = ofm_zero_point; + attrs->ifm_channels = std::move(ifm_channels); + attrs->ifm2_channels = std::move(ifm2_channels); + attrs->reversed_operands = reversed_operands; + attrs->activation = std::move(activation); + attrs->clip_min = clip_min; + attrs->clip_max = clip_max; + attrs->ifm_layout = std::move(ifm_layout); + attrs->ifm2_layout = std::move(ifm2_layout); + attrs->ofm_layout = std::move(ofm_layout); + attrs->ofm_dtype = std::move(ofm_dtype); + + static const Op& op = Op::Get("contrib.ethosu.binary_elementwise"); + return Call(op, {ifm, ifm2, lut}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.ethosu_binary_elementwise") + .set_body_typed(MakeEthosuBinaryElementwise); + +RELAY_REGISTER_OP("contrib.ethosu.binary_elementwise") + .describe(R"code(Arm(R) Ethos(TM)-U NPU quantized binary elementwise operator. + +This Relay operator corresponds to the hardware-implemented quantized +binary elementwise operation found on Ethos(TM)-U NPU. It accepts either NHWC +or NHCWB16 format for the inputs data (input feature maps, or IFMs). + +Reference: https://developer.arm.com/documentation/102420/0200/ + +- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **ifm2**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **ofm**: (1, ofm_height, ofm_width, ifm_channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(3) + .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") + .add_argument("ifm2", "Tensor", "The Input Feature Map tensor 2 (IFM2).") + .add_argument("lut", "Tensor", "The look-up table of values to use if activation = 'LUT'") + .set_support_level(11) + .add_type_rel("EthosuBinaryElementwise", EthosuBinaryElementwiseRel); + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/contrib/ethosu/common.cc b/src/relay/op/contrib/ethosu/common.cc index bdda81bc7708b..bdaa9da526186 100644 --- a/src/relay/op/contrib/ethosu/common.cc +++ b/src/relay/op/contrib/ethosu/common.cc @@ -32,6 +32,24 @@ namespace op { namespace contrib { namespace ethosu { +Array EthosuInferBinaryElementwiseOutputShape(Array ifm_shape, + String ifm_layout, String ofm_layout, + IndexExpr ofm_channels) { + // In the case of NHCWB16, convert the ifm shape to NHW (C not required for this function) + if (ifm_layout == "NHCWB16") { + ifm_shape = {ifm_shape[0], ifm_shape[1], ifm_shape[3]}; + } + Array oshape({ifm_shape[0], ifm_shape[1], ifm_shape[2], ofm_channels}); + + // If the ofm is NHCWB16, convert the layout + if (ofm_layout == "NHCWB16") { + int channel_bricks = 1 + (oshape[3].as()->value - 1) / 16; + oshape = {oshape[0], oshape[1], channel_bricks, oshape[2], 16}; + } + + return oshape; +} + Array EthosuInferKernelOutput(Array ifm_shape, String ifm_layout, String ofm_layout, Array kernel_shape, IndexExpr ofm_channels, Array dilation, diff --git a/src/relay/op/contrib/ethosu/common.h b/src/relay/op/contrib/ethosu/common.h index b5377e6e8bdf7..574fb91181ef6 100644 --- a/src/relay/op/contrib/ethosu/common.h +++ b/src/relay/op/contrib/ethosu/common.h @@ -33,6 +33,17 @@ namespace op { namespace contrib { namespace ethosu { +/*! \brief Infer the output tensor shape for binary elementwise operators. + * \param ifm_shape The shape of Input Feature Map. + * \param ifm_layout The layout of the IFM (NHWC or NHCWB16). + * \param ofm_layout The layout of the OFM (NHWC or NHCWB16). + * \param ofm_channels The number of Output Feature Map channels. + * \return The shape of the output tensor. + */ +Array EthosuInferBinaryElementwiseOutputShape(Array ifm_shape, + String ifm_layout, String ofm_layout, + IndexExpr ofm_channels); + /*! \brief Infer the output tensor shape for convolution and pooling operators. * \param ifm_shape The shape of Input Feature Map. * \param ifm_layout The layout of the IFM (NHWC or NHCWB16). diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc index 86f14f37a8d8a..bcf54fbd4a2d6 100644 --- a/src/relay/op/contrib/ethosu/pooling.cc +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -19,7 +19,7 @@ /*! * \file src/relay/op/contrib/ethosu/pooling.cc - * \brief Pooling operators definitions for the Arm(R) Ethos(TM)-U NPU convolution ops. + * \brief Pooling operators definitions for the Arm(R) Ethos(TM)-U NPU. */ #include diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 58862c5f5faa3..17d3fad9cb30a 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -509,3 +509,56 @@ def make_ethosu_pooling( ofm_layout=ofm_layout, ) return pooling + + +def get_binary_elementwise_args(call, include_buffers=False): + args = call.args + binary_elementwise_args = [] + + for i, arg in enumerate(args): + if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): + binary_elementwise_args.append(arg.value) + elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: + binary_elementwise_args.append(arg.index) + else: + binary_elementwise_args.append(arg) + + return binary_elementwise_args + + +def make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + ofm_dtype, + reversed_operands=False, + activation="NONE", + ifm_layout="NHWC", + ifm2_layout="NHWC", + ofm_layout="NHWC", +): + ethosu_binary_elementwise = ethosu_ops.ethosu_binary_elementwise( + ifm=ifm, + ifm2=ifm2, + lut=relay.const([], dtype="int8"), + operator_type=operator_type, + ifm_scale=1, + ifm_zero_point=0, + ifm2_scale=1, + ifm2_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + ifm_channels=ifm_channels, + ifm2_channels=ifm2_channels, + reversed_operands=reversed_operands, + activation=activation, + ofm_dtype=ofm_dtype, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + ifm_layout=ifm_layout, + ifm2_layout=ifm2_layout, + ofm_layout=ofm_layout, + ) + return ethosu_binary_elementwise diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 478a3c2bd5219..a5686c81beb8a 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -276,8 +276,6 @@ def test_ethosu_pooling( dtype = "int8" def create_tflite_graph(): - tf.config.run_functions_eagerly(True) - class Model(tf.Module): @tf.function def tf_function(self, x): @@ -343,5 +341,255 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"]) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape", + [ + ([1, 2, 3, 4], [1, 2, 3, 4]), + ([1, 2, 3, 4], [1, 1, 1, 1]), + ([1, 1, 1, 1], [1, 2, 3, 4]), + ], +) +@pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) +def test_ethosu_binary_elementwise( + accel_type, + operator_type, + ifm_shape, + ifm2_shape, + activation_function, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, lhs, rhs): + if operator_type == "ADD": + op = tf.math.add(lhs, rhs) + elif operator_type == "SUB": + op = tf.math.subtract(lhs, rhs) + elif operator_type == "MUL": + op = tf.math.multiply(lhs, rhs) + elif operator_type == "MIN": + op = tf.math.minimum(lhs, rhs) + elif operator_type == "MAX": + op = tf.math.maximum(lhs, rhs) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + data2 = np.random.rand(*tuple(ifm2_shape)) * 2 + yield [data.astype(np.float32), data2.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"ifm": ifm_shape, "ifm2": ifm2_shape}, + dtype_dict={"ifm": dtype, "ifm2": dtype}, + ) + mod = partition_for_ethosu(mod, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + output_tolerance=1 if operator_type == "MAX" else 0, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape", + [ + ([1, 2, 3, 4], [1, 2, 3, 4]), + ([1, 2, 3, 4], [1, 1, 3, 1]), + ([1, 1, 3, 1], [1, 2, 3, 4]), + ], +) +def test_ethosu_left_shift_binary_elemwise( + accel_type, + ifm_shape, + ifm2_shape, +): + dtype = "int32" + + def create_model(): + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype) + c1 = relay.left_shift(ifm, ifm2) + f = relay.Function([ifm, ifm2], c1) + mod = tvm.IRModule() + mod["main"] = f + return mod + + relay_mod = create_model() + mod = partition_for_ethosu(relay_mod) + + # Generate reference data + in_min, in_max = util.get_range_for_dtype_str(dtype) + input_data = { + "ifm": np.random.randint(in_min, high=in_max, size=ifm_shape, dtype=dtype), + "ifm2": np.random.randint(0, high=32, size=ifm2_shape, dtype=dtype), + } + output_data = generate_ref_data(relay_mod, input_data) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, reversed_operands, ofm_dtype", + [ + ([1, 2, 3, 4], [1, 2, 3, 4], False, "int8"), + ([1, 2, 3, 1], [1, 1, 3, 1], False, "int32"), + ([1, 1, 3, 1], [1, 2, 3, 1], True, "int32"), + ], +) +def test_ethosu_right_shift_binary_elemwise( + ifm_shape, ifm2_shape, reversed_operands, accel_type, ofm_dtype +): + dtype = "int32" + + def create_model(): + ifm_count = int(np.prod(ifm_shape)) + ifm2_count = int(np.prod(ifm2_shape)) + + # Create a "partitioned" Relay function + ifms = relay.var("ifms", shape=[ifm_count + ifm2_count], dtype=dtype) + split = relay.split(ifms, [ifm_count]) + ifm = relay.reshape(split[0], newshape=ifm_shape) + ifm2 = relay.reshape(split[1], newshape=ifm2_shape) + shr_op = infra.make_ethosu_binary_elementwise( + ifm, ifm2, ifm_shape[3], ifm2_shape[3], "SHR", ofm_dtype, reversed_operands + ) + + glb_ethosu = relay.GlobalVar("tvmgen_default_ethosu_main_0") + func = ( + relay.Function([ifms], shr_op) + .with_attr("Inline", 1) + .with_attr("Compiler", "ethosu") + .with_attr("global_symbol", "tvmgen_default_ethosu_main_0") + .with_attr("Primitive", 1) + ) + mod = tvm.IRModule() + mod[glb_ethosu] = func + mod = relay.transform.InferType()(mod) + + # Main + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype) + call = relay.Call( + glb_ethosu, + [ + relay.concatenate( + data=( + relay.reshape(ifm, newshape=ifm_count), + relay.reshape(ifm2, newshape=ifm2_count), + ), + axis=0, + ) + ], + ) + mod["main"] = relay.Function([ifm, ifm2], call) + mod = relay.transform.InferType()(mod) + return mod + + mod = create_model() + + # Generate reference data + in_min, in_max = util.get_range_for_dtype_str(dtype) + in_min, in_max = 18, 19 + lhs = np.random.randint(in_min, high=in_max, size=ifm_shape, dtype=dtype) + rhs = np.random.randint(1, high=2, size=ifm2_shape, dtype=dtype) + input_data = { + "ifm": lhs, + "ifm2": rhs, + } + + if reversed_operands: + lhs = np.broadcast_to(lhs, ifm2_shape) + lhs, rhs = rhs, lhs + else: + rhs = np.broadcast_to(rhs, ifm_shape) + + def rounding_right_shift(lhs, rhs): + r = 1 << (rhs - 1) + return (lhs + r) >> rhs + + output_data = np.array( + [rounding_right_shift(x[0], x[1]) for x in zip(lhs.flat, rhs.flat)] + ).astype(ofm_dtype) + + compiled_model = infra.build_source(mod, input_data, [output_data], accel_type) + + imported_modules = compiled_model[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_model, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index fc03a98beb6be..2a84a23930e4d 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -558,5 +558,193 @@ def verify(ext_func): verify(mod["tvmgen_default_ethosu_main_0"]) +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"]) +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, reversed_operands", + [ + ([1, 2, 3, 4], [1, 2, 3, 4], False), + ([1, 2, 3, 4], [1, 1, 3, 1], False), + ([1, 1, 3, 1], [1, 2, 3, 4], True), + ], +) +@pytest.mark.parametrize("activation_function", ["NONE", "RELU"]) +def test_tflite_binary_elemwise_legalize( + operator_type, + ifm_shape, + ifm2_shape, + reversed_operands, + activation_function, +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x, y): + if operator_type == "ADD": + op = tf.math.add(x, y) + elif operator_type == "SUB": + op = tf.math.subtract(x, y) + elif operator_type == "MUL": + op = tf.math.multiply(x, y) + elif operator_type == "MIN": + op = tf.math.minimum(x, y) + elif operator_type == "MAX": + op = tf.math.maximum(x, y) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32), tf.TensorSpec(ifm2_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + data2 = np.random.rand(*tuple(ifm2_shape)) * 2 + yield [data.astype(np.float32), data2.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + out_shape = ifm2_shape if reversed_operands else ifm_shape + shapes = [ifm_shape, ifm2_shape] + ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1) + op = ext_func.body + assert list(op.args[0].checked_type.shape) == shapes[ifm_index] + assert list(op.args[1].checked_type.shape) == shapes[ifm2_index] + assert op.args[0].checked_type.dtype == dtype + assert list(op.checked_type.shape) == out_shape + assert op.checked_type.dtype == dtype + assert op.attrs.operator_type == operator_type + assert op.attrs.reversed_operands == reversed_operands + if activation_function == "RELU": + assert str(op.attrs.activation) == "CLIP" + + if operator_type == "ADD": + rewriter = legalize.AddRewriter() + pattern_table = [ + ( + ethosu.AddParams.composite_name, + ethosu.qnn_add_pattern(), + lambda pat: ethosu.AddParams(pat).is_valid(), + ), + ] + elif operator_type == "SUB": + rewriter = legalize.SubRewriter() + pattern_table = [ + ( + ethosu.SubParams.composite_name, + ethosu.qnn_subtract_pattern(), + lambda pat: ethosu.SubParams(pat).is_valid(), + ), + ] + elif operator_type == "MUL": + rewriter = legalize.MulRewriter() + pattern_table = [ + ( + ethosu.MulParams.composite_name, + ethosu.qnn_mul_pattern(), + lambda pat: ethosu.MulParams(pat).is_valid(), + ), + ] + elif operator_type == "MIN": + rewriter = legalize.MinRewriter() + pattern_table = [ + ( + ethosu.MinParams.composite_name, + ethosu.minimum_pattern(), + lambda pat: ethosu.MinParams(pat).is_valid(), + ), + ] + elif operator_type == "MAX": + rewriter = legalize.MaxRewriter() + pattern_table = [ + ( + ethosu.MaxParams.composite_name, + ethosu.maximum_pattern(), + lambda pat: ethosu.MaxParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"x": ifm_shape, "y": ifm2_shape}, + dtype_dict={"x": dtype, "y": dtype}, + ) + mod = partition_ethosu_by_table(mod, pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethosu_main_0"] + ) + verify(mod["tvmgen_default_ethosu_main_0"]) + + +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, reversed_operands", + [ + ([1, 2, 3, 4], [1, 2, 3, 4], False), + ([1, 2, 3, 4], [1, 1, 3, 1], False), + ([1, 1, 3, 1], [1, 2, 3, 4], True), + ], +) +def test_ethosu_left_shift_binary_elemwise_legalize(ifm_shape, ifm2_shape, reversed_operands): + dtype = "int32" + operator_type = "SHL" + + def create_graph(): + input1 = relay.var("x1", shape=ifm_shape, dtype=dtype) + input2 = relay.var("x2", shape=ifm2_shape, dtype=dtype) + c1 = relay.left_shift(input1, input2) + f = relay.Function([input1, input2], c1) + mod = tvm.IRModule() + mod["main"] = f + return mod + + def verify(ext_func): + out_shape = ifm2_shape if reversed_operands else ifm_shape + shapes = [ifm_shape, ifm2_shape] + ifm_index, ifm2_index = (1, 0) if reversed_operands else (0, 1) + op = ext_func.body + assert list(op.args[0].checked_type.shape) == shapes[ifm_index] + assert list(op.args[1].checked_type.shape) == shapes[ifm2_index] + assert op.args[0].checked_type.dtype == dtype + assert list(op.checked_type.shape) == out_shape + assert op.checked_type.dtype == dtype + assert op.attrs.operator_type == operator_type + assert op.attrs.reversed_operands == reversed_operands + assert str(op.attrs.activation) == "NONE" + + rewriter = legalize.ShlRewriter() + pattern_table = [ + ( + ethosu.ShlParams.composite_name, + ethosu.shl_pattern(), + lambda pat: ethosu.ShlParams(pat).is_valid(), + ), + ] + + mod = create_graph() + mod = partition_ethosu_by_table(mod, pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethosu_main_0"] + ) + verify(mod["tvmgen_default_ethosu_main_0"]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py new file mode 100644 index 0000000000000..6dcd9da395cc4 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_binary_elementwise.py @@ -0,0 +1,335 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir import spec +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from .infra import make_ethosu_binary_elementwise, get_binary_elementwise_args + + +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, ifm_channels, ifm2_channels, ifm_layout, ofm_layout", + [ + ((1, 5, 9, 3), (1, 5, 9, 3), 3, 3, "NHWC", "NHWC"), + ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHCWB16"), + ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHWC"), + ((1, 8, 9, 40), (1, 8, 9, 40), 40, 40, "NHWC", "NHCWB16"), + # Broadcast + ((1, 5, 9, 3), (1, 1, 9, 1), 3, 1, "NHWC", "NHWC"), + ((1, 8, 9, 40), (1, 1, 1, 1), 40, 1, "NHWC", "NHCWB16"), + ], +) +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX"]) +@pytest.mark.parametrize("activation", ["NONE", "CLIP"]) +def test_binary_elementwise_single( + ifm_shape, + ifm2_shape, + ifm_channels, + ifm2_channels, + ifm_layout, + ofm_layout, + operator_type, + activation, +): + dtype = "int8" + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype) + + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + dtype, + False, + activation, + ifm_layout, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(binary_elementwise), binary_elementwise) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_binary_elementwise_args(stmt)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] if ifm_shape[2] != 1 else 1 + ifm_stride_h = ifm_shape[2] * ifm_shape[3] if ifm_shape[1] != 1 else 1 + + ifm2_stride_c = 1 + ifm2_stride_w = ifm2_shape[3] if ifm2_shape[2] != 1 else 1 + ifm2_stride_h = ifm2_shape[2] * ifm2_shape[3] if ifm2_shape[1] != 1 else 1 + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[2] + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + + ifm2_stride_w = 16 + ifm2_stride_c = 16 * ifm2_shape[3] + ifm2_stride_h = 16 * ifm2_shape[2] * ifm2_shape[3] + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[3] + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = ifm_channels if ofm_width > 1 else 1 + ofm_stride_h = ifm_channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((ifm_channels - 1) // 16 + 1) + + serial_binary_elementwise = spec.SerialBinaryElementwise( + ifm=spec.SerialFeatureMap( + data_type=dtype, + height=ifm_shape[1], + width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + channels=ifm_channels, + tile_height_0=ifm_shape[1], + tile_height_1=0, + tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm_stride_h, + stride_w=ifm_stride_w, + stride_c=ifm_stride_c, + ), + ifm2=spec.SerialFeatureMap( + data_type=dtype, + height=ifm2_shape[1], + width=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3], + channels=ifm2_channels, + tile_height_0=ifm2_shape[1], + tile_height_1=0, + tile_width_0=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm2_stride_h, + stride_w=ifm2_stride_w, + stride_c=ifm2_stride_c, + ), + ofm=spec.SerialFeatureMap( + data_type=dtype, + height=ofm_height, + width=ofm_width, + channels=ifm_channels, + tile_height_0=ofm_height, + tile_height_1=0, + tile_width_0=ofm_width, + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ofm_layout, + stride_h=ofm_stride_h, + stride_w=ofm_stride_w, + stride_c=ofm_stride_c, + ), + operator_type=operator_type, + reversed_operands=False, + activation=spec.SerialActivation( + op=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + ), + ) + + assert data[0] == ["ethosu_binary_elementwise"] + list(serial_binary_elementwise) + + +@pytest.mark.parametrize( + "ifm_shape, ifm2_shape, ifm_channels, ifm2_channels, ifm_layout, ofm_layout", + [ + ((1, 5, 9, 3), (1, 5, 9, 3), 3, 3, "NHWC", "NHWC"), + ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHCWB16"), + ((1, 8, 3, 9, 16), (1, 8, 3, 9, 16), 40, 40, "NHCWB16", "NHWC"), + ((1, 8, 9, 40), (1, 8, 9, 40), 40, 40, "NHWC", "NHCWB16"), + # Broadcast + ((1, 5, 9, 3), (1, 1, 9, 1), 3, 1, "NHWC", "NHWC"), + ((1, 8, 9, 40), (1, 1, 1, 1), 40, 1, "NHWC", "NHCWB16"), + ], +) +@pytest.mark.parametrize("operator_type", ["SHR", "SHL"]) +def test_shift_binary_elementwise_single( + ifm_shape, + ifm2_shape, + ifm_channels, + ifm2_channels, + ifm_layout, + ofm_layout, + operator_type, +): + dtype = "int32" + activation = "NONE" # Only NONE is available if the activation type is int32 + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm2_shape, dtype=dtype) + + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + dtype, + False, + "NONE", + ifm_layout, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(binary_elementwise), binary_elementwise) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_binary_elementwise_args(stmt)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] if ifm_shape[2] != 1 else 1 + ifm_stride_h = ifm_shape[2] * ifm_shape[3] if ifm_shape[1] != 1 else 1 + + ifm2_stride_c = 1 + ifm2_stride_w = ifm2_shape[3] if ifm2_shape[2] != 1 else 1 + ifm2_stride_h = ifm2_shape[2] * ifm2_shape[3] if ifm2_shape[1] != 1 else 1 + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[2] + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + + ifm2_stride_w = 16 + ifm2_stride_c = 16 * ifm2_shape[3] + ifm2_stride_h = 16 * ifm2_shape[2] * ifm2_shape[3] + + ofm_height = ifm_shape[1] + ofm_width = ifm_shape[3] + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = ifm_channels if ofm_width > 1 else 1 + ofm_stride_h = ifm_channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((ifm_channels - 1) // 16 + 1) + + serial_binary_elementwise = spec.SerialBinaryElementwise( + ifm=spec.SerialFeatureMap( + data_type=dtype, + height=ifm_shape[1], + width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + channels=ifm_channels, + tile_height_0=ifm_shape[1], + tile_height_1=0, + tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm_stride_h, + stride_w=ifm_stride_w, + stride_c=ifm_stride_c, + ), + ifm2=spec.SerialFeatureMap( + data_type=dtype, + height=ifm2_shape[1], + width=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3], + channels=ifm2_channels, + tile_height_0=ifm2_shape[1], + tile_height_1=0, + tile_width_0=ifm2_shape[2] if ifm_layout == "NHWC" else ifm2_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm2_stride_h, + stride_w=ifm2_stride_w, + stride_c=ifm2_stride_c, + ), + ofm=spec.SerialFeatureMap( + data_type=dtype, + height=ofm_height, + width=ofm_width, + channels=ifm_channels, + tile_height_0=ofm_height, + tile_height_1=0, + tile_width_0=ofm_width, + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ofm_layout, + stride_h=ofm_stride_h, + stride_w=ofm_stride_w, + stride_c=ofm_stride_c, + ), + operator_type=operator_type, + reversed_operands=False, + activation=spec.SerialActivation( + op=activation, + clip_min=0, + clip_max=0, + ), + ) + + assert data[0] == ["ethosu_binary_elementwise"] + list(serial_binary_elementwise) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index f4b83a4577cc6..ab1bad226ae6d 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -913,5 +913,439 @@ def populate_ethosu_pooling_calls(stmt): assert npu_op.ifm_upscale == vapi.NpuResamplingMode.NONE +# fmt: off +"""A ethosu_binary_elementwise ADD tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseAdd: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer( + placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1 + ) + ethosu_write_2 = T.match_buffer( + ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1 + ) + # body + T.evaluate(T.call_extern( "ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "ADD", 0, "CLIP", 10, 100, dtype="int8")) + + __tvm_meta__ = None +# fmt: on + +# fmt: off +"""A ethosu_binary_elementwise SUB tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseSub: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SUB", 0, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + +# fmt: off +"""A ethosu_binary_elementwise MUL tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMul: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MUL", 0, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise MIN tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMin: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MIN", 0, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise Max tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMax: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 9, 3, 5, 0, 9, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "MAX", 0, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise SHR tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseShr: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHR", 0, "NONE", 0, 0, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise SHL tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseShl: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [270], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 9, 3], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", placeholder_2.data, 135), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int32", 5, 9, 3, 5, 0, 9, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "SHL", 0, "CLIP", 10, 100, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX", "SHR", "SHL"]) +def test_translate_ethosu_binary_elementwise(operator_type): + if operator_type == "SHR" or operator_type == "SHL": + data_type = vapi.NpuDataType.INT32 + data_type_bytes = 4 + else: + data_type = vapi.NpuDataType.INT8 + data_type_bytes = 1 + + def extract_ethosu_binary_elementwise_call_extern(mod): + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_binary_elementwise_calls = list() + + def populate_ethosu_binary_elementwise_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_binary_elementwise" + ): + ethosu_binary_elementwise_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_binary_elementwise_calls) + return ethosu_binary_elementwise_calls[0] + + if operator_type == "ADD": + binary_elementwise = SingleEthosuBinaryElementwiseAdd + elif operator_type == "SUB": + binary_elementwise = SingleEthosuBinaryElementwiseSub + elif operator_type == "MUL": + binary_elementwise = SingleEthosuBinaryElementwiseMul + elif operator_type == "MIN": + binary_elementwise = SingleEthosuBinaryElementwiseMin + elif operator_type == "MAX": + binary_elementwise = SingleEthosuBinaryElementwiseMax + elif operator_type == "SHR": + binary_elementwise = SingleEthosuBinaryElementwiseShr + elif operator_type == "SHL": + binary_elementwise = SingleEthosuBinaryElementwiseShl + binary_elementwise_call = extract_ethosu_binary_elementwise_call_extern(binary_elementwise) + npu_op = tir_to_cs_translator.translate_ethosu_binary_elementwise(binary_elementwise_call) + + # Compare IFM + assert npu_op.ifm.data_type == data_type + assert npu_op.ifm.shape == vapi.NpuShape3D(5, 9, 3) + assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_0 + assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_1 + assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).width_0 + assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm.strides == vapi.NpuShape3D( + 27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes + ) + # Compare IFM2 + assert npu_op.ifm2.data_type == data_type + assert npu_op.ifm2.shape == vapi.NpuShape3D(5, 9, 3) + assert npu_op.ifm2.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_0 + assert npu_op.ifm2.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_1 + assert npu_op.ifm2.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).width_0 + assert npu_op.ifm2.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm2.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm2.strides == vapi.NpuShape3D( + 27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes + ) + # Compare OFM + assert npu_op.ofm.data_type == data_type + assert npu_op.ofm.shape == vapi.NpuShape3D(5, 9, 3) + assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_0 + assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_1 + assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).width_0 + assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ofm.layout == vapi.NpuLayout.NHWC + assert npu_op.ofm.strides == vapi.NpuShape3D( + 27 * data_type_bytes, 3 * data_type_bytes, 1 * data_type_bytes + ) + # Compare op type + if operator_type == "ADD": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.ADD + elif operator_type == "SUB": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SUB + elif operator_type == "MUL": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MUL + elif operator_type == "MIN": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MIN + elif operator_type == "MAX": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MAX + elif operator_type == "SHR": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHR + elif operator_type == "SHL": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHL + # Compare reversed_operands + assert npu_op.reversed_operands == False + # Compare activation + if operator_type == "SHR": + assert npu_op.activation is None + else: + assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + assert npu_op.activation.min == 10 + assert npu_op.activation.max == 100 + + +# fmt: off +"""A ethosu_binary_elementwise ADD with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseAddBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "ADD", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + +# fmt: off +"""A ethosu_binary_elementwise SUB with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseSubBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SUB", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + +# fmt: off +"""A ethosu_binary_elementwise MUL with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMulBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MUL", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise MIN with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMinBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MIN", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise MAX with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseMaxBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 2, 3, 4, 2, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int8", 1, 3, 1, 1, 0, 3, T.load("int8", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int8", 2, 3, 4, 2, 0, 3, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "MAX", 1, "CLIP", 10, 100, dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise SHR with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseShrBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHR", 1, "NONE", 0, 0, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A ethosu_binary_elementwise SHL with broadcasting tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuBinaryElementwiseShlBroadcasting: + @T.prim_func + def main(placeholder: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [27], dtype="int32", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 2, 3, 4], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_binary_elementwise", "int32", 2, 3, 4, 2, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "int32", 1, 3, 1, 1, 0, 3, T.load("int32", placeholder_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 1, 1, 1, "int32", 2, 3, 4, 2, 0, 3, T.load("int32", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 12, 4, 1, "SHL", 1, "CLIP", 10, 100, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize("operator_type", ["ADD", "SUB", "MUL", "MIN", "MAX", "SHR", "SHL"]) +def test_translate_ethosu_binary_elementwise_broadcasting(operator_type): + if operator_type == "SHR" or operator_type == "SHL": + data_type = vapi.NpuDataType.INT32 + data_type_bytes = 4 + else: + data_type = vapi.NpuDataType.INT8 + data_type_bytes = 1 + + def extract_ethosu_binary_elementwise_broadcasting_call_extern(mod): + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_binary_elementwise_calls = list() + + def populate_ethosu_binary_elementwise_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_binary_elementwise" + ): + ethosu_binary_elementwise_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_binary_elementwise_calls) + return ethosu_binary_elementwise_calls[0] + + if operator_type == "ADD": + binary_elementwise = SingleEthosuBinaryElementwiseAddBroadcasting + elif operator_type == "SUB": + binary_elementwise = SingleEthosuBinaryElementwiseSubBroadcasting + elif operator_type == "MUL": + binary_elementwise = SingleEthosuBinaryElementwiseMulBroadcasting + elif operator_type == "MIN": + binary_elementwise = SingleEthosuBinaryElementwiseMinBroadcasting + elif operator_type == "MAX": + binary_elementwise = SingleEthosuBinaryElementwiseMaxBroadcasting + elif operator_type == "SHR": + binary_elementwise = SingleEthosuBinaryElementwiseShrBroadcasting + elif operator_type == "SHL": + binary_elementwise = SingleEthosuBinaryElementwiseShlBroadcasting + binary_elementwise_call = extract_ethosu_binary_elementwise_broadcasting_call_extern( + binary_elementwise + ) + npu_op = tir_to_cs_translator.translate_ethosu_binary_elementwise(binary_elementwise_call) + + # Compare IFM + assert npu_op.ifm.data_type == data_type + assert npu_op.ifm.shape == vapi.NpuShape3D(2, 3, 4) + assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).height_0 + assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).height_1 + assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).width_0 + assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm.strides == vapi.NpuShape3D( + 12 * data_type_bytes, 4 * data_type_bytes, 1 * data_type_bytes + ) + # Compare IFM2 + assert npu_op.ifm2.data_type == data_type + assert npu_op.ifm2.shape == vapi.NpuShape3D(1, 3, 1) + assert npu_op.ifm2.tiles.height_0 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0, 0]).height_0 + assert npu_op.ifm2.tiles.height_1 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0, 0]).height_1 + assert npu_op.ifm2.tiles.width_0 == vapi.NpuTileBox(1, 0, 3, [0, 0, 0, 0]).width_0 + assert npu_op.ifm2.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm2.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm2.strides == vapi.NpuShape3D( + 1 * data_type_bytes, 1 * data_type_bytes, 1 * data_type_bytes + ) + # Compare OFM + assert npu_op.ofm.data_type == data_type + assert npu_op.ofm.shape == vapi.NpuShape3D(2, 3, 4) + assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).height_0 + assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).height_1 + assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(2, 0, 3, [0, 0, 0, 0]).width_0 + assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ofm.layout == vapi.NpuLayout.NHWC + assert npu_op.ofm.strides == vapi.NpuShape3D( + 12 * data_type_bytes, 4 * data_type_bytes, 1 * data_type_bytes + ) + # Compare op type + if operator_type == "ADD": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.ADD + elif operator_type == "SUB": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SUB + elif operator_type == "MUL": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MUL + elif operator_type == "MIN": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MIN + elif operator_type == "MAX": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.MAX + elif operator_type == "SHR": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHR + elif operator_type == "SHL": + assert npu_op.sub_op_type == vapi.NpuElementWiseOp.SHL + # Compare reversed_operands + assert npu_op.reversed_operands == True + # Compare activation + + if operator_type == "SHR": + assert npu_op.activation is None + else: + assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + assert npu_op.activation.min == 10 + assert npu_op.activation.max == 100 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py index ecbe31b3cbd35..e068439fcee58 100644 --- a/tests/python/contrib/test_ethosu/test_type_inference.py +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -24,6 +24,7 @@ from .infra import make_ethosu_conv2d from .infra import make_ethosu_depthwise_conv2d from .infra import make_ethosu_pooling +from .infra import make_ethosu_binary_elementwise @pytest.mark.parametrize( @@ -226,5 +227,120 @@ def test_ethosu_pooling_invalid_dtype(): run_opt_pass(func, relay.transform.InferType()) +@pytest.mark.parametrize( + "ifm_shape, ifm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), "NHCWB16")] +) +@pytest.mark.parametrize( + "ofm_shape, ofm_layout", [((1, 4, 5, 33), "NHWC"), ((1, 4, 3, 5, 16), "NHCWB16")] +) +def test_ethosu_binary_elementwise_type_inference( + ifm_shape, + ifm_layout, + ofm_shape, + ofm_layout, +): + dtype = "int8" + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype) + operator_type = "ADD" + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + dtype, + ifm_layout=ifm_layout, + ifm2_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + func = run_opt_pass(func, relay.transform.InferType()) + assert tuple(func.body.checked_type.shape) == ofm_shape + assert func.body.checked_type.dtype == dtype + + +def test_ethosu_binary_elementwise_invalid_operator_type(): + invalid_operator_type = "A" + ifm_shape = [1, 4, 5, 33] + dtype = "int8" + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype) + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + invalid_operator_type, + dtype, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +def test_ethosu_binary_elementwise_invalid_data_types(): + dtype = "int8" + dtype2 = "int32" + operator_type = "ADD" + ifm_shape = [1, 4, 5, 33] + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=dtype2) + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + dtype, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +@pytest.mark.parametrize("operator_type", ["MIN", "MAX"]) +def test_ethosu_binary_elementwise_min_max_invalid_data_type(operator_type): + invalid_dtype = "int32" + ifm_shape = [1, 4, 5, 33] + ifm = relay.var("ifm", shape=ifm_shape, dtype=invalid_dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=invalid_dtype) + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + invalid_dtype, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +@pytest.mark.parametrize("invalid_dtype", ["int8", "uint8"]) +@pytest.mark.parametrize("operator_type", ["RHS", "SHR"]) +def test_ethosu_binary_elementwise_shift_invalid_data_type(invalid_dtype, operator_type): + ifm_shape = [1, 4, 5, 33] + ifm = relay.var("ifm", shape=ifm_shape, dtype=invalid_dtype) + ifm2 = relay.var("ifm2", shape=ifm_shape, dtype=invalid_dtype) + ifm_channels, ifm2_channels = 33, 33 + binary_elementwise = make_ethosu_binary_elementwise( + ifm, + ifm2, + ifm_channels, + ifm2_channels, + operator_type, + invalid_dtype, + ) + func = relay.Function([ifm, ifm2], binary_elementwise) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + if __name__ == "__main__": pytest.main([__file__])