From f88af205619d0f5821c307a7e03f4643380c0966 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 21 Jun 2021 14:38:02 +0800 Subject: [PATCH] Combine amp and qat (#33484) * Combine amp and qat * add unit test --- paddle/fluid/imperative/amp_auto_cast.cc | 17 +- paddle/fluid/operators/fake_quantize_op.cu | 44 ++-- .../fluid/contrib/slim/tests/CMakeLists.txt | 1 + .../slim/tests/test_imperative_qat_amp.py | 222 ++++++++++++++++++ python/paddle/fluid/dygraph/amp/auto_cast.py | 2 + 5 files changed, 267 insertions(+), 19 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index b4154737e0fbc..647b7cb34f659 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -141,7 +141,7 @@ static inline std::shared_ptr CastToFP32( } static inline framework::proto::VarType::Type GetPromoteType( - const NameVarBaseMap& ins) { + const std::string& op_type, const NameVarBaseMap& ins) { auto dst_type = framework::proto::VarType::FP16; for (const auto& pair : ins) { for (const auto& var : pair.second) { @@ -151,6 +151,18 @@ static inline framework::proto::VarType::Type GetPromoteType( } } } + + // NOTE(juncai): moving_average_abs_max_scale only consider the + // dtype of input(X) + if (op_type == "moving_average_abs_max_scale") { + for (const auto& pair : ins) { + if (pair.first == "X" && + pair.second.front()->DataType() == framework::proto::VarType::FP16) { + dst_type = framework::proto::VarType::FP16; + } + } + } + return dst_type; } @@ -183,7 +195,8 @@ NameVarBaseMap AutoCastInputs(const std::string& op_type, } return new_ins; } else { - auto dst_type = GetPromoteType(ins); + auto dst_type = GetPromoteType(op_type, ins); + // NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32. if (dst_type == framework::proto::VarType::FP16 && AmpOperators::Instance().GetMutableUnsupportedFp16Ops()->count( diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 78052179f6be7..583ff157a0d39 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -25,18 +25,19 @@ __global__ void FindAbsMaxKernel(const T* in, const int n, T* out) { int bid = threadIdx.x + blockIdx.x * blockDim.x; int tid = threadIdx.x; - extern __shared__ T shared_max_data[]; + extern __shared__ char* shared_max_data_tmp[]; + auto shared_max_data = reinterpret_cast(shared_max_data_tmp); if (gridDim.x > 1) { shared_max_data[tid] = T(0); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { - T tmp = fabs(in[i]); + T tmp = abs(in[i]); if (tmp > shared_max_data[tid]) { shared_max_data[tid] = tmp; } } } else { if (bid < n) { - shared_max_data[tid] = fabs(in[bid]); + shared_max_data[tid] = abs(in[bid]); } else { shared_max_data[tid] = T(0); } @@ -73,6 +74,8 @@ struct FindAbsMaxFunctor { }; template struct FindAbsMaxFunctor; +template struct FindAbsMaxFunctor; template __global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n, @@ -213,13 +216,16 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, int tid = threadIdx.x; T s = scale[0]; - T inv_s = inverse(s); + T bin_cnt_t = static_cast(bin_cnt); + for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T x = in[i]; - T v = x > s ? s : x; - v = v < -s ? -s : v; - v = bin_cnt * inv_s * v; - out[i] = round(v) * s / bin_cnt; + x = x > s ? s : x; + x = x < -s ? -s : x; + x = (bin_cnt_t / s) * x; + + x = static_cast(round(static_cast(x))); + out[i] = (x * s) / bin_cnt_t; } } @@ -261,9 +267,6 @@ struct ClipAndFakeQuantDequantFunctor { } }; -template struct ClipAndFakeQuantDequantFunctor; - // ChannelClipAndQuantKernel for quant_axis is 0 template __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, @@ -423,8 +426,10 @@ struct FindMovingAverageAbsMaxFunctor { memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T), ctx.stream()); ctx.Wait(); - state = rate * state + 1; - accum = rate * accum + scale; + + T rate_t = static_cast(rate); + state = rate_t * state + static_cast(1.0); + accum = rate_t * accum + scale; scale = accum / state; memory::Copy(gpu_place, out_accum->mutable_data(gpu_place), @@ -527,10 +532,12 @@ template struct ChannelClipFakeQuantDequantFunctor); REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max, - ops::FakeQuantizeDequantizeAbsMaxKernel); + ops::FakeQuantizeDequantizeAbsMaxKernel, + ops::FakeQuantizeDequantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max, @@ -539,12 +546,15 @@ REGISTER_OP_CUDA_KERNEL( fake_quantize_moving_average_abs_max, ops::FakeQuantizeMovingAverageAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale, - ops::MovingAverageAbsMaxScaleKernel); + ops::MovingAverageAbsMaxScaleKernel, + ops::MovingAverageAbsMaxScaleKernel); REGISTER_OP_CUDA_KERNEL( fake_quantize_dequantize_moving_average_abs_max, - ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel); + ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel, + ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel); REGISTER_OP_CUDA_KERNEL(stright_throuth_estimator_grad, - ops::StrightThroughEstimatorGradKernel); + ops::StrightThroughEstimatorGradKernel, + ops::StrightThroughEstimatorGradKernel); REGISTER_OP_CUDA_KERNEL( fake_channel_wise_quantize_dequantize_abs_max, ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel); diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 20c60dc58b78d..5a4f7c0a1fd1a 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -127,6 +127,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_lstm_model) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2) + list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp) endif() if(LINUX AND WITH_MKLDNN) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py new file mode 100644 index 0000000000000..d1bf76f472465 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_amp.py @@ -0,0 +1,222 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function + +import os +import numpy as np +import random +import shutil +import time +import unittest +import logging + +import paddle +import paddle.fluid as fluid +from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware +from paddle.fluid.log_helper import get_logger +from paddle.dataset.common import download + +from imperative_test_utils import fix_model_dict, ImperativeLenet + +os.environ["CPU_NUM"] = "1" +if paddle.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) + +_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class TestImperativeQatAmp(unittest.TestCase): + """ + Test the combination of qat and amp. + """ + + @classmethod + def setUpClass(cls): + timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) + cls.root_path = os.path.join(os.getcwd(), + "imperative_qat_amp_" + timestamp) + cls.save_path = os.path.join(cls.root_path, "model") + + cls.download_path = 'dygraph_int8/download' + cls.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + + cls.download_path) + + cls.lenet_url = "https://paddle-inference-dist.cdn.bcebos.com/int8/unittest_model_data/lenet_pretrained.tar.gz" + cls.lenet_md5 = "953b802fb73b52fae42896e3c24f0afb" + + seed = 1 + np.random.seed(seed) + paddle.static.default_main_program().random_seed = seed + paddle.static.default_startup_program().random_seed = seed + + @classmethod + def tearDownClass(cls): + try: + shutil.rmtree(cls.root_path) + except Exception as e: + print("Failed to delete {} due to {}".format(cls.root_path, str(e))) + + def cache_unzipping(self, target_folder, zip_path): + if not os.path.exists(target_folder): + cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, + zip_path) + os.system(cmd) + + def download_model(self, data_url, data_md5, folder_name): + download(data_url, self.download_path, data_md5) + file_name = data_url.split('/')[-1] + zip_path = os.path.join(self.cache_folder, file_name) + print('Data is downloaded at {0}'.format(zip_path)) + + data_cache_folder = os.path.join(self.cache_folder, folder_name) + self.cache_unzipping(data_cache_folder, zip_path) + return data_cache_folder + + def set_vars(self): + self.qat = ImperativeQuantAware() + + self.train_batch_num = 30 + self.train_batch_size = 32 + self.test_batch_num = 100 + self.test_batch_size = 32 + self.eval_acc_top1 = 0.99 + + def model_train(self, model, batch_num=-1, batch_size=32, use_amp=False): + model.train() + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=batch_size) + adam = paddle.optimizer.Adam( + learning_rate=0.001, parameters=model.parameters()) + scaler = paddle.amp.GradScaler(init_loss_scaling=500) + + for batch_id, data in enumerate(train_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + + if use_amp: + with paddle.amp.auto_cast(): + out = model(img) + acc = fluid.layers.accuracy(out, label) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.mean(loss) + scaled_loss = scaler.scale(avg_loss) + scaled_loss.backward() + + scaler.minimize(adam, scaled_loss) + adam.clear_gradients() + else: + out = model(img) + acc = fluid.layers.accuracy(out, label) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = fluid.layers.mean(loss) + avg_loss.backward() + + adam.minimize(avg_loss) + model.clear_gradients() + + if batch_id % 100 == 0: + _logger.info("Train | step {}: loss = {:}, acc= {:}".format( + batch_id, avg_loss.numpy(), acc.numpy())) + + if batch_num > 0 and batch_id + 1 >= batch_num: + break + + def model_test(self, model, batch_num=-1, batch_size=32, use_amp=False): + model.eval() + + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + + acc_top1_list = [] + for batch_id, data in enumerate(test_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = paddle.to_tensor(x_data) + label = paddle.to_tensor(y_data) + + with paddle.amp.auto_cast(use_amp): + out = model(img) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + + acc_top1_list.append(float(acc_top1.numpy())) + if batch_id % 100 == 0: + _logger.info("Test | At step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + if batch_num > 0 and batch_id + 1 >= batch_num: + break + + acc_top1 = sum(acc_top1_list) / len(acc_top1_list) + return acc_top1 + + def test_ptq(self): + start_time = time.time() + + self.set_vars() + + params_path = self.download_model(self.lenet_url, self.lenet_md5, + "lenet") + params_path += "/lenet_pretrained/lenet.pdparams" + + with fluid.dygraph.guard(): + model = ImperativeLenet() + model_state_dict = paddle.load(params_path) + model.set_state_dict(model_state_dict) + + _logger.info("Test fp32 model") + fp32_acc_top1 = self.model_test(model, self.test_batch_num, + self.test_batch_size) + + self.qat.quantize(model) + + use_amp = True + self.model_train(model, self.train_batch_num, self.train_batch_size, + use_amp) + + _logger.info("Test int8 model") + int8_acc_top1 = self.model_test(model, self.test_batch_num, + self.test_batch_size, use_amp) + + _logger.info('fp32_acc_top1: %f, int8_acc_top1: %f' % + (fp32_acc_top1, int8_acc_top1)) + self.assertTrue( + int8_acc_top1 > fp32_acc_top1 - 0.01, + msg='fp32_acc_top1: %f, int8_acc_top1: %f' % + (fp32_acc_top1, int8_acc_top1)) + + input_spec = [ + paddle.static.InputSpec( + shape=[None, 1, 28, 28], dtype='float32') + ] + paddle.jit.save(layer=model, path=self.save_path, input_spec=input_spec) + print('Quantized model saved in {%s}' % self.save_path) + + end_time = time.time() + print("total time: %ss" % (end_time - start_time)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 4ff08337875c0..b14b2be7394a1 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -29,6 +29,8 @@ 'matmul', 'matmul_v2', 'mul', + 'fake_quantize_dequantize_abs_max', + 'fake_quantize_dequantize_moving_average_abs_max', } # The set of ops that support fp16 calculation and are considered numerically-