From 9783e887a728c671ee36654a25d3195639ee863c Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Tue, 21 Jun 2022 17:09:02 +0800 Subject: [PATCH] [cherry pick #43088 #40664] Add float16 to fake quantize/dequantize OP (#43689) * cherry pick #43088 #40664 * fix clang format --- paddle/fluid/operators/fake_dequantize_op.cu | 7 +- paddle/fluid/operators/fake_quantize_op.cu | 15 +- paddle/fluid/operators/fake_quantize_op.cu.h | 86 ++- .../unittests/test_fake_dequantize_op.py | 75 ++- .../tests/unittests/test_fake_quantize_op.py | 523 ++++++++---------- 5 files changed, 366 insertions(+), 340 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.cu b/paddle/fluid/operators/fake_dequantize_op.cu index 582f0627b2044..bfa663d59b95c 100644 --- a/paddle/fluid/operators/fake_dequantize_op.cu +++ b/paddle/fluid/operators/fake_dequantize_op.cu @@ -17,10 +17,13 @@ limitations under the License. */ namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; +using float16 = paddle::platform::float16; REGISTER_OP_CUDA_KERNEL(fake_dequantize_max_abs, ops::FakeDequantizeMaxAbsKernel, - ops::FakeDequantizeMaxAbsKernel); + ops::FakeDequantizeMaxAbsKernel, + ops::FakeDequantizeMaxAbsKernel); REGISTER_OP_CUDA_KERNEL( fake_channel_wise_dequantize_max_abs, ops::FakeChannelWiseDequantizeMaxAbsKernel, - ops::FakeChannelWiseDequantizeMaxAbsKernel); + ops::FakeChannelWiseDequantizeMaxAbsKernel, + ops::FakeChannelWiseDequantizeMaxAbsKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 5416ae11c2b56..42361407a0f0a 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -19,17 +19,22 @@ namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; using float16 = paddle::platform::float16; REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max, - ops::FakeQuantizeAbsMaxKernel); + ops::FakeQuantizeAbsMaxKernel, + ops::FakeQuantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max, ops::FakeQuantizeDequantizeAbsMaxKernel, ops::FakeQuantizeDequantizeAbsMaxKernel); -REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max, - ops::FakeChannelWiseQuantizeAbsMaxKernel); +REGISTER_OP_CUDA_KERNEL( + fake_channel_wise_quantize_abs_max, + ops::FakeChannelWiseQuantizeAbsMaxKernel, + ops::FakeChannelWiseQuantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, - ops::FakeQuantizeRangeAbsMaxKernel); + ops::FakeQuantizeRangeAbsMaxKernel, + ops::FakeQuantizeRangeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL( fake_quantize_moving_average_abs_max, - ops::FakeQuantizeMovingAverageAbsMaxKernel); + ops::FakeQuantizeMovingAverageAbsMaxKernel, + ops::FakeQuantizeMovingAverageAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleKernel, ops::MovingAverageAbsMaxScaleKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.cu.h b/paddle/fluid/operators/fake_quantize_op.cu.h index ae448b7ff2c8b..e809c55368403 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu.h +++ b/paddle/fluid/operators/fake_quantize_op.cu.h @@ -17,6 +17,7 @@ limitations under the License. */ #endif // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_ #include + #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/fake_quantize_op.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" @@ -24,6 +25,16 @@ limitations under the License. */ namespace paddle { namespace operators { +template +struct QuantizeDataType { + using type = T; +}; + +template <> +struct QuantizeDataType { + using type = float; +}; + template __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; @@ -87,10 +98,12 @@ __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, int tid = threadIdx.x; int channel_size = n / c; const T* in_c = in + blockIdx.x * channel_size; - extern __shared__ T shared_max_data[]; + extern __shared__ char* shared_max_data_tmp[]; + auto shared_max_data = reinterpret_cast(shared_max_data_tmp); T local_max_data = T(0); for (int i = tid; i < channel_size; i += blockDim.x) { - T tmp = fabs(in_c[i]); + T tmp = static_cast( + fabs(static_cast::type>(in_c[i]))); if (tmp > local_max_data) { local_max_data = tmp; } @@ -112,7 +125,8 @@ template __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, const int cin, const int cout, T* out) { - extern __shared__ T shared_max_data[]; + extern __shared__ char* shared_max_data_tmp[]; + auto shared_max_data = reinterpret_cast(shared_max_data_tmp); int cout_wh_size = n / cin; int wh_size = n / (cin * cout); @@ -121,7 +135,8 @@ __global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n, const T* in_current = in + tid * cout_wh_size + bid * wh_size; T local_max_data = T(0); for (int i = 0; i < wh_size; i++) { - T tmp = fabs(in_current[i]); + T tmp = static_cast( + fabs(static_cast::type>(in_current[i]))); if (tmp > local_max_data) { local_max_data = tmp; } @@ -203,14 +218,18 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; - T s = scale[0]; - T inv_s = inverse(s); + using ComputeDataType = typename QuantizeDataType::type; + + ComputeDataType s = static_cast(scale[0]); + ComputeDataType inv_s = inverse(s); + ComputeDataType bin_cnt_t = static_cast(bin_cnt); + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T x = in[i]; - T v = x > s ? s : x; + ComputeDataType x = static_cast(in[i]); + ComputeDataType v = x > s ? s : x; v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out[i] = round(v); + v = bin_cnt_t * inv_s * v; + out[i] = static_cast(round(v)); } } @@ -221,17 +240,19 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; - T s = scale[0]; - T inv_s = inverse(s); - T bin_cnt_t = static_cast(bin_cnt); + using ComputeDataType = typename QuantizeDataType::type; + + ComputeDataType s = static_cast(scale[0]); + ComputeDataType inv_s = inverse(s); + ComputeDataType bin_cnt_t = static_cast(bin_cnt); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T x = in[i]; + ComputeDataType x = static_cast(in[i]); x = x > s ? s : x; x = x < -s ? -s : x; x = bin_cnt_t * inv_s * x; - x = static_cast(round(static_cast(x))); - out[i] = (x * s) / bin_cnt_t; + x = round(x); + out[i] = static_cast((x * s) / bin_cnt_t); } } @@ -285,15 +306,18 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, const T* in_c = in + blockIdx.x * channel_size; T* out_c = out + blockIdx.x * channel_size; - T s = scale[blockIdx.x]; - T inv_s = inverse(s); + using ComputeDataType = typename QuantizeDataType::type; + + ComputeDataType s = static_cast(scale[blockIdx.x]); + ComputeDataType inv_s = inverse(s); + ComputeDataType bin_cnt_t = static_cast(bin_cnt); for (int64_t i = tid; i < channel_size; i += blockDim.x) { - T x = in_c[i]; - T v = x > s ? s : x; + ComputeDataType x = static_cast(in_c[i]); + ComputeDataType v = x > s ? s : x; v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out_c[i] = round(v); + v = bin_cnt_t * inv_s * v; + out_c[i] = static_cast(round(v)); } } @@ -303,14 +327,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN( const T* in, const T* scale, const int bin_cnt, const int64_t n, const int nScale, const int quant_stride, T* out) { int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + using ComputeDataType = typename QuantizeDataType::type; + ComputeDataType bin_cnt_t = static_cast(bin_cnt); for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) { - T s = scale[(i / quant_stride) % nScale]; - T inv_s = inverse(s); - T x = in[i]; - T v = x > s ? s : x; + ComputeDataType s = + static_cast(scale[(i / quant_stride) % nScale]); + ComputeDataType inv_s = inverse(s); + ComputeDataType x = static_cast(in[i]); + ComputeDataType v = x > s ? s : x; v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out[i] = round(v); + v = bin_cnt_t * inv_s * v; + out[i] = static_cast(round(v)); } } @@ -376,7 +403,8 @@ __global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale, scale_arr[idx] = cur; T max = last_scale[0]; out_scale[0] = max < cur ? cur : max; - if (fabs(removed - max) < 1e-6) { + if (fabs(static_cast::type>(removed - max)) < + 1e-6) { need_find_max[0] = 1; out_size[0] = it > window_size ? window_size : it; } else { diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index 728e178845c9b..8dfe776530b9b 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -18,6 +18,7 @@ import numpy as np import math from op_test import OpTest +import paddle.fluid.core as core def quantize_max_abs(x, max_range): @@ -76,22 +77,25 @@ def channel_wise_dequantize_max_abs(x, class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): def set_args(self): self.quant_bits = [8, 8] - self.data_type = "float32" self.activation_scale = 0.7861 + def set_dtype(self): + self.dtype = np.float32 + def setUp(self): self.set_args() + self.set_dtype() self.op_type = "fake_channel_wise_dequantize_max_abs" - x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + x = np.random.randn(4, 3, 64, 64).astype(self.dtype) yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], 1) ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, 1, self.activation_scale) self.inputs = { 'X': yq, - 'Scales': [("scales0", np.array(scales).astype(self.data_type)), - ("scales1", np.array( - [self.activation_scale]).astype(self.data_type))] + 'Scales': + [("scales0", np.array(scales).astype(self.dtype)), + ("scales1", np.array([self.activation_scale]).astype(self.dtype))] } self.attrs = {'quant_bits': self.quant_bits} self.outputs = {'Out': ydq} @@ -100,16 +104,28 @@ def test_check_output(self): self.check_output() +class TestFakeChannelWiseDequantizeMaxAbsOpTwoScalesFloat16( + TestFakeChannelWiseDequantizeMaxAbsOpTwoScales): + def set_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output(atol=1e-2) + + class TestFakeChannelWiseDequantizeMaxAbsOpOneScale(OpTest): def set_args(self): self.quant_bits = [8] - self.data_type = "float32" self.quant_axis = 0 + def set_dtype(self): + self.dtype = np.float32 + def setUp(self): self.set_args() + self.set_dtype() self.op_type = "fake_channel_wise_dequantize_max_abs" - x = np.random.randn(4, 3, 64, 64).astype(self.data_type) + x = np.random.randn(4, 3, 64, 64).astype(self.dtype) yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0], self.quant_axis) ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, @@ -117,7 +133,7 @@ def setUp(self): self.inputs = { 'X': yq, - 'Scales': [("scales0", np.array(scales).astype(self.data_type))] + 'Scales': [("scales0", np.array(scales).astype(self.dtype))] } self.attrs = { 'quant_bits': self.quant_bits, @@ -133,24 +149,44 @@ class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1( TestFakeChannelWiseDequantizeMaxAbsOpOneScale): def set_args(self): self.quant_bits = [8] - self.data_type = "float32" self.quant_axis = 1 +class TestFakeChannelWiseDequantizeMaxAbsOpOneScaleFloat16( + TestFakeChannelWiseDequantizeMaxAbsOpOneScale): + def set_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output(atol=1e-2) + + +class TestFakeChannelWiseDequantizeMaxAbsOpOneScale1Float16( + TestFakeChannelWiseDequantizeMaxAbsOpOneScale1): + def set_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output(atol=1e-2) + + class TestFakeDequantizeMaxAbsOp(OpTest): def set_args(self): self.num_bits = 8 self.max_range = math.pow(2, self.num_bits - 1) - 1 - self.data_type = "float32" + + def set_dtype(self): + self.dtype = np.float32 def setUp(self): self.set_args() + self.set_dtype() self.op_type = "fake_dequantize_max_abs" - x = np.random.randn(31, 65).astype(self.data_type) + x = np.random.randn(31, 65).astype(self.dtype) yq, scale = quantize_max_abs(x, self.max_range) ydq = dequantize_max_abs(yq, scale, self.max_range) - self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.data_type)} + self.inputs = {'X': yq, 'Scale': np.array(scale).astype(self.dtype)} self.attrs = {'max_range': self.max_range} self.outputs = {'Out': ydq} @@ -159,17 +195,22 @@ def test_check_output(self): class TestFakeDequantizeMaxAbsOpDouble(TestFakeDequantizeMaxAbsOp): - def set_args(self): - self.num_bits = 8 - self.max_range = math.pow(2, self.num_bits - 1) - 1 - self.data_type = "float64" + def set_dtype(self): + self.dtype = np.float64 class TestFakeDequantizeMaxAbsOp5Bits(TestFakeDequantizeMaxAbsOp): def set_args(self): self.num_bits = 5 self.max_range = math.pow(2, self.num_bits - 1) - 1 - self.data_type = "float32" + + +class TestFakeDequantizeMaxAbsOpFloat16(TestFakeDequantizeMaxAbsOp): + def set_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output(atol=1e-2) class TestChannelWiseDequantizeOp(OpTest): diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 2be61d1218560..df458f97d5996 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,364 +15,313 @@ from __future__ import print_function import unittest +import itertools import numpy as np import math from op_test import OpTest -import paddle.fluid.core as core -class TestFakeQuantizeOp(OpTest): - def setUp(self): - self.op_type = "fake_quantize_abs_max" - self.attrs = {'bit_length': 8} - self.inputs = {'X': np.random.random((124, 240)).astype("float32"), } - scale = np.max(np.abs(self.inputs['X'])).astype("float32") - self.outputs = { - 'Out': np.round(self.inputs['X'] / scale * ( - (1 << (self.attrs['bit_length'] - 1)) - 1)), - 'OutScale': np.array(scale).astype("float32"), - } +# numpy.round has different behavior in comparision to c++ round function +# so we use round_c instead of numpy.round to align the output data +def round_c_single_element(val): + dtype = type(val) + if val >= 0: + return dtype(np.floor(val + 0.5)) + return dtype(np.ceil(val - 0.5)) - def test_check_output(self): - self.check_output() +round_c = np.vectorize(round_c_single_element) -class TestFakeQuantizeOp1(OpTest): - def setUp(self): - self.op_type = "fake_quantize_abs_max" - self.attrs = {'bit_length': 8} - self.inputs = {'X': np.zeros((10, 10)).astype("float32"), } - scale = np.max(np.abs(self.inputs['X'])).astype("float32") - inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale - self.outputs = { - 'Out': np.round(self.inputs['X'] * inv_scale * ( - (1 << (self.attrs['bit_length'] - 1)) - 1)), - 'OutScale': np.array(scale).astype("float32"), - } - def test_check_output(self): - self.check_output() +def get_compute_type(dtype): + assert dtype in [np.float16, np.float32, np.float64] + if dtype == np.float16: + return np.float32 + return dtype -class TestFakeQuantizeOp2(OpTest): +class TestFakeQuantizeAbsMaxOp(OpTest): def setUp(self): - self.op_type = "fake_quantize_abs_max" + self.op_type = 'fake_quantize_abs_max' self.attrs = {'bit_length': 8} - self.inputs = {'X': np.full((10, 10), 1e-40).astype("float32"), } - scale = np.max(np.abs(self.inputs['X'])).astype("float32") - inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale - self.outputs = { - 'Out': np.round(self.inputs['X'] * inv_scale * ( - (1 << (self.attrs['bit_length'] - 1)) - 1)), - 'OutScale': np.array(scale).astype("float32"), - } - def test_check_output(self): + def _fake_quantize_abs_max(self, dtype, input_shape, distribution): + input_data = distribution(input_shape).astype(dtype) + compute_type = get_compute_type(dtype) + scale = np.max(np.abs(input_data)) + bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 + inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale + output_data = round_c(input_data.astype(compute_type) * inv_scale * bnt) + self.inputs = {'X': input_data} + self.outputs = {'Out': output_data, 'OutScale': scale} + self.dtype = dtype self.check_output() + def test_fake_quantize_abs_max(self): + self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random) -class TestFakeChannelWiseQuantizeOp(OpTest): - def setUp(self): - self.set_arg() - assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1." + def test_fake_quantize_abs_max_float16(self): + self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random) - self.op_type = "fake_channel_wise_quantize_abs_max" - self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis} + def test_fake_quantize_abs_max_underflow(self): + self._fake_quantize_abs_max(np.float32, (10, 10), np.zeros) - scales = [] - outputs = self.inputs['X'].copy() - bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 - if self.quant_axis == 0: - for i in range(self.inputs['X'].shape[0]): - scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32") - scales.append(scale_v) - outputs[i] = np.round(outputs[i] / scale_v * bnt) - elif self.quant_axis == 1: - for i in range(self.inputs['X'].shape[1]): - scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype( - "float32") - scales.append(scale_v) - outputs[:, i] = np.round(outputs[:, i] / scale_v * bnt) + def test_fake_quantize_abs_max_underflow2(self): + self._fake_quantize_abs_max(np.float32, (10, 10), + lambda shape: np.full(shape, 1e-40)) - self.outputs = { - 'Out': outputs, - 'OutScale': np.array(scales).astype("float32"), - } - def set_arg(self): - self.quant_axis = 0 - self.inputs = { - 'X': np.random.random((20, 15, 6, 6)).astype("float32"), - } +class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): + def setUp(self): + self.op_type = 'fake_channel_wise_quantize_abs_max' + self.attrs = {'bit_length': 8} - def test_check_output(self): + def _fake_channel_wise_quantize_abs_max(self, dtype, input_shape, + quant_axis, distribution): + assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' + input_data = distribution(input_shape).astype(dtype) + compute_type = get_compute_type(dtype) + bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 + compute_axis = tuple( + i for i in range(len(input_shape)) if i != quant_axis) + scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) + output_data = round_c(bnt * input_data.astype(compute_type) / + scale_broadcast) + if quant_axis == 1: + scale_broadcast = np.transpose(scale_broadcast, + (1, ) + compute_axis) + scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0] + self.inputs = {'X': input_data} + self.outputs = {'Out': output_data, 'OutScale': scale} + self.dtype = dtype + self.attrs['quant_axis'] = quant_axis self.check_output() - -class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp): - def set_quant_axis(self): - self.quant_axis = 1 - self.inputs = { - 'X': np.random.random((15, 20, 5, 5)).astype("float32"), - } - - -class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp): - def set_quant_axis(self): - self.quant_axis = 0 - self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } - - -class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp): - def set_quant_axis(self): - self.quant_axis = 1 - self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } + def test_fake_channel_wise_quantize_abs_max(self): + dtype_options = [np.float32, np.float16] + input_shape_quant_axis_options = [[(20, 15, 6, 6), 0], + [(15, 20, 5, 5), 1], [(30, 15), 0], + [(30, 15), 1]] + for dtype, input_shape_quant_axis in itertools.product( + dtype_options, input_shape_quant_axis_options): + input_shape, quant_axis = input_shape_quant_axis + with self.subTest( + dtype=dtype, input_shape=input_shape, + quant_axis=quant_axis): + self._fake_channel_wise_quantize_abs_max( + dtype, input_shape, quant_axis, np.random.random) class TestFakeQuantizeRangeAbsMaxOp(OpTest): def setUp(self): - self.op_type = "fake_quantize_range_abs_max" - self.attrs = { - 'bit_length': int(5), - 'window_size': int(1), - 'is_test': False - } - x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10 - x = x.astype("float32") + self.op_type = 'fake_quantize_range_abs_max' + self.attrs = {'bit_length': 5, 'window_size': 1} + + def _fake_quantize_range_abs_max(self, + dtype, + input_shape, + distribution, + is_test=False): + input_data = distribution(input_shape).astype(dtype) + compute_type = get_compute_type(dtype) + bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 + in_scale = np.zeros(1).astype(dtype) + out_scale = np.zeros(self.attrs['window_size']).astype(dtype) + out_scale[0] = np.max(np.abs(input_data)) + if is_test: + out_scale[0] = in_scale[0] = out_scale[0] - 1.0 + clip_data = np.clip(input_data, -in_scale, in_scale) + else: + clip_data = input_data + output_data = round_c( + clip_data.astype(compute_type) / out_scale[0] * bnt) self.inputs = { - 'X': x, - 'Iter': np.zeros(1).astype("int64"), - 'InScale': np.zeros(1).astype("float32") + 'X': input_data, + 'Iter': np.zeros(1).astype(np.int64), + 'InScale': in_scale } - scale = np.max(np.abs(self.inputs['X'])).astype("float32") - - out_scales = np.zeros(self.attrs['window_size']).astype("float32") - out_scales[0] = scale self.outputs = { - 'Out': np.round(self.inputs['X'] / scale * ( - (1 << (self.attrs['bit_length'] - 1)) - 1)), - 'OutScale': scale, - 'OutScales': out_scales, + 'Out': output_data, + 'OutScale': out_scale[0], + 'OutScales': out_scale } - - def test_check_output(self): + self.dtype = dtype + self.attrs['is_test'] = is_test self.check_output() + def test_fake_quantize_range_abs_max(self): + dtype_options = [np.float32, np.float16] + is_test_options = [False, True] + for dtype, is_test in itertools.product(dtype_options, is_test_options): + self.attrs['bit_length'] = 8 if is_test else 5 + with self.subTest(dtype=dtype, is_test=is_test): + self._fake_quantize_range_abs_max( + dtype, (8, 16, 7, 7), + lambda shape: (np.random.random(shape) - 0.5) * 10, + is_test=is_test) + class TestMovingAverageAbsMaxScaleOp(OpTest): def setUp(self): - self.op_type = "moving_average_abs_max_scale" + self.op_type = 'moving_average_abs_max_scale' self.attrs = {'moving_rate': float(0.9), 'is_test': False} - accum = np.zeros(1).astype("float32") - accum[0] = 1 - state = np.zeros(1).astype("float32") - state[0] = 1 - x = np.random.random((8, 16, 7, 7)).astype("float32") - self.inputs = { - 'X': x, - 'InAccum': accum, - 'InState': state, - } - out = x - out_accum = np.zeros(1).astype("float32") - out_state = np.zeros(1).astype("float32") - out_scale = np.zeros(1).astype("float32") - out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max( - np.abs(self.inputs['X'])).astype("float32") - out_state[0] = self.attrs['moving_rate'] * state[0] + 1 + def _moving_average_abs_max_scale(self, dtype, input_shape, distribution): + input_data = distribution(input_shape).astype(dtype) + in_accum = np.ones(1).astype(dtype) + in_state = np.ones(1).astype(dtype) + out_accum = self.attrs['moving_rate'] * in_accum[0] + np.max( + np.abs(input_data)) + out_state = self.attrs['moving_rate'] * in_state[0] + 1.0 out_scale = out_accum / out_state + self.inputs = { + 'X': input_data, + 'InAccum': in_accum, + 'InState': in_state + } self.outputs = { - 'Out': out, + 'Out': input_data, 'OutAccum': out_accum, 'OutState': out_state, - 'OutScale': out_scale, + 'OutScale': out_scale } - - def test_check_output(self): + self.dtype = dtype self.check_output() - -class TestFakeQuantizeRangeAbsMaxOp2(OpTest): - def setUp(self): - self.op_type = "fake_quantize_range_abs_max" - self.attrs = { - 'bit_length': int(8), - 'window_size': int(1), - 'is_test': True - } - x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10 - x = x.astype("float32") - scale = np.array([np.max(np.abs(x)).astype("float32") - 1.0]) - out_scales = np.zeros(self.attrs['window_size']).astype("float32") - out_scales[0] = scale - self.inputs = { - 'X': x, - 'Iter': np.zeros(1).astype("int64"), - 'InScale': scale.astype("float32") - } - xs = np.clip(x, -scale, scale) - qs = np.round(xs / scale * ((1 << (self.attrs['bit_length'] - 1)) - 1)) - self.outputs = { - 'Out': qs, - 'OutScale': scale.astype("float32"), - 'OutScales': out_scales, - } - - def test_check_output(self): - self.check_output(no_check_set=set(['OutScale', 'OutScales'])) + def test_moving_average_abs_max(self): + self._moving_average_abs_max_scale(np.float32, (8, 16, 7, 7), + np.random.random) -class TestMovingOpBase(OpTest): +class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): def setUp(self): - self.init_type() - self.attrs = { - 'bit_length': int(5), - 'moving_rate': float(0.9), - 'is_test': False - } - accum = np.zeros(1).astype("float32") - accum[0] = 1 - state = np.zeros(1).astype("float32") - state[0] = 1 - scale = np.zeros(1).astype("float32") - scale[0] = 0.001 + self.op_type = 'fake_quantize_moving_average_abs_max' + self.attrs = {'bit_length': 5, 'moving_rate': 0.9, 'is_test': False} + + def _fake_quantize_moving_average_abs_max(self, + dtype, + input_shape, + distribution, + dequantize=False, + with_gradient=False): + input_data = distribution(input_shape).astype(dtype) + compute_type = get_compute_type(dtype) + bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 + in_accum = np.ones(1).astype(dtype) + in_state = np.ones(1).astype(dtype) + in_scale = np.array([0.001]).astype(dtype) + out_accum = np.zeros(1).astype(dtype) + out_state = np.zeros(1).astype(dtype) + out_scale = np.zeros(1).astype(dtype) + out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max( + np.abs(input_data)) + out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0 + out_scale = out_accum / out_state + round_data = round_c(input_data.astype(compute_type) / out_scale * bnt) + if dequantize: + output_data = (round_data * out_scale / bnt).astype(dtype) + self.op_type = 'fake_quantize_dequantize_moving_average_abs_max' + else: + output_data = round_data.astype(dtype) self.inputs = { - 'X': np.random.random((8, 16, 7, 7)).astype("float32"), - 'InScale': scale, - 'InAccum': accum, - 'InState': state, + 'X': input_data, + 'InScale': in_scale, + 'InAccum': in_accum, + 'InState': in_state } - - out_accum = np.zeros(1).astype("float32") - out_state = np.zeros(1).astype("float32") - out_scale = np.zeros(1).astype("float32") - out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max( - np.abs(self.inputs['X'])).astype("float32") - out_state[0] = self.attrs['moving_rate'] * state[0] + 1 - out_scale = out_accum / out_state - out_data = self.calc_output(out_scale) self.outputs = { - 'Out': out_data, + 'Out': output_data, 'OutAccum': out_accum, 'OutState': out_state, - 'OutScale': out_scale, + 'OutScale': out_scale } - - def init_type(self): - self.op_type = "fake_quantize_moving_average_abs_max" - - def calc_output(self, out_scale): - return np.round(self.inputs['X'] / out_scale * ( - (1 << (self.attrs['bit_length'] - 1)) - 1)) - - def test_check_output(self): + self.dtype = dtype self.check_output() + if with_gradient: + gradient = [ + np.ones(input_data.shape) / np.product(input_data.shape) + ] + self.check_grad(['X'], 'Out', user_defined_grads=gradient) + def test_fake_quantize_moving_average_abs_max(self): + self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7), + np.random.random) -class TestFakeQuantDequantMovingOp(TestMovingOpBase): - def init_type(self): - self.op_type = "fake_quantize_dequantize_moving_average_abs_max" - - def calc_output(self, out_scale): - range_v = (1 << (self.attrs['bit_length'] - 1)) - 1 - return np.round(self.inputs['X'] / out_scale * - range_v) * out_scale / range_v + def test_fake_quantize_moving_average_abs_max_float16(self): + self._fake_quantize_moving_average_abs_max(np.float16, (8, 16, 7, 7), + np.random.random) - def test_check_grad(self): - x = self.inputs["X"] - gradient = [np.ones(x.shape) / np.product(x.shape)] - self.check_grad(["X"], "Out", user_defined_grads=gradient) + def test_fake_quantize_dequantize_moving_average_abs_max(self): + self._fake_quantize_moving_average_abs_max( + np.float32, (8, 16, 7, 7), + np.random.random, + dequantize=True, + with_gradient=True) -class TestFakeQuantDequantAbsOp(OpTest): +class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): def setUp(self): - self.op_type = "fake_quantize_dequantize_abs_max" + self.op_type = 'fake_quantize_dequantize_abs_max' self.attrs = {'bit_length': 8} - self.inputs = {'X': np.random.random((124, 240)).astype("float32"), } - scale = np.max(np.abs(self.inputs['X'])).astype("float32") - out_data = self.calc_output(scale) + + def _fake_quantize_dequantize_abs_max(self, dtype, input_shape, + distribution): + input_data = distribution(input_shape).astype(dtype) + scale = np.max(np.abs(input_data)).astype(dtype) + bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 + output_data = round_c(input_data / scale * bnt) * scale / bnt + self.inputs = {'X': input_data} self.outputs = { - 'Out': out_data, - 'OutScale': np.array(scale).astype("float32"), + 'Out': output_data, + 'OutScale': np.array(scale).astype(dtype) } - - def calc_output(self, scale): - range_v = (1 << (self.attrs['bit_length'] - 1)) - 1 - return np.round(self.inputs['X'] / scale * range_v) * scale / range_v - - def test_check_output(self): + self.dtype = dtype self.check_output() + gradient = [np.ones(input_data.shape) / np.product(input_data.shape)] + self.check_grad(['X'], 'Out', user_defined_grads=gradient) - def test_check_grad(self): - x = self.inputs["X"] - gradient = [np.ones(x.shape) / np.product(x.shape)] - self.check_grad(["X"], "Out", user_defined_grads=gradient) + def test_fake_quantize_dequantize_abs_max(self): + self._fake_quantize_dequantize_abs_max(np.float32, (124, 240), + np.random.random) -class TestChannelWiseFakeQuantDequantOp(OpTest): +class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): def setUp(self): - self.set_arg() - assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1." - - self.op_type = "fake_channel_wise_quantize_dequantize_abs_max" - self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis} - - scales = [] - outputs = self.inputs['X'].copy() - range_v = (1 << (self.attrs['bit_length'] - 1)) - 1 - if self.quant_axis == 0: - for i in range(self.inputs['X'].shape[0]): - scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32") - scales.append(scale_v) - outputs[i] = np.round(outputs[i] * range_v / - scale_v) * scale_v / range_v - elif self.quant_axis == 1: - for i in range(self.inputs['X'].shape[1]): - scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype( - "float32") - scales.append(scale_v) - outputs[:, i] = np.round(outputs[:, i] * range_v / - scale_v) * scale_v / range_v - - self.outputs = { - 'Out': outputs, - 'OutScale': np.array(scales).astype("float32"), - } - - def set_arg(self): - self.quant_axis = 0 - self.inputs = { - 'X': np.random.random((3, 4, 64, 64)).astype("float32"), - } + self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max' + self.attrs = {'bit_length': 8} - def test_check_output(self): + def _fake_channel_wise_quantize_dequantize_abs_max( + self, dtype, input_shape, quant_axis, distribution): + assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' + input_data = distribution(input_shape).astype(dtype) + compute_type = get_compute_type(dtype) + bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 + output_data = input_data.copy().astype(compute_type) + compute_axis = tuple( + i for i in range(len(input_shape)) if i != quant_axis) + scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) + output_data = round_c(bnt * output_data / + scale_broadcast) * scale_broadcast / bnt + if quant_axis == 1: + scale_broadcast = np.transpose(scale_broadcast, + (1, ) + compute_axis) + scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0] + self.inputs = {'X': input_data} + self.outputs = {'Out': output_data, 'OutScale': scale} + self.dtype = dtype + self.attrs['quant_axis'] = quant_axis self.check_output() + gradient = [np.ones(input_data.shape) / np.product(input_data.shape)] + self.check_grad(['X'], 'Out', user_defined_grads=gradient) - def test_check_grad(self): - x = self.inputs["X"] - gradient = [np.ones(x.shape) / np.product(x.shape)] - self.check_grad(["X"], "Out", user_defined_grads=gradient) - - -class TestChannelWiseFakeQuantDequantOp1(TestChannelWiseFakeQuantDequantOp): - def set_arg(self): - self.quant_axis = 1 - self.inputs = { - 'X': np.random.random((15, 20, 5, 5)).astype("float32"), - } - - -class TestChannelWiseFakeQuantDequantOp2(TestChannelWiseFakeQuantDequantOp): - def set_arg(self): - self.quant_axis = 0 - self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } - - -class TestChannelWiseFakeQuantDequantOp3(TestChannelWiseFakeQuantDequantOp): - def set_arg(self): - self.quant_axis = 1 - self.inputs = {'X': np.random.random((30, 15)).astype("float32"), } + def test_channel_wise_fake_quant_dequant_abs_max(self): + input_shape_quant_axis_options = [[(3, 4, 64, 64), 0], + [(15, 20, 5, 5), 1], [(30, 15), 0], + [(30, 15), 1]] + for input_shape, quant_axis in input_shape_quant_axis_options: + with self.subTest(input_shape=input_shape, quant_axis=quant_axis): + self._fake_channel_wise_quantize_dequantize_abs_max( + np.float32, input_shape, quant_axis, np.random.random) def quantize_max_abs(x, max_range): @@ -514,5 +463,5 @@ def test_check_output(self): self.check_output() -if __name__ == "__main__": +if __name__ == '__main__': unittest.main()