diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index d63b69b2e259..d41d8854d7d1 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -28,3 +28,4 @@ from .nn import * from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule from .adaptive_avg_pool1d import * +from .global_avg_pool2d import * diff --git a/python/tvm/topi/hexagon/qnn/global_avg_pool2d.py b/python/tvm/topi/hexagon/qnn/global_avg_pool2d.py new file mode 100755 index 000000000000..1c171be8976e --- /dev/null +++ b/python/tvm/topi/hexagon/qnn/global_avg_pool2d.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Assumptions: +1) The input is in NCHW layout. Squeezenet is the only model that calls + nn.global_avg_pool2d and the only layout it uses is 'NCHW'. +2) Both input and output dtype is uint8 and + quantization parameter is provided to the op. +3) Input is assumed to always be multiple of fixed chunk 32c8h8w. +""" + +from tvm import te +from tvm import tir +from ..utils import get_layout_transform_fn, get_fixed_point_value, saturate + + +def global_avg_pool2d_u8( + data: te.Tensor, + odtype: str, + input_zero_point: int, + input_scale: float, + output_zero_point: int, + output_scale: float, +): + """global_avg_pool2d""" + input_b, input_c, input_h, input_w = data.shape + oshape = (input_b, input_c) + (1, 1) + + if input_h * input_w < 256: + bits = "16" + else: + bits = "32" + + if odtype == "uint8": + temp_dtype = "uint" + bits + elif odtype == "int8": + temp_dtype = "int" + bits + else: + raise RuntimeError(f"Unsupported output dtype, {odtype}'") + + pool_area = input_h * input_w + rh_r = te.reduce_axis((0, input_h), name="rh_r") + rw_r = te.reduce_axis((0, input_w), name="rw_r") + + scale_with_area = input_scale / (output_scale * int(pool_area)) + scale_fixed_point, rsh = get_fixed_point_value(scale_with_area, "int16") + corr = (output_zero_point << rsh) - input_zero_point * pool_area * scale_fixed_point + + sum_compute = te.compute( + oshape, + lambda n, c, h, w: te.sum( + data[n, c, h + rh_r, w + rw_r].astype(temp_dtype), axis=[rh_r, rw_r] + ), + name="sum", + ) + + avg_compute = te.compute( + oshape, + lambda n, c, h, w: saturate( + ((sum_compute[n, c, h, w] * scale_fixed_point) + corr) >> rsh, odtype + ).astype(odtype), + name="global_avg_pool2d", + ) + + return avg_compute + + +def stir_global_avg_pool2d_u8_schedule(outs: te.Tensor, ins: te.Tensor, input_layout: str): + """Schedule""" + func = te.create_prim_func([ins, outs]) + s = tir.Schedule(func) + + sum_block = s.get_block("sum") + + # Input is multiple of fixed chunk but output is NxCx1x1 + # Hence transform_layout is only applied on input + input_transformed_layout = get_layout_transform_fn(input_layout) + s.transform_layout(sum_block, buffer=("read", 0), index_map=input_transformed_layout) + + return s diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py b/python/tvm/topi/hexagon/slice_ops/__init__.py index 5f86e706af50..6b17b64489a9 100644 --- a/python/tvm/topi/hexagon/slice_ops/__init__.py +++ b/python/tvm/topi/hexagon/slice_ops/__init__.py @@ -36,3 +36,4 @@ from .tanh import tanh_te_compute, tanhf16_schedule from .dwconv2d import * from .depth_to_space import d2s_compute, d2s_schedule +from .global_avg_pool2d import * diff --git a/python/tvm/topi/hexagon/slice_ops/global_avg_pool2d.py b/python/tvm/topi/hexagon/slice_ops/global_avg_pool2d.py new file mode 100755 index 000000000000..30222c11bb54 --- /dev/null +++ b/python/tvm/topi/hexagon/slice_ops/global_avg_pool2d.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Assumptions: +1) The input is in NCHW layout. Squeezenet is the only model that calls + nn.global_avg_pool2d and the only layout it uses is 'NCHW'. +2) The op takes input data as an argument. +3) Both input and output dtype is float32 and +4) Input is assumed to always be multiple of fixed chunk 32c8h4w. +""" + +from tvm import te +from tvm import tir +from tvm import topi +from ..utils import get_layout_transform_fn + + +def global_avg_pool2d( + data: te.Tensor, +): + """global_avg_pool2d""" + return topi.nn.global_pool(data, "avg", "NCHW") + + +def stir_global_avg_pool2d_schedule(outs: te.Tensor, ins: te.Tensor, input_layout: str): + """Schedule""" + func = te.create_prim_func([ins, outs]) + s = tir.Schedule(func) + + sum_block = s.get_block("adaptive_pool_sum") + + # Input is multiple of fixed chunk but output is NxCx1x1 + # Hence transform_layout is only applied on input + input_transformed_layout = get_layout_transform_fn(input_layout) + s.transform_layout(sum_block, buffer=("read", 0), index_map=input_transformed_layout) + + return s diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 5aeed9aa4fde..78ed21e8a13b 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -136,6 +136,14 @@ def ncw_32c64w_2d(n, c, w): return [n, c // 32, w // 64, te.AXIS_SEPARATOR, c % 32, w % 64] +def nchw_32c8h8w_2d(n, c, h, w): + return [n, c // 32, h // 8, w // 8, te.AXIS_SEPARATOR, c % 32, h % 8, w % 8] + + +def nchw_32c8h4w_2d(n, c, h, w): + return [n, c // 32, h // 8, w // 4, te.AXIS_SEPARATOR, c % 32, h % 8, w % 4] + + def get_layout_transform_fn(layout): """Return index map function as per the layout string""" if layout == "nhwc-8h2w32c2w-2d": @@ -180,6 +188,10 @@ def get_layout_transform_fn(layout): return ohwi32o_1d if layout == "ncw-32c64w-2d": return ncw_32c64w_2d + if layout == "nchw-32c8h8w-2d": + return nchw_32c8h8w_2d + if layout == "nchw-32c8h4w-2d": + return nchw_32c8h4w_2d raise RuntimeError(f"Unexpected layout '{layout}'") diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index fcb811fce742..e81c24694ef9 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -277,6 +277,19 @@ def transform_numpy(arr_np, current_layout: str, new_layout: str): raise RuntimeError(f"Unexpected new_layout '{new_layout}'") + if current_layout == "nchw": + if new_layout in ["nchw-32c8h8w-2d", "nchw-32c8h8w-1d"]: + n, c, h, w = arr_np.shape + return arr_np.reshape([n, c // 32, 32, h // 8, 8, w // 8, 8]).transpose( + 0, 1, 3, 5, 2, 4, 6 + ) + if new_layout in ["nchw-32c8h4w-2d", "nchw-32c8h4w-1d"]: + n, c, h, w = arr_np.shape + return arr_np.reshape([n, c // 32, 32, h // 8, 8, w // 4, 4]).transpose( + 0, 1, 3, 5, 2, 4, 6 + ) + raise RuntimeError(f"Unexpected new_layout '{new_layout}'") + raise RuntimeError(f"Unexpected current_layout '{current_layout}'") diff --git a/tests/python/contrib/test_hexagon/topi/slice_op/test_global_avg_pool2d.py b/tests/python/contrib/test_hexagon/topi/slice_op/test_global_avg_pool2d.py new file mode 100755 index 000000000000..3f7e999c7bca --- /dev/null +++ b/tests/python/contrib/test_hexagon/topi/slice_op/test_global_avg_pool2d.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test code for float16 and uint8 global_avg_pool2d.""" + +import numpy as np + +import tvm +from tvm import te +from tvm.topi.testing import adaptive_pool +import tvm.topi.hexagon.qnn as qn +import tvm.topi.hexagon.slice_ops as sl +from tvm.contrib.hexagon import allocate_hexagon_array +from ...infrastructure import transform_numpy, quantize_np, get_hexagon_target + + +SCALE_M_VAL = None +ZERO_POINT_M_VAL = None +SCALE_VAL = None +ZERO_POINT_VAL = None + + +class TestGlobalPool2D: + (input_shape,) = tvm.testing.parameters( + ([1, 32, 8, 8],), + ([1, 1056, 16, 16],), + ) + + # Fixed chunk layout is set as nchw-32c8h8w-2d for uint8 and nchw-32c8h4w-2d for float16. + # For optimization, it might get changed later. + # Since output shape will be NxCx1x1 which is not a + # multiple of fixed-chunk, output_layout is NCHW. + input_layout, output_layout, pool_type, layout, dtype = tvm.testing.parameters( + ("nchw-32c8h8w-2d", "nchw", "avg", "NCHW", "uint8"), + ("nchw-32c8h4w-2d", "nchw", "avg", "NCHW", "float16"), + ) + + @tvm.testing.fixture + def expected_output_np( + self, + input_np, + pool_type, + layout, + ): + """Generate expected output.""" + ref_np = tvm.topi.testing.adaptive_pool( + input_np, + (1, 1), + pool_type, + layout, + ) + return ref_np + + @tvm.testing.fixture + def input_np(self, input_shape, dtype): + if dtype in ("uint8", "int8"): + dtype = "float32" + return np.random.random(input_shape).astype(dtype) + + @tvm.testing.fixture + def quantize_input_np(self, input_np, dtype): + if dtype in ("uint8", "int8"): + global ZERO_POINT_VAL, SCALE_VAL + input_np_quantized, SCALE_VAL, ZERO_POINT_VAL = quantize_np(input_np, dtype) + return input_np_quantized + + @tvm.testing.fixture + def transformed_input_np(self, input_np, quantize_input_np, input_layout, layout, dtype): + if dtype == "float16": + return transform_numpy(input_np, layout.lower(), input_layout) + if dtype in ("uint8", "int8"): + return transform_numpy(quantize_input_np, layout.lower(), input_layout) + + raise RuntimeError(f"Unsupported data type '{dtype}'") + + @tvm.testing.fixture + def quantize_expected_output_np(self, expected_output_np, dtype): + if dtype in ("uint8", "int8"): + global ZERO_POINT_M_VAL, SCALE_M_VAL + out_ref_quantized, SCALE_M_VAL, ZERO_POINT_M_VAL = quantize_np( + expected_output_np, dtype + ) + + # Since output_layout is nchw, no transformation is needed. + return out_ref_quantized + + @tvm.testing.requires_hexagon + def test_global_pool2d( + self, + dtype, + input_shape, + input_layout, + transformed_input_np, + expected_output_np, + quantize_expected_output_np, + hexagon_session, + ): + a_tensor = te.placeholder(input_shape, name="a_tensor", dtype=dtype) + + if dtype == "float16": + m_tensor = sl.global_avg_pool2d(a_tensor) + tir_schedule = sl.stir_global_avg_pool2d_schedule(m_tensor, a_tensor, input_layout) + elif dtype in ["uint8", "int8"]: + m_tensor = qn.global_avg_pool2d_u8( + a_tensor, + dtype, + ZERO_POINT_VAL, + SCALE_VAL, + ZERO_POINT_M_VAL, + SCALE_M_VAL, + ) + tir_schedule = qn.stir_global_avg_pool2d_u8_schedule(m_tensor, a_tensor, input_layout) + + sch = tir_schedule.mod + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build( + sch, + [a_tensor, m_tensor], + get_hexagon_target("v69"), + name="global_pool2d", + ) + + input_axis_separator = [4] + + a_data_nd = allocate_hexagon_array( + hexagon_session.device, + data=transformed_input_np, + dtype=dtype, + axis_separators=input_axis_separator, + mem_scope="global.vtcm", + ) + + m_data_nd = allocate_hexagon_array( + hexagon_session.device, + expected_output_np.shape, + dtype=dtype, + ) + + mod = hexagon_session.load_module(func) + mod(a_data_nd, m_data_nd) + + # Convert nd to np + m_data_np = m_data_nd.numpy() + + if dtype == "float16": + np.testing.assert_allclose(expected_output_np, m_data_np, rtol=1e-3, atol=1e-3) + elif dtype in ["int8", "uint8"]: + np.testing.assert_allclose(quantize_expected_output_np, m_data_np, atol=1) + + +if __name__ == "__main__": + tvm.testing.main()