diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index b970aec62c6f..c4b70c130d4e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -16,7 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel, no-value-for-parameter """A set of passes to legalize some of operations for the NPU""" -from typing import List +from typing import List, Type + import numpy as np # type: ignore import tvm # type: ignore @@ -26,6 +27,7 @@ from tvm.relay.dataflow_pattern import wildcard from tvm.relay.dataflow_pattern import is_op from tvm.relay.dataflow_pattern import rewrite +from tvm.relay.dataflow_pattern import CallPattern from tvm.relay.backend.contrib.ethosu import op as ethosu_ops # type: ignore from tvm.relay.backend.contrib.ethosu.errors import UnsupportedLayout # type: ignore from tvm.relay.backend.contrib.ethosu import vela_api @@ -121,7 +123,7 @@ def __call__(self, *args, **kwargs): pass -class EthosUConv2DRewriter(DFPatternCallback): +class Conv2DRewriter(DFPatternCallback): """Convert conv2d related composite functions into ethosu_conv2d operators""" def __init__(self): @@ -193,14 +195,14 @@ def callback( @ir.transform.module_pass(opt_level=1) -class LegalizeEthosUConv2D: - """This is the pass that wraps the EthosUConv2DRewriter""" +class LegalizeConv2D: + """This is the pass that wraps the Conv2DRewriter""" def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): - func = rewrite(EthosUConv2DRewriter(), func) + func = rewrite(Conv2DRewriter(), func) mod.update_func(global_var, func) return mod @@ -208,7 +210,7 @@ def __call__(self, *args, **kwargs): pass -class EthosuDepthwiseConv2DRewriter(DFPatternCallback): +class DepthwiseConv2DRewriter(DFPatternCallback): """Convert ethosu.qnn_depthwise_conv2d composite functions to ethosu_depthwise_conv2d operators""" @@ -286,14 +288,124 @@ def callback( @ir.transform.module_pass(opt_level=1) -class LegalizeEthosUDepthwiseConv2D: - """This is the pass that wraps the EthosUDepthwiseConv2DRewriter""" +class LegalizeDepthwiseConv2D: + """This is the pass that wraps the DepthwiseConv2DRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(DepthwiseConv2DRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class PoolingRewriter(DFPatternCallback): + """Convert ethosu.avgpool2d and ethosu.maxpool2d composite functions to + ethosu_pooling operators""" + + def __init__( + self, + params_class: Type, + pattern: CallPattern, + ): + super().__init__(require_type=True) + self.params_class = params_class + self.pattern = pattern + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + params = self.params_class(post.op.body) + params.ifm.tensor = post.args[0] + channels_map = { + "NHWC": 3, + } + if str(params.ofm.layout) not in channels_map.keys(): + raise UnsupportedLayout(str(params.ofm.layout)) + + activation_map = {"clip": "CLIP"} + if params.activation: + activation = activation_map[params.activation.op.name] + clip_min = int(params.activation.attrs.a_min) + clip_max = int(params.activation.attrs.a_max) + else: + activation = "NONE" + clip_min = 0 + clip_max = 0 + + # Activations requiring LUT is currently not supported, so setting it to an empty list + lut = relay.const([], dtype="int8") + + return ethosu_ops.ethosu_pooling( + ifm=post.args[0], + lut=lut, + pooling_type=params.pooling_type, + ifm_scale=params.ifm.q_params.scale_f32, + ifm_zero_point=params.ifm.q_params.zero_point, + ofm_scale=params.ofm.q_params.scale_f32, + ofm_zero_point=params.ofm.q_params.zero_point, + pool_shape=params.pool_shape, + ofm_channels=params.ofm.shape[channels_map[str(params.ofm.layout)]], + strides=params.strides, + padding=params.padding, + activation=activation, + clip_min=clip_min, + clip_max=clip_max, + upscale="NONE", + ifm_layout=str(params.ifm.layout), + ofm_layout=str(params.ofm.layout), + ) + + +class MaxPoolingRewriter(PoolingRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.MaxPool2DParams, + pattern=( + wildcard().has_attr({"Composite": ethosu_patterns.MaxPool2DParams.composite_name}) + )(wildcard()), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeMaxPooling: + """This is the pass that wraps the MaxPoolingRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(MaxPoolingRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + +class AvgPoolingRewriter(PoolingRewriter): + def __init__(self): + super().__init__( + params_class=ethosu_patterns.AvgPool2DParams, + pattern=( + wildcard().has_attr({"Composite": ethosu_patterns.AvgPool2DParams.composite_name}) + )(wildcard()), + ) + + +@ir.transform.module_pass(opt_level=1) +class LegalizeAvgPooling: + """This is the pass that wraps the AvgPoolingRewriter""" def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: for global_var, func in mod.functions.items(): - func = rewrite(EthosuDepthwiseConv2DRewriter(), func) + func = rewrite(AvgPoolingRewriter(), func) mod.update_func(global_var, func) return mod @@ -312,8 +424,10 @@ def transform_module( self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext ) -> tvm.ir.IRModule: mod = LegalizeSplit()(mod) - mod = LegalizeEthosUConv2D()(mod) - mod = LegalizeEthosUDepthwiseConv2D()(mod) + mod = LegalizeConv2D()(mod) + mod = LegalizeDepthwiseConv2D()(mod) + mod = LegalizeMaxPooling()(mod) + mod = LegalizeAvgPooling()(mod) return mod def __call__(self, *args, **kwargs): diff --git a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py index 1063db6a04c5..c9aa59b5dd2b 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/__init__.py @@ -18,3 +18,4 @@ from .convolution import ethosu_conv2d from .depthwise import ethosu_depthwise_conv2d +from .pooling import ethosu_pooling diff --git a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py index b159830ceaa9..7fb054edb6b6 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/convolution.py @@ -113,7 +113,7 @@ def ethosu_conv2d( ofm_layout: str = "NHWC", ) -> tvm.relay.Call: """This is a quantized 2D convolution operation as supported by the - the NPU. It accepts either NHWC or NHCWB16 format + Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format for the input data and OHWI format for the kernel weights. Reference: https://developer.arm.com/documentation/102420/0200/ @@ -132,7 +132,7 @@ def ethosu_conv2d( scale_bias : tvm.relay.Expr The packed per-channel weight scale and bias tensor. lut : tvm.relay.Expr - The look-up table values to use if activation = "LUT". + The look-up table of values to use if activation = "LUT". ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int @@ -146,7 +146,7 @@ def ethosu_conv2d( kernel_shape : tuple of int The 2 dimensional kernel shape as (kernel_height, kernel_width). ofm_channels : int - The number of OFM channels. + The number of the Output Feature Map channels. strides : tuple of int, optional The 2 dimensional strides as (stride_height, stride_width). padding : tuple of int, optional diff --git a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py index abcddf90b97c..d1b49ef6e898 100644 --- a/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/op/depthwise.py @@ -112,8 +112,8 @@ def ethosu_depthwise_conv2d( ifm_layout: str = "NHWC", ofm_layout: str = "NHWC", ) -> tvm.relay.Call: - """This is a quantized 2D depthwise convolution operation as supported - by the NPU. It accepts either NHWC or NHCWB16 format + """This is a quantized 2D depthwise convolution operation as supported by the + Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format for the input data and OHWI format for the kernel weights. Reference: https://developer.arm.com/documentation/102420/0200/ @@ -132,7 +132,7 @@ def ethosu_depthwise_conv2d( scale_bias : tvm.relay.Expr The packed per-channel weight scale and bias tensor. lut : tvm.relay.Expr - The look-up table values to use if activation = "LUT" + The look-up table of values to use if activation = "LUT" ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int @@ -146,7 +146,7 @@ def ethosu_depthwise_conv2d( kernel_shape : tuple of int The 2 dimensional kernel shape as (kernel_height, kernel_width). ofm_channels : int - The number of OFM channels. + The number of the Output Feature Map channels. strides : tuple of int, optional The 2 dimensional strides as (stride_height, stride_width). padding : tuple of int, optional diff --git a/python/tvm/relay/backend/contrib/ethosu/op/pooling.py b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py new file mode 100644 index 000000000000..f344f61f1dd1 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/op/pooling.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +"""Relay operators for pooling""" +from typing import Tuple + +import tvm +from tvm.relay.op import _make +from tvm.topi.generic import schedule_injective +from tvm.relay.op.op import OpStrategy +from tvm.relay.op import strategy as _strategy + +from ..te import pooling_compute + + +def _extract_ethosu_pooling_params(attrs, args): + """Get the parameters necessary to construct a ethosu_pooling compute TE + from a ethosu_pooling Relay call.""" + ifm = args[0] + lut = args[1] + pooling_type = attrs.pooling_type + ifm_scale = attrs.ifm_scale + ifm_zero_point = attrs.ifm_zero_point + ofm_scale = attrs.ofm_scale + ofm_zero_point = attrs.ofm_zero_point + pool_shape = attrs.pool_shape + ofm_channels = attrs.ofm_channels + strides = attrs.strides + padding = attrs.padding + activation = attrs.activation + clip_min = attrs.clip_min + clip_max = attrs.clip_max + upscale = attrs.upscale + ifm_layout = attrs.ifm_layout + ofm_layout = attrs.ofm_layout + + return ( + ifm, + lut, + pooling_type, + ifm_scale, + ifm_zero_point, + ofm_scale, + ofm_zero_point, + pool_shape, + ofm_channels, + strides, + padding, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) + + +@tvm.ir.register_op_attr("contrib.ethosu.pooling", "FTVMCompute") +def create_ethosu_pooling_compute(attrs, args, out_type): + """Create an ethosu_pooling compute op.""" + params = _extract_ethosu_pooling_params(attrs, args) + op = pooling_compute(*params) + return [op] + + +@tvm.ir.register_op_attr("contrib.ethosu.pooling", "FTVMStrategy") +def pooling_strategy_ethosu(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + create_ethosu_pooling_compute, + _strategy.wrap_topi_schedule(schedule_injective), + name="ethosu_pooling", + ) + return strategy + + +def ethosu_pooling( + ifm: tvm.relay.Expr, + lut: tvm.relay.Expr, + pooling_type: str, + ifm_scale: float, + ifm_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + pool_shape: Tuple[int, int], + ofm_channels: int, + strides: Tuple[int, int] = (1, 1), + padding: Tuple[int, int, int, int] = (0, 0, 0, 0), + activation: str = "NONE", + clip_min: int = 0, + clip_max: int = 0, + upscale: str = "NONE", + ifm_layout: str = "NHWC", + ofm_layout: str = "NHWC", +) -> tvm.relay.Call: + """This is a quantized 2D pooling operation as supported by the + Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format + for the input data. + + Parameters + ---------- + ifm : tvm.relay.Expr + The Input Feature Map tensor (IFM). + lut : tvm.relay.Expr + The look-up table of values to use if activation = "LUT". + pooling_type: str + The type of the pooling. "AVG" - average pool, "MAX" - max pool. + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + pool_shape : tuple of int + The 2 dimensional pool shape as (pool_shape_height, pool_shape_width). + ofm_channels : int + The number of the Output Feature Map channels + strides : tuple of int, optional + The 2 dimensional strides as (stride_height, stride_width). + padding : tuple of int, optional + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + activation : str, optional + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int, optional + The minimum clipping value if activation = "CLIP". + clip_max : int, optional + The maximum clipping value if activation = "CLIP". + upscale: str, optional + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str, optional + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str, optional + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + out : tvm.relay.Call + A call to the ethosu_pooling op. + """ + return _make.ethosu_pooling( + ifm, + lut, + pooling_type, + ifm_scale, + ifm_zero_point, + ofm_scale, + ofm_zero_point, + pool_shape, + ofm_channels, + strides, + padding, + activation, + clip_min, + clip_max, + upscale, + ifm_layout, + ofm_layout, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py index 5dcdd4dcf602..e2eb28f8f915 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/__init__.py @@ -18,3 +18,4 @@ from .convolution import * from .depthwise import * +from .pooling import * diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 26f7ea979219..1a7f96ace8eb 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -53,7 +53,7 @@ def conv2d_compute( scale_bias : te.Tensor The packed per-channel weight scale and bias tensor. lut : te.Tensor - The look-up table values to use if activation = "LUT". + The look-up table of values to use if activation = "LUT". ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int diff --git a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py index 35ae7f9a700a..6c139c958fa1 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/depthwise.py @@ -53,7 +53,7 @@ def depthwise_conv2d_compute( scale_bias : te.Tensor The packed per-channel weight scale and bias tensor. lut : te.Tensor - The look-up table values to use if activation = "LUT". + The look-up table of values to use if activation = "LUT". ifm_scale : float The quantization scale for the Input Feature Map tensor. ifm_zero_point : int diff --git a/python/tvm/relay/backend/contrib/ethosu/te/pooling.py b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py new file mode 100644 index 000000000000..2f090f289da2 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/te/pooling.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-argument +"""Tensor Expressions for poolings""" +from typing import Tuple + +from tvm import te +from .dma import dma_ofm_compute, dma_ifm_compute + + +def pooling_compute( + ifm: te.Tensor, + lut: te.Tensor, + pooling_type: str, + ifm_scale: float, + ifm_zero_point: int, + ofm_scale: float, + ofm_zero_point: int, + pool_shape: Tuple[int, int], + ofm_channels: int, + strides: Tuple[int, int], + padding: Tuple[int, int, int, int], + activation: str, + clip_min: int, + clip_max: int, + upscale: str, + ifm_layout: str, + ofm_layout: str, +) -> te.Tensor: + """A compute operator representing the capabilities of pooling for the NPU. + + Parameters + ---------- + ifm : te.Tensor + The Input Feature Map tensor (IFM). + lut : te.Tensor + The look-up table of values to use if activation = "LUT". + pooling_type: str + The type of the pooling. "AVG" - average pool, "MAX" - max pool. + ifm_scale : float + The quantization scale for the Input Feature Map tensor. + ifm_zero_point : int + The quantization zero point for the Input Feature Map tensor. + ofm_scale : float + The quantization scale for the Output Feature Map tensor. + ofm_zero_point : int + The quantization zero point for the Output Feature Map tensor. + pool_shape : Tuple[int, int] + The 2 dimensional pool shape as (pool_shape_height, pool_shape_width). + ofm_channels : int + The number of the Output Feature Map channels + strides : Tuple[int, int] + The 2 dimensional strides as (stride_height, stride_width). + padding : Tuple[int, int, int, int] + The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right). + activation : str + The activation function to use. + "NONE" - no activation function. + "CLIP" - clip the output between clip_min and clip_max. + "TANH" - tanh activation function. + "SIGMOID" - sigmoid activation function. + "LUT" - use a look-up table to perform the activation function. + clip_min : int + The minimum clipping value if activation = "CLIP". + clip_max : int + The maximum clipping value if activation = "CLIP". + upscale : str + The 2x2 upscaling mode to apply to the Input Feature Map tensor. + "NONE" - no upscaling. + "NEAREST" - upscale using nearest neighbour. + "ZEROS" - upscale using zeros. + ifm_layout : str + The layout of the Input Feature Map tensor. Can be "NHWC" or "NHCWB16". + ofm_layout : str + The layout of the Output Feature Map tensor. Can be "NHWC" or "NHCWB16". + + Returns + ------- + te.Tensor + The OFM tensor. + """ + stride_h, stride_w = strides + pool_shape_h, pool_shape_w = pool_shape + + # Compute operation for the IFM DMA pipeline + dmaed_ifm = dma_ifm_compute(ifm, ifm_layout, ifm_zero_point, ifm_scale, ofm_channels, padding) + + # Pooling compute operation + ofm_height = (dmaed_ifm.shape[1] - pool_shape_h) // stride_h + 1 + ofm_width = (dmaed_ifm.shape[2] - pool_shape_w) // stride_w + 1 + rh = te.reduce_axis((0, pool_shape_h), name="ry") + rw = te.reduce_axis((0, pool_shape_w), name="rx") + + pooling_attrs = { + "op": "ethosu_pooling", + "pooling_type": pooling_type, + "stride_h": stride_h, + "stride_w": stride_w, + "activation": activation, + "clip_min": clip_min, + "clip_max": clip_max, + "upscale": upscale, + } + + pooling = te.compute( + (1, ofm_height, ofm_width, ofm_channels), + lambda nn, hh, ww, cc: te.max( + dmaed_ifm(nn, hh * stride_h + rh, ww * stride_w + rw, cc).astype(ifm.dtype), + axis=[rh, rw], + ), + name="ethosu_pooling", + attrs=pooling_attrs, + ) + + # Compute operation for the OFM DMA pipeline + return dma_ofm_compute(pooling, ofm_layout, ofm_zero_point, ofm_scale, ofm_channels) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index bc95a9a3bab7..b68a5ad14a6f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -29,7 +29,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. The resulting TIR module will contain a single function - that consists of a sequence of tir.extern_calls to NPU + that consists of a sequence of tir.call_extern to NPU operations. Parameters diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index 761c8aad7bb1..2f5d7abd260d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -22,6 +22,7 @@ from tvm.relay.backend.contrib.ethosu import vela_api from .convolution import get_conv2d_params from .depthwise import get_depthwise_conv2d_params +from .pooling import get_pooling_params from .transform import get_copy_params from .utils import get_weights_pointer, get_scale_bias_pointer @@ -54,6 +55,7 @@ def ReplaceOperators(): "ethosu_conv2d": get_conv2d_params, "ethosu_copy": get_copy_params, "ethosu_depthwise_conv2d": get_depthwise_conv2d_params, + "ethosu_pooling": get_pooling_params, } pointer_to_producer = {} pointer_to_consumer = {} diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py new file mode 100644 index 000000000000..30f9bb3d981e --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the pooling operators in TIR.""" +from typing import Dict, Tuple +import tvm +from .utils import get_outer_loops, get_op_attrs +from .dma import get_ifm_params, get_ofm_params +from .spec import SerialKernel, SerialActivation, SerialPooling + + +def get_pooling_params( + stmt: tvm.tir.AttrStmt, + producers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], + consumers: Dict[tvm.tir.Var, tvm.tir.AttrStmt], +) -> Tuple[SerialPooling, tvm.tir.Var, tvm.tir.Var]: + """Get the parameters necessary to construct a call_extern for a pooling. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a convolution loop nest. + producers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : Dict[tvm.tir.Var, tvm.tir.AttrStmt] + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + SerialPooling + The parameters needed to construct a 2D convolution. + output_pointer : tvm.tir.Var + The output pointer of the convolution operation. + replace_pointer : tvm.tir.Var + The output pointer of the DMA write operation, which is to replace + the convolution output pointer. + """ + attrs, body = get_op_attrs(stmt) + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + rh = inner + rw = rh.body + compute = rw.body.value.b + input_pointer = compute.buffer_var + output_pointer = rw.body.buffer_var + # Get feature map info + serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) + serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + # Get kernel info + serial_kernel = SerialKernel( + width=int(rw.extent), + height=int(rh.extent), + stride_w=int(attrs["stride_w"]), + stride_h=int(attrs["stride_h"]), + dilation_w=1, + dilation_h=1, + ) + + # Get activation info + serial_activation = SerialActivation( + op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] + ) + return ( + SerialPooling( + ifm=serial_ifm, + ofm=serial_ofm, + pooling_type=attrs["pooling_type"], + pool_shape=serial_kernel, + padding=serial_padding, + activation=serial_activation, + upscale="NONE", + ), + output_pointer, + replace_pointer, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index bcae01a10214..861669588f72 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -85,10 +85,10 @@ def translate(tir_module, params): """ buffer_info = extract_buffer_info(tir_module, params) - extern_calls = extract_extern_calls(tir_module) + call_extern_list = extract_call_extern_list(tir_module) _npu_ops = list() - for extern_call in extern_calls: - _npu_ops.append(translate_ethosu_tir_extern_call(extern_call)) + for call_extern in call_extern_list: + _npu_ops.append(translate_ethosu_tir_call_extern(call_extern)) _npu_ops, constant_tensor, scratch_size = assign_addresses(buffer_info, _npu_ops) target_accel_config = vela_api.get_accelerator_config() cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_config) @@ -97,7 +97,7 @@ def translate(tir_module, params): return payload.hex(), hex_value, scratch_size -def extract_extern_calls(mod): +def extract_call_extern_list(mod): """This function will obtain all extern calls from a TIR module Parameters @@ -115,14 +115,14 @@ def extract_extern_calls(mod): assert len(mod.functions.items()) == 1 primfunc = mod.functions.items()[0][1] - extern_calls = list() + call_extern_list = list() - def populate_extern_calls(stmt): + def populate_call_extern_list(stmt): if isinstance(stmt, tvm.tir.Call) and stmt.op.name == "tir.call_extern": - extern_calls.append(stmt) + call_extern_list.append(stmt) - stmt_functor.post_order_visit(primfunc.body, populate_extern_calls) - return extern_calls + stmt_functor.post_order_visit(primfunc.body, populate_call_extern_list) + return call_extern_list def extract_buffer_info( @@ -295,18 +295,19 @@ def classify_io(buffer): return npu_ops, constant_tensor, scratch_size -def translate_ethosu_tir_extern_call(tir_extern_call): +def translate_ethosu_tir_call_extern(tir_call_extern): """This is a dispatcher function to dispatch correct translation call depending on the extern call's first argument""" - supported_extern_calls = { + supported_call_extern = { "ethosu_conv2d": translate_ethosu_conv2d, "ethosu_copy": translate_ethosu_copy, "ethosu_depthwise_conv2d": translate_ethosu_depthwise_conv2d, + "ethosu_pooling": translate_ethosu_pooling, } - ext_call_type = tir_extern_call.args[0].value - assert ext_call_type in supported_extern_calls.keys(), f"{ext_call_type} is not yet supported" - npu_op = supported_extern_calls[ext_call_type](tir_extern_call) + ext_call_type = tir_call_extern.args[0].value + assert ext_call_type in supported_call_extern.keys(), f"{ext_call_type} is not yet supported" + npu_op = supported_call_extern[ext_call_type](tir_call_extern) # Some conversions return additional outputs # if they are needed, the caller should use the function directly if isinstance(npu_op, tuple): @@ -314,20 +315,21 @@ def translate_ethosu_tir_extern_call(tir_extern_call): return npu_op -def translate_ethosu_copy(tir_extern_call): - """This function will translate a tir ethosu_copy extern_call - as produced by Relay to TIR compilation. +def translate_ethosu_copy(tir_call_extern: tvm.tir.Call) -> vapi.NpuDmaOperation: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. + Parameters ---------- - tir_extern_call : tvm.tir.Call + tir_call_extern : tvm.tir.Call Returns ------- ethosu.vela.api.NpuDmaOperation The vela object containing the params of ethosu_copy """ - # We skip the first element as it is the extern_call function name - serial_object = spec.create_serial_object(spec.SerialCopy, tir_extern_call.args[1:]) + # We skip the first element as it is the call_extern function name + serial_object = spec.create_serial_object(spec.SerialCopy, tir_call_extern.args[1:]) return _create_npu_dma_op(serial_object) @@ -360,7 +362,7 @@ def translate_ethosu_conv2d(tir_call_extern: tvm.tir.Call) -> Tuple[vapi.NpuConv Parameters ---------- tir_call_extern : tvm.tir.Call - This should be a TIR call_extern that has a agreed upon ordering + This should be a TIR call_extern that has agreed upon ordering for TIR Compiler. See Serial2DConvolution in tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. @@ -370,7 +372,6 @@ def translate_ethosu_conv2d(tir_call_extern: tvm.tir.Call) -> Tuple[vapi.NpuConv The vela object containing the params of ethosu_conv2d weights_zero_point : int The zero point of the weights - """ # We skip the first element as it is the call_extern function name serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_call_extern.args[1:]) @@ -417,25 +418,27 @@ def _create_npu_op_conv2d( return npu_conv2d_op, weights_zero_point -def translate_ethosu_depthwise_conv2d(tir_extern_call): - """This function will translate a tir extern_call - as produced by Relay to TIR compilation. +def translate_ethosu_depthwise_conv2d( + tir_call_extern: tvm.tir.Call, +) -> Tuple[vapi.NpuConvDepthWiseOperation, int]: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. Parameters ---------- - tir_extern_call : tvm.tir.Call - This should be a tir external call that has an agreed upon ordering - for NPU TIR Compiler. See Serial2DDepthwise in + tir_call_extern : tvm.tir.Call + This should be a TIR call_extern that has agreed upon ordering + for TIR Compiler. See Serial2DDepthwise in tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. Returns ------- - ethosu.vela.api.NpuDepthWiseOperation + ethosu.vela.api.NpuConvDepthWiseOperation The vela object containing the params of ethosu_depthwise_conv2d weights_zero_point : int The zero point of the weights """ - serial_object = spec.create_serial_object(spec.Serial2DDepthwise, tir_extern_call.args[1:]) + serial_object = spec.create_serial_object(spec.Serial2DDepthwise, tir_call_extern.args[1:]) return _create_npu_op_depthwise_conv2d(serial_object) @@ -625,3 +628,52 @@ def _create_npu_dma_op(serial_copy): length=int(serial_copy.length.value), ) return vapi.NpuDmaOperation(src, dest) + + +def translate_ethosu_pooling(tir_call_extern: tvm.tir.Call) -> vapi.NpuPoolingOperation: + """This function will translate a TIR call_extern + as produced by NPU Relay to TIR compilation. + + Parameters + ---------- + tir_call_extern : tvm.tir.Call + This should be a TIR call_extern that has agreed upon ordering + for TIR Compiler. See SerialPooling in + tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. + + Returns + ------- + ethosu.vela.api.NpuPoolingOperation + The vela object containing the params of ethosu_pooling + """ + serial_object = spec.create_serial_object(spec.SerialPooling, tir_call_extern.args[1:]) + return _create_npu_op_pooling(serial_object) + + +def _create_npu_op_pooling(serial_pooling: spec.SerialPooling): + pooling_type = serial_pooling.pooling_type + if pooling_type == "AVG": + npu_pooling_op = vapi.NpuPoolingOp.AVERAGE + elif pooling_type == "MAX": + npu_pooling_op = vapi.NpuPoolingOp.MAX + + npu_pooling_op = vapi.NpuPoolingOperation(npu_pooling_op) + npu_pooling_op.ifm = _create_npu_feature_map(serial_pooling.ifm) + npu_pooling_op.ofm = _create_npu_feature_map(serial_pooling.ofm) + npu_pooling_op.kernel = _create_npu_kernel(serial_pooling.pool_shape) + npu_pooling_op.padding = _create_npu_padding(serial_pooling.padding) + + npu_pooling_op.activation = _create_npu_activation(serial_pooling.activation) + if ( + npu_pooling_op.activation + and npu_pooling_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + ): + _convert_clip_bounds(npu_pooling_op) + + npu_pooling_op.upscale = _create_npu_resampling_mode(serial_pooling.upscale) + + target_accel_config = vela_api.get_accelerator_config() + block_config = vela_api.get_optimal_block_config(npu_pooling_op, target_accel_config) + npu_pooling_op.block_config = block_config + + return npu_pooling_op diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index ca417942840d..a152235c702b 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -23,7 +23,7 @@ import tvm # type: ignore from tvm import relay -from tvm.relay.expr import Constant # type: ignore +from tvm.relay.expr import Constant, Call # type: ignore from tvm.relay.op.contrib.register import register_pattern_table # type: ignore from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant # type: ignore from tvm.relay.build_module import bind_params_by_name # type: ignore @@ -170,6 +170,16 @@ def check_padding(padding: List[int], bounds: List[int]): return not (top > topb or left > leftb or bottom > bottomb or right > rightb) +def check_pool_shape(pool_shape: tvm.ir.container.Array) -> bool: + if len(pool_shape) != 2: + return False + if pool_shape[1] > 256: + return False + if pool_shape[0] * pool_shape[1] > 256 * 256: + return False + return True + + class QnnConv2DParams: """ This class will parse a Call to a ethosu.qnn_conv2d composite function @@ -331,6 +341,123 @@ def qnn_depthwise_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return clip_or_req +class MaxPool2DParams: + """ + This class will parse a call to a ethosu.maxpool2d composite function + and extract the parameter information. + """ + + composite_name = "ethosu.maxpool2d" + # The hardware only supports padding upto the numbers as follows + padding_bounds = [127, 127, 128, 128] + + def __init__(self, func_body: Call): + clip = None + if str(func_body.op) == "clip": + clip = func_body + pool_op = clip.args[0] + else: + pool_op = func_body + + attrs = pool_op.attrs + self.ifm = TensorParams(pool_op.args[0], attrs.layout) + self.ofm = TensorParams(pool_op, attrs.layout) + self.pool_shape = attrs.pool_size + self.strides = attrs.strides + self.padding = attrs.padding + self.activation = clip + self.pooling_type = "MAX" + + def is_valid(self): + """ + This function checks whether MaxPool2D has compatible attributes with the NPU + """ + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params): + return False + if self.ifm.dtype != self.ofm.dtype: + return False + if not check_strides(self.strides): + return False + if not check_batch_size(self.ifm): + return False + if not check_padding(self.padding, self.padding_bounds): + return False + if not check_pool_shape(self.pool_shape): + return False + return True + + +def qnn_maxpool2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for nn.max_pool2d with optional fused RELU activation. + """ + pattern = is_op("nn.max_pool2d")(wildcard()) + pattern = pattern.optional(is_op("clip")) + return pattern + + +class AvgPool2DParams: + """ + This class will parse a call to a ethosu.avgpool2d composite function + and extract the parameter information. + """ + + composite_name = "ethosu.avgpool2d" + # The hardware only supports padding upto the numbers as follows + padding_bounds = [127, 127, 128, 128] + + def __init__(self, func_body: Call): + clip = None + if str(func_body.op) == "clip": + clip = func_body + cast2 = clip.args[0] + else: + cast2 = func_body + + avgpool = cast2.args[0] + cast1 = avgpool.args[0] + + attrs = avgpool.attrs + self.ifm = TensorParams(cast1.args[0], attrs.layout) + self.ofm = TensorParams(cast2, attrs.layout) + self.pool_shape = attrs.pool_size + self.strides = attrs.strides + self.padding = attrs.padding + self.activation = clip + self.pooling_type = "AVG" + + def is_valid(self): + """ + This function checks whether AvgPool2D has compatible attributes with the NPU + """ + tensor_params = [self.ifm, self.ofm] + if not check_valid_dtypes(tensor_params): + return False + if self.ifm.dtype != self.ofm.dtype: + return False + if not check_strides(self.strides): + return False + if not check_batch_size(self.ifm): + return False + if not check_padding(self.padding, self.padding_bounds): + return False + if not check_pool_shape(self.pool_shape): + return False + return True + + +def qnn_avgpool2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern: + """ + This function creates the pattern for nn.avg_pool2d with optional fused RELU activation. + """ + pattern = is_op("cast")(wildcard()) + pattern = is_op("nn.avg_pool2d")(pattern) + pattern = is_op("cast")(pattern) + pattern = pattern.optional(is_op("clip")) + return pattern + + @register_pattern_table("ethosu") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -344,6 +471,16 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal qnn_depthwise_conv2d_pattern(), lambda pat: QnnDepthwiseConv2DParams(pat).is_valid(), ), + ( + MaxPool2DParams.composite_name, + qnn_maxpool2d_pattern(), + lambda pat: MaxPool2DParams(pat).is_valid(), + ), + ( + AvgPool2DParams.composite_name, + qnn_avgpool2d_pattern(), + lambda pat: AvgPool2DParams(pat).is_valid(), + ), ] diff --git a/src/relay/op/contrib/ethosu/convolution.cc b/src/relay/op/contrib/ethosu/convolution.cc index bad10bf66f3a..2cc6c3416e27 100644 --- a/src/relay/op/contrib/ethosu/convolution.cc +++ b/src/relay/op/contrib/ethosu/convolution.cc @@ -70,7 +70,7 @@ struct EthosuConv2DAttrs : public tvm::AttrsNode { .describe("The 2 dimensional kernel shape as (kernel_height, kernel_width).") .set_default(NullValue>()); TVM_ATTR_FIELD(ofm_channels) - .describe("The number of OFM channels.") + .describe("The number of the Output Feature Map channels.") .set_default(NullValue()); TVM_ATTR_FIELD(strides) .set_default(Array({1, 1})) @@ -179,7 +179,7 @@ RELAY_REGISTER_OP("contrib.ethosu.conv2d") .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized convolution operator. This Relay operator corresponds to the hardware-implemented quantized -convolution operation found on Ethos(TM)-U NPUs. It accepts either NHWC +convolution operation found on Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format for the input data (Input Feature Map, or IFM) and OHWI format for the kernel weights. @@ -201,7 +201,7 @@ of type uint8. For more detail, refer to the Technical Reference Manual linked a .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") .add_argument("weight", "Tensor", "The weight tensor.") .add_argument("scale_bias", "Tensor", "The packed per-channel weight scale and bias tensor.") - .add_argument("lut", "Tensor", "The look-up table values to use if activation = 'LUT'.") + .add_argument("lut", "Tensor", "The look-up table of values to use if activation = 'LUT'.") .set_support_level(11) .add_type_rel("EthosuConv2D", EthosuConv2DRel); diff --git a/src/relay/op/contrib/ethosu/depthwise.cc b/src/relay/op/contrib/ethosu/depthwise.cc index fa73645d45de..5ff27de51b2f 100644 --- a/src/relay/op/contrib/ethosu/depthwise.cc +++ b/src/relay/op/contrib/ethosu/depthwise.cc @@ -186,7 +186,7 @@ RELAY_REGISTER_OP("contrib.ethosu.depthwise_conv2d") .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized depthwise operator. This Relay operator corresponds to the hardware-implemented quantized -depthwise operation found on Ethos(TM)-U NPUs. It accepts either NHWC or NHCWB16 format +depthwise operation found on Ethos(TM)-U NPU. It accepts either NHWC or NHCWB16 format for the input data (input feature map, or IFM) and OHWI format for the kernel weights. - **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) @@ -201,7 +201,7 @@ for the input data (input feature map, or IFM) and OHWI format for the kernel we .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") .add_argument("weight", "Tensor", "The weight tensor.") .add_argument("scale_bias", "Tensor", "The packed per-channel weight scale and bias tensor.") - .add_argument("lut", "Tensor", "The look-up table values to use if activation = 'LUT'") + .add_argument("lut", "Tensor", "The look-up table of values to use if activation = 'LUT'") .set_support_level(11) .add_type_rel("EthosuDepthwiseConv2D", EthosuDepthwiseConv2DRel); diff --git a/src/relay/op/contrib/ethosu/pooling.cc b/src/relay/op/contrib/ethosu/pooling.cc new file mode 100644 index 000000000000..86f14f37a8d8 --- /dev/null +++ b/src/relay/op/contrib/ethosu/pooling.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/op/contrib/ethosu/pooling.cc + * \brief Pooling operators definitions for the Arm(R) Ethos(TM)-U NPU convolution ops. + */ +#include + +#include "common.h" + +namespace tvm { +namespace relay { +namespace op { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes used by the Ethos(TM)-U NPU pooling operator */ +struct EthosuPoolingAttrs : public tvm::AttrsNode { + String pooling_type; + double ifm_scale; + int ifm_zero_point; + double ofm_scale; + int ofm_zero_point; + Array pool_shape; + IndexExpr ofm_channels; + Array strides; + Array padding; + String activation; + int clip_min; + int clip_max; + String upscale; + String ifm_layout; + String ofm_layout; + + TVM_DECLARE_ATTRS(EthosuPoolingAttrs, "relay.attrs.EthosuPoolingAttrs") { + TVM_ATTR_FIELD(pooling_type) + .describe("The type of the pooling. 'AVG' - average pool, 'MAX' - max pool."); + TVM_ATTR_FIELD(ifm_scale).describe("The quantization scale for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ifm_zero_point) + .describe("The quantization zero point for the Input Feature Map tensor."); + TVM_ATTR_FIELD(ofm_scale).describe("The quantization scale for the Output Feature Map tensor."); + TVM_ATTR_FIELD(ofm_zero_point) + .describe("The quantization zero point for the Output Feature Map tensor."); + TVM_ATTR_FIELD(pool_shape) + .describe("The 2 dimensional pool shape as (pool_shape_height, pool_shape_width).") + .set_default(NullValue >()); + TVM_ATTR_FIELD(ofm_channels) + .describe(" The number of the Output Feature Map channels.") + .set_default(NullValue()); + TVM_ATTR_FIELD(strides) + .set_default(Array({1, 1})) + .describe("The 2 dimensional strides as (stride_height, stride_width)."); + TVM_ATTR_FIELD(padding) + .describe("The 4 dimensional padding as (pad_top, pad_left, pad_bottom, pad_right).") + .set_default(Array({0, 0, 0, 0})); + TVM_ATTR_FIELD(activation) + .describe( + "The activation function to use. " + "'NONE' - no activation function. " + "'CLIP' - clip the output between clip_min and clip_max. " + "'TANH' - tanh activation function. " + "'SIGMOID' - sigmoid activation function. " + "'LUT' - use a look-up table to perform the activation function.") + .set_default("NONE"); + TVM_ATTR_FIELD(clip_min) + .describe("The minimum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(clip_max) + .describe("The maximum clipping value if activation = 'CLIP'.") + .set_default(0); + TVM_ATTR_FIELD(upscale) + .describe( + "The 2x2 upscaling mode to apply to the Input Feature Map tensor. " + "'NONE' - no upscaling. " + "'NEAREST' - upscale using nearest neighbour. " + "'ZEROS' - upscale using zeros.") + .set_default("NONE"); + TVM_ATTR_FIELD(ifm_layout) + .describe("The layout of the Input Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + TVM_ATTR_FIELD(ofm_layout) + .describe("The layout of the Output Feature Map tensor. Can be 'NHWC' or 'NHCWB16'.") + .set_default("NHWC"); + } +}; + +TVM_REGISTER_NODE_TYPE(EthosuPoolingAttrs); + +bool EthosuPoolingRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + int ifm_index = 0; + int result_index = 2; + ICHECK_EQ(types.size(), result_index + 1); + + const auto* ifm = types[ifm_index].as(); + if (ifm == nullptr) return false; + + const auto* param = attrs.as(); + ICHECK(param != nullptr) << "EthosuPoolingAttrs cannot be nullptr."; + + if (param->pooling_type != "AVG" && param->pooling_type != "MAX") { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: expected pooling_type 'AVG' or 'MAX' but was " + << param->pooling_type); + return false; + } + + if (ifm->dtype != DataType::UInt(8) && ifm->dtype != DataType::Int(8)) { + reporter->GetDiagCtx().EmitFatal( + Diagnostic::Error(reporter->GetSpan()) + << "Invalid operator: Expected pool type(uint8) or type(int8) for ifm but was " + << ifm->dtype); + return false; + } + + // Assign ofm type + auto ofm_shape = EthosuInferKernelOutput( + ifm->shape, param->ifm_layout, param->ofm_layout, param->pool_shape, param->ofm_channels, + Array({1, 1}), param->strides, param->padding); + reporter->Assign(types[result_index], TensorType(ofm_shape, ifm->dtype)); + return true; +} + +Expr MakeEthosuPooling(Expr ifm, Expr lut, String pooling_type, double ifm_scale, + int ifm_zero_point, double ofm_scale, int ofm_zero_point, + Array pool_shape, IndexExpr ofm_channels, + Array strides, Array padding, String activation, + int clip_min, int clip_max, String upscale, String ifm_layout, + String ofm_layout) { + auto attrs = make_object(); + attrs->pooling_type = std::move(pooling_type); + attrs->ifm_scale = ifm_scale; + attrs->ifm_zero_point = ifm_zero_point; + attrs->ofm_scale = ofm_scale; + attrs->ofm_zero_point = ofm_zero_point; + attrs->pool_shape = std::move(pool_shape); + attrs->ofm_channels = std::move(ofm_channels); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->activation = std::move(activation); + attrs->clip_min = clip_min; + attrs->clip_max = clip_max; + attrs->upscale = std::move(upscale); + attrs->ifm_layout = std::move(ifm_layout); + attrs->ofm_layout = std::move(ofm_layout); + static const Op& op = Op::Get("contrib.ethosu.pooling"); + return Call(op, {ifm, lut}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op._make.ethosu_pooling").set_body_typed(MakeEthosuPooling); + +RELAY_REGISTER_OP("contrib.ethosu.pooling") + .describe(R"code(Arm(R) Ethos(TM)-U NPU 2D quantized pooling operator. + +This Relay operator corresponds to the hardware-implemented quantized +pooling operation found on Ethos(TM)-U NPU. It accepts either NHWC +or NHCWB16 format for the input data (input feature map, or IFM). + +Reference: https://developer.arm.com/documentation/102420/0200/ + +- **ifm**: NHWC - (1, ifm_height, ifm_width, ifm_channels) + NHCWB16 - (1, ifm_height, ifm_channels // 16, ifm_width, 16) +- **ofm**: (1, ofm_height, ofm_width, ofm_channels) + +)code" TVM_ADD_FILELINE) + .set_attrs_type() + .set_num_inputs(2) + .add_argument("ifm", "Tensor", "The Input Feature Map tensor (IFM).") + .add_argument("lut", "Tensor", "The look-up table of values to use if activation = 'LUT'") + .set_support_level(11) + .add_type_rel("EthosuPooling", EthosuPoolingRel); + +} // namespace ethosu +} // namespace contrib +} // namespace op +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 01a7ceb9ed56..19f546a6f974 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -17,7 +17,6 @@ """ This module provides infrastructure to verify the correctness of the command stream produced. - Currently it will invoke vela to generate a vela-optimized tflite in which the command stream is contained as a custom operator. This class include methods to parse the custom operator to extract @@ -460,3 +459,51 @@ def make_ethosu_depthwise_conv2d( ofm_layout=ofm_layout, ) return depthwise + + +def get_pooling_args(call, include_buffers=False): + args = call.args + pooling_args = [] + + for i, arg in enumerate(args): + if isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): + pooling_args.append(arg.value) + elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: + pooling_args.append(arg.index) + else: + pooling_args.append(arg) + + return pooling_args + + +def make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + activation="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", +): + pooling = ethosu_ops.ethosu_pooling( + ifm, + lut=relay.const([], dtype="int8"), + pooling_type=pooling_type, + ifm_scale=1, + ifm_zero_point=0, + ofm_scale=1, + ofm_zero_point=0, + pool_shape=pool_shape, + ofm_channels=ofm_channels, + strides=strides, + padding=padding, + activation=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return pooling diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 4949d6814ab2..478a3c2bd521 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -254,5 +254,94 @@ def representative_dataset(): infra.verify_source(compiled_models, accel_type) +@pytest.mark.parametrize( + "accel_type", + ACCEL_TYPES, +) +@pytest.mark.parametrize("pooling_type", ["MAX", "AVG"]) +@pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]]) +@pytest.mark.parametrize( + "pool_shape, strides, activation_function, padding", + [([1, 2], [1, 2], "NONE", "SAME"), ([2, 3], [2, 3], "RELU", "VALID")], +) +def test_ethosu_pooling( + accel_type, + ifm_shape, + pooling_type, + strides, + pool_shape, + activation_function, + padding, +): + dtype = "int8" + + def create_tflite_graph(): + tf.config.run_functions_eagerly(True) + + class Model(tf.Module): + @tf.function + def tf_function(self, x): + if pooling_type == "MAX": + op = tf.nn.max_pool(x, pool_shape, strides, padding) + elif pooling_type == "AVG": + op = tf.nn.avg_pool(x, pool_shape, strides, padding) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={"x": ifm_shape}, + dtype_dict={"x": dtype}, + ) + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethosu.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index b9a588d4aec0..fc03a98beb6b 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -313,7 +313,7 @@ def verify_linear(ext_func, conv2d_params): for test_case in test_cases: mod, conv_params = test_case[0](*test_case[1]) mod = ethosu.partition_for_ethosu(mod) - mod = legalize.LegalizeEthosUConv2D()(mod) + mod = legalize.LegalizeConv2D()(mod) verify_linear(mod["tvmgen_default_ethosu_main_0"], conv_params) @@ -349,7 +349,7 @@ def create_graph_single_unsupported_ifm_layout( with pytest.raises( tvm._ffi.base.TVMError, match="EthosUCodegenError: Unsupported Layout NCHW" ): - mod = legalize.LegalizeEthosUConv2D()(mod) + mod = legalize.LegalizeConv2D()(mod) @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)]) @@ -458,7 +458,102 @@ def verify(ext_func): mod = partition_ethosu_by_table(mod, depthwise_pattern_table) mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( - legalize.EthosuDepthwiseConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + legalize.DepthwiseConv2DRewriter(), mod["tvmgen_default_ethosu_main_0"] + ) + verify(mod["tvmgen_default_ethosu_main_0"]) + + +@pytest.mark.parametrize("pooling_type", ["MAX", "AVG"]) +@pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]]) +@pytest.mark.parametrize( + "pool_shape, strides, activation_function, padding", + [([1, 2], [1, 2], "NONE", "SAME"), ([2, 3], [2, 3], "RELU", "VALID")], +) +def test_tflite_pool2d_legalize( + ifm_shape, pooling_type, strides, pool_shape, activation_function, padding +): + dtype = "int8" + + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, x): + if pooling_type == "MAX": + op = tf.nn.max_pool(x, pool_shape, strides, padding) + elif pooling_type == "AVG": + op = tf.nn.avg_pool(x, pool_shape, strides, padding) + if activation_function == "RELU": + op = tf.nn.relu(op) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + tf.TensorSpec(ifm_shape, dtype=tf.float32) + ) + + # Convert the model + def representative_dataset(): + for _ in range(100): + data = np.random.rand(*tuple(ifm_shape)) + yield [data.astype(np.float32)] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + return tflite_model + + def verify(ext_func): + ofm_shape = infra.compute_ofm_shape(ifm_shape, padding, pool_shape, strides) + op = ext_func.body + assert list(op.args[0].checked_type.shape) == ifm_shape + assert op.args[0].checked_type.dtype == dtype + assert list(op.checked_type.shape) == ofm_shape + assert op.checked_type.dtype == dtype + assert op.attrs.pooling_type == pooling_type + assert list(op.attrs.strides) == strides + assert list(op.attrs.padding) == infra.compute_padding_shape( + ifm_shape, ofm_shape, padding, pool_shape, strides + ) + assert list(op.attrs.pool_shape) == pool_shape + assert op.attrs.ofm_channels == ifm_shape[3] + if activation_function == "RELU": + assert str(op.attrs.activation) == "CLIP" + + if pooling_type == "MAX": + rewriter = legalize.MaxPoolingRewriter() + pattern_table = [ + ( + ethosu.MaxPool2DParams.composite_name, + ethosu.qnn_maxpool2d_pattern(), + lambda pat: ethosu.MaxPool2DParams(pat).is_valid(), + ), + ] + elif pooling_type == "AVG": + rewriter = legalize.AvgPoolingRewriter() + pattern_table = [ + ( + ethosu.AvgPool2DParams.composite_name, + ethosu.qnn_avgpool2d_pattern(), + lambda pat: ethosu.AvgPool2DParams(pat).is_valid(), + ), + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={"x": ifm_shape}, + dtype_dict={"x": dtype}, + ) + mod = partition_ethosu_by_table(mod, pattern_table) + + mod["tvmgen_default_ethosu_main_0"] = dataflow_pattern.rewrite( + rewriter, mod["tvmgen_default_ethosu_main_0"] ) verify(mod["tvmgen_default_ethosu_main_0"]) diff --git a/tests/python/contrib/test_ethosu/test_replace_pooling.py b/tests/python/contrib/test_ethosu/test_replace_pooling.py new file mode 100644 index 000000000000..099b9d60c428 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_pooling.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") + +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir import spec +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from .infra import make_ethosu_pooling, get_pooling_args + + +@pytest.mark.parametrize( + "ifm_shape, ofm_channels, ifm_layout, ofm_layout", + [ + ((1, 5, 9, 3), 3, "NHWC", "NHWC"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHCWB16"), + ((1, 8, 3, 9, 16), 40, "NHCWB16", "NHWC"), + ((1, 8, 9, 40), 40, "NHWC", "NHCWB16"), + ], +) +@pytest.mark.parametrize("pooling_type", ["AVG", "MAX"]) +@pytest.mark.parametrize("activation", ["NONE", "CLIP"]) +def test_pooling_single( + ifm_shape, + ofm_channels, + ifm_layout, + ofm_layout, + pooling_type, + activation, +): + pool_shape = (3, 2) + strides = (1, 2) + padding = (1, 1, 1, 0) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + pooling = make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + activation, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(pooling), pooling) + func = run_opt_pass(func, relay.transform.InferType()) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_pooling_args(stmt)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] + ifm_stride_h = ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - pool_shape[0] + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[2] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1 + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - pool_shape[0] + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[3] - pool_shape[1] + padding[1] + padding[1]) // strides[1] + 1 + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = ofm_channels if ofm_width > 1 else 1 + ofm_stride_h = ofm_channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((ofm_channels - 1) // 16 + 1) + + serial_pooling = spec.SerialPooling( + ifm=spec.SerialFeatureMap( + data_type="int8", + height=ifm_shape[1], + width=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + channels=ofm_channels, + tile_height_0=ifm_shape[1], + tile_height_1=0, + tile_width_0=ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ifm_layout, + stride_h=ifm_stride_h, + stride_w=ifm_stride_w, + stride_c=ifm_stride_c, + ), + ofm=spec.SerialFeatureMap( + data_type="int8", + height=ofm_height, + width=ofm_width, + channels=ofm_channels, + tile_height_0=ofm_height, + tile_height_1=0, + tile_width_0=ofm_width, + tile_address_0=0, + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=1.0, + zero_point=0, + layout=ofm_layout, + stride_h=ofm_stride_h, + stride_w=ofm_stride_w, + stride_c=ofm_stride_c, + ), + pooling_type=pooling_type, + pool_shape=spec.SerialKernel( + width=pool_shape[1], + height=pool_shape[0], + stride_w=strides[1], + stride_h=strides[0], + dilation_w=1, + dilation_h=1, + ), + padding=spec.SerialPadding( + top=padding[0], left=padding[1], bottom=padding[2], right=padding[3] + ), + activation=spec.SerialActivation( + op=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + ), + upscale="NONE", + ) + + assert data[0] == ["ethosu_pooling"] + list(serial_pooling) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 8240b392a1cf..f4b83a4577cc 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -634,7 +634,7 @@ def populate_ethosu_copy_calls(stmt): for test_case in test_cases: ethosu_copy_calls = extract_ethosu_copy_extern_calls(test_case["tir_module"]) for idx, ethosu_copy_call in enumerate(ethosu_copy_calls): - npu_dma_op = tir_to_cs_translator.translate_ethosu_tir_extern_call(ethosu_copy_call) + npu_dma_op = tir_to_cs_translator.translate_ethosu_tir_call_extern(ethosu_copy_call) assert npu_dma_op.src.address.buffer_var.name == test_case["ref"][idx]["src"] assert npu_dma_op.dest.address.buffer_var.name == test_case["ref"][idx]["dest"] assert npu_dma_op.src.length == test_case["ref"][idx]["length"] @@ -675,7 +675,7 @@ def test_assign_addresses(): }, ] - def extract_extern_calls(mod): + def extract_call_extern_list(mod): """This function will obtain all ethosu_conv2d calls from a NPU TIR module Parameters @@ -825,10 +825,10 @@ def check_buffer(address, region, length, buffer_var): buffer_info = tir_to_cs_translator.extract_buffer_info( test_case["tir_module"], test_case["param_dict"] ) - extern_calls = extract_extern_calls(test_case["tir_module"]) + extern_calls = extract_call_extern_list(test_case["tir_module"]) _npu_ops = list() for extern_call in extern_calls: - _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_extern_call(extern_call)) + _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_call_extern(extern_call)) npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops) _npu_ops, constant_tensor, scratch_size = tir_to_cs_translator.assign_addresses( buffer_info, _npu_ops @@ -842,5 +842,76 @@ def check_buffer(address, region, length, buffer_var): assert np.prod(constant_tensor_read_mask) == 1 +# fmt: off +"""A ethosu_pooling tir testcase for the translator""" +@tvm.script.ir_module +class SingleEthosuPooling: + @T.prim_func + def main(placeholder: T.handle, placeholder_3: T.handle, ethosu_write: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder, [1, 5, 9, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_2 = T.match_buffer(ethosu_write, [1, 5, 5, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + T.evaluate(T.call_extern("ethosu_pooling", "int8", 5, 9, 3, 5, 0, 9, T.load("int8", placeholder_4.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 27, 3, 1, "int8", 5, 5, 3, 5, 0, 5, T.load("int8", ethosu_write_2.data, 0), 0, 0, 0, T.float32(1.0), 0, "NHWC", 15, 3, 1, "AVG", 2, 3, 2, 1, 1, 1, 1, 1, 1, 0, "CLIP", 10, 100, "NONE", dtype="int8")) + __tvm_meta__ = None +# fmt: on + + +def test_translate_ethosu_pooling(): + def extract_ethosu_pooling_extern_call(mod): + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_pooling_calls = list() + + def populate_ethosu_pooling_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_pooling" + ): + ethosu_pooling_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_pooling_calls) + return ethosu_pooling_calls[0] + + pooling_call = extract_ethosu_pooling_extern_call(SingleEthosuPooling) + npu_op = tir_to_cs_translator.translate_ethosu_pooling(pooling_call) + + assert npu_op.ifm.data_type == vapi.NpuDataType.INT8 + assert npu_op.ifm.shape == vapi.NpuShape3D(5, 9, 3) + assert npu_op.ifm.tiles.height_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_0 + assert npu_op.ifm.tiles.height_1 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).height_1 + assert npu_op.ifm.tiles.width_0 == vapi.NpuTileBox(5, 0, 9, [0, 0, 0, 0]).width_0 + assert npu_op.ifm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ifm.layout == vapi.NpuLayout.NHWC + assert npu_op.ifm.strides == vapi.NpuShape3D(27, 3, 1) + # Compare OFM + assert npu_op.ofm.data_type == vapi.NpuDataType.INT8 + assert npu_op.ofm.shape == vapi.NpuShape3D(5, 5, 3) + assert npu_op.ofm.tiles.height_0 == vapi.NpuTileBox(5, 0, 5, [0, 0, 0, 0]).height_0 + assert npu_op.ofm.tiles.height_1 == vapi.NpuTileBox(5, 0, 5, [0, 0, 0, 0]).height_1 + assert npu_op.ofm.tiles.width_0 == vapi.NpuTileBox(5, 0, 5, [0, 0, 0, 0]).width_0 + assert npu_op.ofm.quantization == vapi.NpuQuantization(1.0, 0) + assert npu_op.ofm.layout == vapi.NpuLayout.NHWC + assert npu_op.ofm.strides == vapi.NpuShape3D(15, 3, 1) + # Compare pooling_type + assert npu_op.sub_op_type == vapi.NpuPoolingOp.AVERAGE + # Compare kernel and padding + assert ( + npu_op.kernel.__dict__ + == vapi.NpuKernel(w=2, h=3, stride_x=2, stride_y=1, dilation_x=1, dilation_y=1).__dict__ + ) + assert npu_op.padding == vapi.NpuPadding(top=1, left=1, bottom=1, right=0) + # Compare activation + assert npu_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + assert npu_op.activation.min == 10 + assert npu_op.activation.max == 100 + # Compare ifm upscaling + assert npu_op.ifm_upscale == vapi.NpuResamplingMode.NONE + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_type_inference.py b/tests/python/contrib/test_ethosu/test_type_inference.py index 47fddad773b2..9b041392c732 100644 --- a/tests/python/contrib/test_ethosu/test_type_inference.py +++ b/tests/python/contrib/test_ethosu/test_type_inference.py @@ -18,10 +18,12 @@ pytest.importorskip("ethosu.vela") +from tvm import relay, TVMError from tvm import relay from tvm.relay.testing import run_opt_pass from .infra import make_ethosu_conv2d from .infra import make_ethosu_depthwise_conv2d +from .infra import make_ethosu_pooling @pytest.mark.parametrize( @@ -54,9 +56,9 @@ def test_ethosu_conv2d_type_inference( ifm_layout=ifm_layout, ofm_layout=ofm_layout, ) - f = relay.Function([ifm], conv2d) - f = run_opt_pass(f, relay.transform.InferType()) - assert tuple(f.body.checked_type.shape) == ofm_shape + func = relay.Function([ifm], conv2d) + func = run_opt_pass(func, relay.transform.InferType()) + assert tuple(func.body.checked_type.shape) == ofm_shape @pytest.mark.parametrize( @@ -87,9 +89,86 @@ def test_ethosu_depthwise_conv2d_type_inference( ifm_layout=ifm_layout, ofm_layout=ofm_layout, ) - f = relay.Function([ifm], depthwise_conv2d) - f = run_opt_pass(f, relay.transform.InferType()) - assert tuple(f.body.checked_type.shape) == ofm_shape + func = relay.Function([ifm], depthwise_conv2d) + func = run_opt_pass(func, relay.transform.InferType()) + assert tuple(func.body.checked_type.shape) == ofm_shape + + +@pytest.mark.parametrize( + "ifm_shape, ifm_layout", [((1, 56, 72, 55), "NHWC"), ((1, 56, 4, 72, 16), "NHCWB16")] +) +@pytest.mark.parametrize( + "ofm_shape, ofm_layout", [((1, 56, 38, 55), "NHWC"), ((1, 56, 4, 38, 16), "NHCWB16")] +) +def test_ethosu_pooling_type_inference( + ifm_shape, + ifm_layout, + ofm_shape, + ofm_layout, +): + dtype = "int8" + ifm = relay.var("ifm", shape=ifm_shape, dtype=dtype) + pooling_type = "AVG" + pool_shape = (3, 2) + ofm_channels = 55 + strides = (1, 2) + padding = (0, 1, 2, 3) + pooling = make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + func = relay.Function([ifm], pooling) + func = run_opt_pass(func, relay.transform.InferType()) + assert tuple(func.body.checked_type.shape) == ofm_shape + assert func.body.checked_type.dtype == dtype + + +def test_ethosu_pooling_invalid_pooling_type(): + invalid_pooling_type = "A" + dtype = "int8" + ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype=dtype) + pool_shape = (3, 2) + ofm_channels = 55 + strides = (1, 2) + padding = (0, 1, 2, 3) + pooling = make_ethosu_pooling( + ifm, + invalid_pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + ) + func = relay.Function([ifm], pooling) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) + + +def test_ethosu_pooling_invalid_dtype(): + invalid_dtype = "int32" + ifm = relay.var("ifm", shape=[1, 56, 72, 55], dtype=invalid_dtype) + pooling_type = "MAX" + pool_shape = (3, 2) + ofm_channels = 55 + strides = (1, 2) + padding = (0, 1, 2, 3) + pooling = make_ethosu_pooling( + ifm, + pooling_type, + pool_shape, + ofm_channels, + strides, + padding, + ) + func = relay.Function([ifm], pooling) + with pytest.raises(TVMError): + run_opt_pass(func, relay.transform.InferType()) if __name__ == "__main__": diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 147c420cc902..2ef84d7f1a6f 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -497,8 +497,8 @@ def test_compile_tflite_module_with_external_codegen_ethosu( # The number of c_source_files depends on the number of fused subgraphs that # get offloaded to the NPU, e.g. conv2d->depthwise_conv2d->conv2d gets offloaded # as a single subgraph if both of these operators are supported by the NPU. - # Currently there are two source files for CPU execution and two offload graphs - assert len(c_source_files) == 4 + # Currently there are two source files for CPU execution and one offload graph + assert len(c_source_files) == 3 @mock.patch("tvm.relay.build")