diff --git a/python/tvm/topi/hexagon/qnn/__init__.py b/python/tvm/topi/hexagon/qnn/__init__.py index 25d1e6d1854d..ef9c025ba5b2 100644 --- a/python/tvm/topi/hexagon/qnn/__init__.py +++ b/python/tvm/topi/hexagon/qnn/__init__.py @@ -18,7 +18,7 @@ """ Computes and schedules for Hexagon quantized ops """ from .avg_pool2d import qnn_avg_pool2d_compute, qnn_avg_pool2d_schedule - +from .qadd_qsub_qmul import * from .dequantize import ( dequantize_compute, dequantize_schedule, diff --git a/python/tvm/topi/hexagon/qnn/qadd_qsub_qmul.py b/python/tvm/topi/hexagon/qnn/qadd_qsub_qmul.py new file mode 100755 index 000000000000..043ad313bdef --- /dev/null +++ b/python/tvm/topi/hexagon/qnn/qadd_qsub_qmul.py @@ -0,0 +1,270 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +"""Compute and schedule for quantized add, multiply, subtract op + +Please note the following assumptions made by the implementation: + +1) The inputs will be multiple of crouton layout except for the axis that needs broadcasting.""" + +from tvm import te +from tvm import tir +from ..utils import get_layout_transform_fn, get_fixed_point_value + + +def broadcast_axis(tensor_A, tensor_B): + """Find out the indices that will have broadcasting""" + A_broadcast = [] + B_broadcast = [] + + for i in range(len(tensor_A.shape)): + if tensor_A.shape[i] == tensor_B.shape[i]: + A_broadcast.append(1) + B_broadcast.append(1) + elif tensor_A.shape[i] == 1: + A_broadcast.append(0) + B_broadcast.append(1) + elif tensor_B.shape[i] == 1: + A_broadcast.append(1) + B_broadcast.append(0) + return A_broadcast, B_broadcast + + +def saturate(x: te.Tensor, dtype: str): + """Saturate value for the specified data type""" + return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype))) + + +def get_int_scale( + scale_A: float, + scale_B: float, + scale_M: float, + zero_point_A: int, + zero_point_B: int, + zero_point_M: int, + op: str, +): + """ + Get fixed-point number and exp_scale_factor from topi.hexagon.utils.get_fixed_point_value. + Also, depending on the op, this function uses exp_scale_factor(log2 of the scale factor) + to adjust the output's zero_point. + """ + + C_recip = 1 / scale_M + + if op == "qmul": + scale = scale_A * scale_B * C_recip + scale_fixed_point, rsh = get_fixed_point_value(scale, "int16") + + # We need to adjust output's zero point value since the compute for the op is multiplied + # by a scaling factor. + # The scaling factor is 2^x where x is the exp_scale_factor which is assigned to rsh here. + # Since zero_point_M is multipled by 2^rsh while converting floating-point scale value + # into fixed-point number, we left shift it by rsh in our compute to reflect that. + + corr = zero_point_M << rsh + + return scale_fixed_point, rsh, corr + + a_scale_f = scale_A * C_recip + b_scale_f = scale_B * C_recip + scale_fixed_point_a, rsh_a = get_fixed_point_value(a_scale_f, "int16") + scale_fixed_point_b, rsh_b = get_fixed_point_value(b_scale_f, "int16") + + # Here we have two exp_scale_factors rsh_a and rsh_b. + # To avoid complexity, we want to use a common exp_scale_factor and + # we want to use the lowest of the two. + + # Since, either of scale_fixed_point_a or scale_fixed_point_b has already been multiplied + # by 2^max(rsh_a, rsh_b) in topi.hexagon.utils.get_fixed_point_value, + # we want to undo that by right shifting that scale_fixed_point value + # by the difference of rsh_a and rsh_b. + + # This results into having a common exp_scale_factor for both scale_fixed_point_a + # and scale_fixed_point_b. + + # We also set rsh here which is used to adjust the zero_point_M and compute the corr value, + # computation of which comes from the original equation of the op's compute. + + if rsh_a > rsh_b: + scale_fixed_point_a = scale_fixed_point_a >> (rsh_a - rsh_b) + rsh = rsh_b + else: + scale_fixed_point_b = scale_fixed_point_b >> (rsh_b - rsh_a) + rsh = rsh_a + + if op == "qadd": + corr = (zero_point_M << rsh) - ( + zero_point_A * scale_fixed_point_a + zero_point_B * scale_fixed_point_b + ) + else: + corr = (zero_point_M << rsh) - ( + zero_point_A * scale_fixed_point_a - zero_point_B * scale_fixed_point_b + ) + + return scale_fixed_point_a, scale_fixed_point_b, rsh, corr + + +def qadd_broadcast_compute( + tensor_A: te.Tensor, + tensor_B: te.Tensor, + output_shape: list, + zero_point_A: int, + scale_A: float, + zero_point_B: int, + scale_B: float, + zero_point_M: int, + scale_M: float, + dtype: str, +): + """Compute quantized add with broadcasting""" + A_broadcast, B_broadcast = broadcast_axis(tensor_A, tensor_B) + n_a, h_a, w_a, c_a = A_broadcast + n_b, h_b, w_b, c_b = B_broadcast + + scale_a, scale_b, rsh, corr = get_int_scale( + scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, "qadd" + ) + + return te.compute( + output_shape, + lambda n, h, w, c: saturate( + ( + ( + (tensor_A[n * n_a, h * h_a, w * w_a, c * c_a] * scale_a) + + (tensor_B[n * n_b, h * h_b, w * w_b, c * c_b] * scale_b) + + corr + ) + >> rsh + ), + dtype, + ).astype(dtype), + ) + + +def qsubtract_broadcast_compute( + tensor_A: te.Tensor, + tensor_B: te.Tensor, + output_shape: list, + zero_point_A: int, + scale_A: float, + zero_point_B: int, + scale_B: float, + zero_point_M: int, + scale_M: float, + dtype: str, +): + """Compute quantized subtract with broadcasting""" + A_broadcast, B_broadcast = broadcast_axis(tensor_A, tensor_B) + n_a, h_a, w_a, c_a = A_broadcast + n_b, h_b, w_b, c_b = B_broadcast + + scale_a, scale_b, rsh, corr = get_int_scale( + scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, "qsub" + ) + + return te.compute( + output_shape, + lambda n, h, w, c: saturate( + ( + ( + (tensor_A[n * n_a, h * h_a, w * w_a, c * c_a] * scale_a) + - (tensor_B[n * n_b, h * h_b, w * w_b, c * c_b] * scale_b) + + corr + ) + >> rsh + ), + dtype, + ).astype(dtype), + ) + + +def qmultiply_broadcast_compute( + tensor_A: te.Tensor, + tensor_B: te.Tensor, + output_shape: list, + zero_point_A: int, + scale_A: float, + zero_point_B: int, + scale_B: float, + zero_point_M: int, + scale_M: float, + dtype: str, +): + """Compute quantized multiply with broadcasting""" + A_broadcast, B_broadcast = broadcast_axis(tensor_A, tensor_B) + n_a, h_a, w_a, c_a = A_broadcast + n_b, h_b, w_b, c_b = B_broadcast + + scale_int, rsh, corr = get_int_scale( + scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, "qmul" + ) + + return te.compute( + output_shape, + lambda n, h, w, c: saturate( + ( + ( + scale_int + * (tensor_A[n * n_a, h * h_a, w * w_a, c * c_a] - zero_point_A) + * (tensor_B[n * n_b, h * h_b, w * w_b, c * c_b] - zero_point_B) + + corr + ) + >> rsh + ), + dtype, + ).astype(dtype), + ) + + +def tir_schedule_quant( + out_M: te.Tensor, + tensor_A: te.Tensor, + tensor_B: te.Tensor, + output_layout: str, + tensor_A_layout: str, + tensor_B_layout: str, +): + """Schedule for output layout nhwc-8h8w32c-2d""" + func = te.create_prim_func([tensor_A, tensor_B, out_M]) + + s = tir.Schedule(func) + + block = s.get_block("compute") + + if tensor_A_layout == "nhwc-8h8w32c-2d": + tensor_A_transformed_layout = get_layout_transform_fn(tensor_A_layout) + s.transform_layout(block, buffer=tensor_A.name, index_map=tensor_A_transformed_layout) + + if tensor_B_layout == "nhwc-8h8w32c-2d": + tensor_B_transformed_layout = get_layout_transform_fn(tensor_B_layout) + s.transform_layout(block, buffer=tensor_B.name, index_map=tensor_B_transformed_layout) + + output_transformed_layout = get_layout_transform_fn(output_layout) + s.transform_layout(block, buffer=out_M.name, index_map=output_transformed_layout) + + n, h, w, c = s.get_loops(block) + + h_o, h_i = s.split(h, [None, 8]) + w_o, w_i = s.split(w, [None, 8]) + c_o, c_i = s.split(c, [None, 32]) + wio, wii = s.split(w_i, [None, 4]) + + s.reorder(n, h_o, w_o, c_o, h_i, wio, wii, c_i) + + return s diff --git a/tests/python/contrib/test_hexagon/topi/test_add_subtract_multiply.py b/tests/python/contrib/test_hexagon/topi/test_add_subtract_multiply.py index 606aa628d009..fe70745143a9 100755 --- a/tests/python/contrib/test_hexagon/topi/test_add_subtract_multiply.py +++ b/tests/python/contrib/test_hexagon/topi/test_add_subtract_multiply.py @@ -22,7 +22,8 @@ import tvm from tvm import te import tvm.topi.hexagon.slice_ops as sl -from ..infrastructure import allocate_hexagon_array, transform_numpy +import tvm.topi.hexagon.qnn as qn +from ..infrastructure import allocate_hexagon_array, transform_numpy, quantize_np @tvm.testing.fixture @@ -38,34 +39,77 @@ def expected_output_np(input_np_A, input_np_B, op_name): @tvm.testing.fixture def input_np_A(input_shape_A, dtype): + if dtype == "uint8" or dtype == "int8": + dtype = "float32" return np.random.random(input_shape_A).astype(dtype) @tvm.testing.fixture def input_np_B(input_shape_B, dtype): + if dtype == "uint8" or dtype == "int8": + dtype = "float32" return np.random.random(input_shape_B).astype(dtype) @tvm.testing.fixture -def transformed_input_np_A(input_np_A, input_A_layout): - return transform_numpy(input_np_A, "nhwc", input_A_layout) +def quantize_input_np_A(input_np_A, dtype): + if dtype == "uint8" or dtype == "int8": + global zero_point_A_val, scale_A_val + input_np_A_quantized, scale_A_val, zero_point_A_val = quantize_np(input_np_A, dtype) + return input_np_A_quantized @tvm.testing.fixture -def transformed_input_np_B(input_np_B, input_B_layout): - return transform_numpy(input_np_B, "nhwc", input_B_layout) +def quantize_input_np_B(input_np_B, dtype): + if dtype == "uint8" or dtype == "int8": + global zero_point_B_val, scale_B_val + input_np_B_quantized, scale_B_val, zero_point_B_val = quantize_np(input_np_B, dtype) + return input_np_B_quantized @tvm.testing.fixture -def transformed_expected_output_np(expected_output_np, output_layout): - return transform_numpy(expected_output_np, "nhwc", output_layout) +def transformed_input_np_A(input_np_A, quantize_input_np_A, input_A_layout, dtype): + if dtype == "float16": + return transform_numpy(input_np_A, "nhwc", input_A_layout) + if dtype == "uint8" or dtype == "int8": + return transform_numpy(quantize_input_np_A, "nhwc", input_A_layout) + + raise RuntimeError(f"Unsupported data type '{dtype}'") + + +@tvm.testing.fixture +def transformed_input_np_B(input_np_B, quantize_input_np_B, input_B_layout, dtype): + if dtype == "float16": + return transform_numpy(input_np_B, "nhwc", input_B_layout) + if dtype == "uint8" or dtype == "int8": + return transform_numpy(quantize_input_np_B, "nhwc", input_B_layout) + + raise RuntimeError(f"Unsupported data type '{dtype}'") + + +@tvm.testing.fixture +def transformed_expected_output_np(expected_output_np, output_layout, dtype): + if dtype == "float16": + return transform_numpy(expected_output_np, "nhwc", output_layout) + if dtype == "uint8" or dtype == "int8": + global zero_point_M_val, scale_M_val + out_ref_quantized, scale_M_val, zero_point_M_val = quantize_np(expected_output_np, dtype) + return transform_numpy(out_ref_quantized, "nhwc", output_layout) + + raise RuntimeError(f"Unsupported data type '{dtype}'") def hexagon_wrapper_allocation( - device, layout, axis_separators, tensor_shape=None, data=None, transformed_data=None, dtype=None + device, + layout, + axis_separators, + tensor_shape=None, + data_original=None, + transformed_data=None, + dtype=None, ): """Input layout can either be nhwc-8h2w32c2w-2d or nhwc""" - if layout == "nhwc-8h2w32c2w-2d": + if layout == "nhwc-8h2w32c2w-2d" or layout == "nhwc-8h8w32c-2d": data_nd = allocate_hexagon_array( device, tensor_shape=tensor_shape, @@ -77,7 +121,7 @@ def hexagon_wrapper_allocation( elif layout == "nhwc": data_nd = allocate_hexagon_array( device, - data=data, + data=data_original, ) return data_nd @@ -136,6 +180,86 @@ class TestAddSubtractMultiplyBroadcast2d: "nhwc-8h2w32c2w-2d", "float16", ), + # broadcast all axes in one input + ( + [1, 48, 56, 32], + [1, 1, 1, 1], + "nhwc-8h2w32c2w-2d", + "nhwc", + "nhwc-8h2w32c2w-2d", + "float16", + ), + ( + [1, 48, 32, 64], + [1, 48, 32, 64], + "nhwc-8h8w32c-2d", + "nhwc-8h8w32c-2d", + "nhwc-8h8w32c-2d", + "uint8", + ), + # broadcast axis 2 in one input + ( + [1, 48, 32, 64], + [1, 48, 1, 64], + "nhwc-8h8w32c-2d", + "nhwc", + "nhwc-8h8w32c-2d", + "uint8", + ), + # broadcast axis 1 in one input + ( + [1, 48, 32, 64], + [1, 1, 32, 64], + "nhwc-8h8w32c-2d", + "nhwc", + "nhwc-8h8w32c-2d", + "uint8", + ), + # broadcast axis 3 in one input + ( + [1, 8, 8, 32], + [1, 8, 8, 1], + "nhwc-8h8w32c-2d", + "nhwc", + "nhwc-8h8w32c-2d", + "uint8", + ), + # broadcast both inputs + ( + [1, 56, 1, 128], + [1, 1, 64, 1], + "nhwc", + "nhwc", + "nhwc-8h8w32c-2d", + "uint8", + ), + # broadcast both inputs + ( + [1, 48, 1, 1], + [1, 1, 32, 32], + "nhwc", + "nhwc", + "nhwc-8h8w32c-2d", + "uint8", + ), + # broadcast both inputs + ( + [1, 48, 1, 32], + [1, 1, 32, 1], + "nhwc", + "nhwc", + "nhwc-8h8w32c-2d", + "uint8", + ), + # broadcast all axes in one input + ( + [1, 48, 56, 32], + [1, 1, 1, 1], + "nhwc-8h8w32c-2d", + "nhwc", + "nhwc-8h8w32c-2d", + "uint8", + ), ) op_name = tvm.testing.parameter("add", "subtract", "multiply") @@ -148,6 +272,8 @@ def test_transform( input_shape_B, input_np_A, input_np_B, + quantize_input_np_A, + quantize_input_np_B, transformed_input_np_A, transformed_input_np_B, expected_output_np, @@ -158,23 +284,50 @@ def test_transform( input_B_layout, op_name, ): + output_shape = expected_output_np.shape target_hexagon = tvm.target.hexagon("v69") A = te.placeholder(input_shape_A, name="A", dtype=dtype) B = te.placeholder(input_shape_B, name="B", dtype=dtype) - if op_name == "add": - M = sl.add_broadcast_compute(A, B) - elif op_name == "subtract": - M = sl.subtract_broadcast_compute(A, B) - elif op_name == "multiply": - M = sl.multiply_broadcast_compute(A, B) - - tir_schedule = sl.tir_broadcast_schedule( - M, A, B, output_layout, input_A_layout, input_B_layout, op_name - ) + if dtype == "float16": + if op_name == "add": + M = sl.add_broadcast_compute(A, B) + elif op_name == "subtract": + M = sl.subtract_broadcast_compute(A, B) + elif op_name == "multiply": + M = sl.multiply_broadcast_compute(A, B) + tir_schedule = sl.tir_broadcast_schedule( + M, A, B, output_layout, input_A_layout, input_B_layout, op_name + ) + elif dtype == "uint8" or dtype == "int8": + args = [ + A, + B, + output_shape, + zero_point_A_val, + scale_A_val, + zero_point_B_val, + scale_B_val, + zero_point_M_val, + scale_M_val, + dtype, + ] + if op_name == "add": + M = qn.qadd_broadcast_compute(*args) + elif op_name == "subtract": + M = qn.qsubtract_broadcast_compute(*args) + elif op_name == "multiply": + M = qn.qmultiply_broadcast_compute(*args) + tir_schedule = qn.tir_schedule_quant( + M, A, B, output_layout, input_A_layout, input_B_layout + ) + sch = tir_schedule.mod input_axis_separator = [4] - if output_layout == "nhwc-8h2w32c2w-2d": + if output_layout in ( + "nhwc-8h2w32c2w-2d", + "nhwc-8h8w32c-2d", + ): output_axis_separator = [4] else: raise RuntimeError(f"Unexpected layout '{output_layout}'") @@ -187,19 +340,26 @@ def test_transform( name="slice_op_with_transform", ) - output_shape = expected_output_np.shape + if dtype == "float16": + in_data_np_A = input_np_A + in_data_np_B = input_np_B + elif dtype == "int8" or dtype == "uint8": + in_data_np_A = quantize_input_np_A + in_data_np_B = quantize_input_np_B + else: + raise RuntimeError(f"Unsupport dtype '{dtype}'") A_data_nd = hexagon_wrapper_allocation( hexagon_session.device, layout=input_A_layout, - data=input_np_A, + data_original=in_data_np_A, transformed_data=transformed_input_np_A, axis_separators=input_axis_separator, ) B_data_nd = hexagon_wrapper_allocation( hexagon_session.device, layout=input_B_layout, - data=input_np_B, + data_original=in_data_np_B, transformed_data=transformed_input_np_B, axis_separators=input_axis_separator, ) @@ -218,8 +378,15 @@ def test_transform( # convert nd to np and reshape to fixed chunk size layout if output_layout == "nhwc-8h2w32c2w-2d": M_data_np = M_data_nd.numpy().reshape([b, h // 8, w // 4, c // 32, 8, 2, 32, 2]) + elif output_layout == "nhwc-8h8w32c-2d": + M_data_np = M_data_nd.numpy().reshape([b, h // 8, w // 8, c // 32, 8, 8, 32]) - np.testing.assert_allclose(transformed_expected_output_np, M_data_np, rtol=1e-3, atol=1e-3) + if dtype == "float16": + np.testing.assert_allclose( + transformed_expected_output_np, M_data_np, rtol=1e-3, atol=1e-3 + ) + elif dtype == "int8" or dtype == "uint8": + np.testing.assert_allclose(transformed_expected_output_np, M_data_np, rtol=1, atol=1) if __name__ == "__main__":