diff --git a/paddle2onnx/mapper/nn/pool2d.cc b/paddle2onnx/mapper/nn/pool2d.cc index 0996bca8a..0672490b7 100755 --- a/paddle2onnx/mapper/nn/pool2d.cc +++ b/paddle2onnx/mapper/nn/pool2d.cc @@ -116,12 +116,20 @@ void Pool2dMapper::AdaptivePool(const std::vector& input_info, onnx_pool_type = iter->second[0]; } - std::shared_ptr* node_ptr; - auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); - auto node = helper_->MakeNode(onnx_pool_type, {input}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + std::shared_ptr node(nullptr); + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) + { + node = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name}); + } + else + { + auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, + P2ODataType::FP32); + node = helper_->MakeNode(onnx_pool_type, {input}); + helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, + output_info[0].dtype); + } + std::vector kernel_size = {kernel_h, kernel_w}; AddAttribute(node, "kernel_shape", kernel_size); std::vector strides = {stride_h, stride_w}; @@ -165,8 +173,12 @@ void Pool2dMapper::NoAdaptivePool(const std::vector& input_info, int64_t max_ksize = *std::max_element(std::begin(k_size_), std::end(k_size_)); int64_t max_pads = *std::max_element(std::begin(pads_), std::end(pads_)); - auto input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); + std::string input_x = input_info[0].name; + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) == kNoNeedCastTypesOpSet7.end()) + { + input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype, + P2ODataType::FP32); + } if (max_ksize <= max_pads) { std::vector onnx_paddings = {0, 0, pads_[0], pads_[1], 0, 0, pads_[2], pads_[3]}; @@ -199,9 +211,17 @@ void Pool2dMapper::NoAdaptivePool(const std::vector& input_info, auto iter = op_mapper_.find(pooling_type_); onnx_pool_type = iter->second[0]; } - auto node = helper_->MakeNode(onnx_pool_type, {input_x}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + std::shared_ptr node(nullptr); + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) + { + node = helper_->MakeNode(onnx_pool_type, {input_x}, {output_info[0].name}); + } + else + { + node = helper_->MakeNode(onnx_pool_type, {input_x}); + helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, + output_info[0].dtype); + } AddAttribute(node, "kernel_shape", k_size_); AddAttribute(node, "strides", strides_); @@ -317,11 +337,18 @@ void Pool2dMapper::Opset7() { auto iter = op_mapper_.find(pooling_type_); onnx_pool_type = iter->second[1]; } - auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); - auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0); - helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) + { + auto output = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name}); + } + else + { + auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, + P2ODataType::FP32); + auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0); + helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32, + output_info[0].dtype); + } } else if (adaptive_) { AdaptivePool(input_info, output_info); } else { diff --git a/paddle2onnx/mapper/nn/pool2d.h b/paddle2onnx/mapper/nn/pool2d.h index 02f92cc98..9fd9df489 100644 --- a/paddle2onnx/mapper/nn/pool2d.h +++ b/paddle2onnx/mapper/nn/pool2d.h @@ -63,6 +63,7 @@ class Pool2dMapper : public Mapper { const std::vector& output_info); void NoAdaptivePool(const std::vector& input_info, const std::vector& output_info); + const std::unordered_set kNoNeedCastTypesOpSet7{P2ODataType::FP16, P2ODataType::FP32}; bool ceil_mode_; bool global_pooling_; bool adaptive_; diff --git a/paddle2onnx/mapper/nn/pool3d.cc b/paddle2onnx/mapper/nn/pool3d.cc index fb6916fa1..2da09abd3 100644 --- a/paddle2onnx/mapper/nn/pool3d.cc +++ b/paddle2onnx/mapper/nn/pool3d.cc @@ -56,12 +56,21 @@ void Pool3dMapper::AdaptivePool(const std::vector& input_info, auto iter = op_mapper_.find(pooling_type_); onnx_pool_type = iter->second[0]; } - std::shared_ptr* node_ptr; - auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); - auto node = helper_->MakeNode(onnx_pool_type, {input}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + + std::shared_ptr node; + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) + { + node = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name}); + } + else + { + auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, + P2ODataType::FP32); + node = helper_->MakeNode(onnx_pool_type, {input}); + helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, + output_info[0].dtype); + } + std::vector kernel_size = {kernel_d, kernel_h, kernel_w}; AddAttribute(node, "kernel_shape", kernel_size); std::vector strides = {stride_d, stride_h, stride_w}; @@ -109,8 +118,13 @@ void Pool3dMapper::NoAdaptivePool(const std::vector& input_info, int64_t max_ksize = *std::max_element(std::begin(k_size_), std::end(k_size_)); int64_t max_pads = *std::max_element(std::begin(pads_), std::end(pads_)); - auto input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); + auto input_x = input_info[0].name; + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) == kNoNeedCastTypesOpSet7.end()) + { + input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype, + P2ODataType::FP32); + } + if (max_ksize <= max_pads) { std::vector onnx_paddings = {0, 0, pads_[0], pads_[1], pads_[2], 0, 0, pads_[3], pads_[4], pads_[5]}; @@ -143,9 +157,17 @@ void Pool3dMapper::NoAdaptivePool(const std::vector& input_info, auto iter = op_mapper_.find(pooling_type_); onnx_pool_type = iter->second[0]; } - auto node = helper_->MakeNode(onnx_pool_type, {input_x}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + std::shared_ptr node(nullptr); + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) + { + node = helper_->MakeNode(onnx_pool_type, {input_x}, {output_info[0].name}); + } + else + { + node = helper_->MakeNode(onnx_pool_type, {input_x}); + helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, + output_info[0].dtype); + } AddAttribute(node, "kernel_shape", k_size_); AddAttribute(node, "strides", strides_); @@ -247,11 +269,19 @@ void Pool3dMapper::Opset7() { auto iter = op_mapper_.find(pooling_type_); onnx_pool_type = iter->second[1]; } - auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, - P2ODataType::FP32); - auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0); - helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32, - output_info[0].dtype); + + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) + { + auto output = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name}); + } + else + { + auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype, + P2ODataType::FP32); + auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0); + helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32, + output_info[0].dtype); + } } else if (adaptive_) { AdaptivePool(input_info, output_info); } else { diff --git a/paddle2onnx/mapper/nn/pool3d.h b/paddle2onnx/mapper/nn/pool3d.h index 9be0c901d..5aeb3adf0 100644 --- a/paddle2onnx/mapper/nn/pool3d.h +++ b/paddle2onnx/mapper/nn/pool3d.h @@ -50,6 +50,7 @@ class Pool3dMapper : public Mapper { const std::vector& output_info); void NoAdaptivePool(const std::vector& input_info, const std::vector& output_info); + const std::unordered_set kNoNeedCastTypesOpSet7{P2ODataType::FP16, P2ODataType::FP32}; bool ceil_mode_; bool global_pooling_; bool adaptive_; diff --git a/paddle2onnx/mapper/tensor/fill_constant.cc b/paddle2onnx/mapper/tensor/fill_constant.cc index 22e4c488b..f7abba095 100644 --- a/paddle2onnx/mapper/tensor/fill_constant.cc +++ b/paddle2onnx/mapper/tensor/fill_constant.cc @@ -25,11 +25,12 @@ int32_t FillConstantMapper::GetMinOpset(bool verbose) { auto onnx_dtype = GetOnnxDtype(out_info[0].dtype); if (onnx_dtype != ONNX_NAMESPACE::TensorProto::INT32 && onnx_dtype != ONNX_NAMESPACE::TensorProto::INT64 && + onnx_dtype != ONNX_NAMESPACE::TensorProto::FLOAT16 && onnx_dtype != ONNX_NAMESPACE::TensorProto::FLOAT && onnx_dtype != ONNX_NAMESPACE::TensorProto::DOUBLE && onnx_dtype != ONNX_NAMESPACE::TensorProto::BOOL ) { - Error() << "Only support int32/int64/float32/float64/bool data type in " + Error() << "Only support int32/int64/float16/float32/float64/bool data type in " "fill_constant operator." << std::endl; return -1; @@ -79,9 +80,8 @@ void FillConstantMapper::Opset7() { float value = GetFillValue(); if (HasInput("ValueTensor")) { auto value_info = GetInput("ValueTensor"); - auto value_tensor = helper_->AutoCast(value_info[0].name, value_info[0].dtype, out_info[0].dtype); auto out = helper_->Constant(shape, GetOnnxDtype(out_info[0].dtype), float(0.0)); - helper_->MakeNode("Add", {out, value_tensor}, {out_info[0].name}); + helper_->MakeNode("Add", {out, value_info[0].name}, {out_info[0].name}); } else { helper_->Constant(out_info[0].name, shape, GetOnnxDtype(out_info[0].dtype), value); } @@ -149,9 +149,7 @@ void FillConstantMapper::Opset9() { } if (value_is_tensor) { auto value_info = GetInput("ValueTensor"); - std::string cast_value = helper_->AutoCast( - value_info[0].name, value_info[0].dtype, out_info[0].dtype); - helper_->MakeNode("Add", {out, cast_value}, {out_info[0].name}); + helper_->MakeNode("Add", {out, value_info[0].name}, {out_info[0].name}); } else { helper_->MakeNode("Identity", {out}, {out_info[0].name}); } diff --git a/paddle2onnx/mapper/tensor/matmul.cc b/paddle2onnx/mapper/tensor/matmul.cc index fd8bb1264..48b66f62b 100644 --- a/paddle2onnx/mapper/tensor/matmul.cc +++ b/paddle2onnx/mapper/tensor/matmul.cc @@ -20,7 +20,7 @@ REGISTER_MAPPER(matmul, MatmulMapper) std::string MatmulMapper::GetTrans(std::vector& input_info) { std::string castd_name = input_info[0].name; - if (input_info[0].dtype == P2ODataType::FP64) { + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) == kNoNeedCastTypesOpSet7.end()) { castd_name = helper_->AutoCast(input_info[0].name, input_info[0].dtype, P2ODataType::FP32); } @@ -43,11 +43,30 @@ void MatmulMapper::Opset7() { if (transpose_Y_) { input_y = GetTrans(input_y_info); } - if (fabs(alpha_ - 1.0) < 1e-6) { + + if (kNoNeedCastTypesOpSet7.find(input_x_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) + { + if (fabs(alpha_ - 1.0) < 1e-6) + { + auto node = helper_->MakeNode("MatMul", {input_x, input_y}, {output_info[0].name}); + } + else + { + auto mutmul_node = helper_->MakeNode("MatMul", {input_x, input_y}); + std::string scale_node = + helper_->Constant({1}, GetOnnxDtype(input_x_info[0].dtype), alpha_); + auto mul_node = + helper_->MakeNode("Mul", {mutmul_node->output(0), scale_node}, {output_info[0].name}); + } + } + else if (fabs(alpha_ - 1.0) < 1e-6) + { auto node = helper_->MakeNode("MatMul", {input_x, input_y}); helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, input_y_info[0].dtype); - } else { + } + else + { auto mutmul_node = helper_->MakeNode("MatMul", {input_x, input_y}); std::string scale_node = helper_->Constant({1}, GetOnnxDtype(input_x_info[0].dtype), alpha_); diff --git a/paddle2onnx/mapper/tensor/matmul.h b/paddle2onnx/mapper/tensor/matmul.h index 16957701f..b29226e72 100644 --- a/paddle2onnx/mapper/tensor/matmul.h +++ b/paddle2onnx/mapper/tensor/matmul.h @@ -34,6 +34,7 @@ class MatmulMapper : public Mapper { private: std::string GetTrans(std::vector& input_info); + const std::unordered_set kNoNeedCastTypesOpSet7{P2ODataType::FP16, P2ODataType::FP32, P2ODataType::INT32, P2ODataType::INT64}; bool transpose_X_ = false; bool transpose_Y_ = false; float alpha_ = 1.0; diff --git a/paddle2onnx/mapper/tensor/matmul_v2.cc b/paddle2onnx/mapper/tensor/matmul_v2.cc index f78a26f42..db3af7f58 100644 --- a/paddle2onnx/mapper/tensor/matmul_v2.cc +++ b/paddle2onnx/mapper/tensor/matmul_v2.cc @@ -22,8 +22,13 @@ namespace paddle2onnx { REGISTER_MAPPER(matmul_v2, MatmulV2Mapper) std::string MatmulV2Mapper::GetTrans(std::vector& input_info) { - std::string castd_name = helper_->AutoCast( - input_info[0].name, input_info[0].dtype, P2ODataType::FP32); + std::string castd_name = input_info[0].name; + if (kNoNeedCastTypesOpSet7.find(input_info[0].dtype) == kNoNeedCastTypesOpSet7.end()) + { + castd_name = helper_->AutoCast( + input_info[0].name, input_info[0].dtype, P2ODataType::FP32); + } + std::vector perm = Arange(0, input_info[0].Rank()); std::swap(perm[perm.size() - 1], perm[perm.size() - 2]); auto transpose_node = helper_->MakeNode("Transpose", {castd_name}); @@ -43,9 +48,16 @@ void MatmulV2Mapper::Opset7() { if (trans_y_) { input_y = GetTrans(input_y_info); } - auto node = helper_->MakeNode("MatMul", {input_x, input_y}); - helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, - input_y_info[0].dtype); + if (kNoNeedCastTypesOpSet7.find(input_y_info[0].dtype) != kNoNeedCastTypesOpSet7.end()) + { + auto node = helper_->MakeNode("MatMul", {input_x, input_y}, {output_info[0].name}); + } + else + { + auto node = helper_->MakeNode("MatMul", {input_x, input_y}); + helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32, + input_y_info[0].dtype); + } } } // namespace paddle2onnx diff --git a/paddle2onnx/mapper/tensor/matmul_v2.h b/paddle2onnx/mapper/tensor/matmul_v2.h index bb3762a34..bc342e6d1 100644 --- a/paddle2onnx/mapper/tensor/matmul_v2.h +++ b/paddle2onnx/mapper/tensor/matmul_v2.h @@ -33,6 +33,7 @@ class MatmulV2Mapper : public Mapper { private: std::string GetTrans(std::vector& input_info); + const std::unordered_set kNoNeedCastTypesOpSet7{P2ODataType::FP16, P2ODataType::FP32, P2ODataType::INT32, P2ODataType::INT64}; bool trans_x_ = false; bool trans_y_ = false; }; diff --git a/paddle2onnx/parser/parser.cc b/paddle2onnx/parser/parser.cc index d2c743a2d..c2c03019c 100755 --- a/paddle2onnx/parser/parser.cc +++ b/paddle2onnx/parser/parser.cc @@ -815,7 +815,6 @@ void PaddleParser::GetGlobalBlockInputOutputInfo() { } int32_t PaddleDataTypeSize(int32_t paddle_dtype) { - Assert(paddle_dtype != FP16, "Float16 is not supported."); if (paddle_dtype == P2ODataType::BOOL) { return sizeof(bool); } else if (paddle_dtype == P2ODataType::INT8) { @@ -828,6 +827,8 @@ int32_t PaddleDataTypeSize(int32_t paddle_dtype) { return sizeof(int64_t); } else if (paddle_dtype == P2ODataType::FP32) { return sizeof(float); + } else if (paddle_dtype == P2ODataType::FP16) { + return sizeof(int16_t); } else if (paddle_dtype == P2ODataType::FP64) { return sizeof(double); } else if (paddle_dtype == P2ODataType::UINT8) { diff --git a/tests/run.sh b/tests/run.sh index fe36c05fc..8adec38a9 100755 --- a/tests/run.sh +++ b/tests/run.sh @@ -60,7 +60,8 @@ ignore="test_auto_scan_multiclass_nms.py test_unsqueeze.py \ test_quantize_model.py \ test_quantize_model_minist.py \ - test_quantize_model_speedup.py" + test_quantize_model_speedup.py \ + test_resnet_fp16.py" bug=0 # Install Python Packet @@ -69,6 +70,7 @@ $PY_CMD -m pip install pytest $PY_CMD -m pip install onnx onnxruntime tqdm filelock $PY_CMD -m pip install paddlepaddle==2.6.0 $PY_CMD -m pip install six hypothesis +$PY_CMD -m pip install numpy==1.26.4 export ENABLE_DEV=ON diff --git a/tests/test_resnet_fp16.py b/tests/test_resnet_fp16.py new file mode 100644 index 000000000..82988b957 --- /dev/null +++ b/tests/test_resnet_fp16.py @@ -0,0 +1,77 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import onnxruntime + +import paddle +import paddle2onnx +from paddle.inference import PrecisionType, PlaceType, convert_to_mixed_precision + +def test_resnet_fp16_convert(): + # download resnet model + if not os.path.exists("ResNet50_infer"): + os.system("wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/inference/ResNet50_infer.tar && tar -xf ResNet50_infer.tar && rm -rf ResNet50_infer.tar") + + # generate fp16 model + path = "ResNet50_infer" + src_model = os.path.join(path,"inference.pdmodel") + src_params = os.path.join(path,"inference.pdiparams") + dst_model = os.path.join(path,"inference_fp16.pdmodel") + dst_params = os.path.join(path,"inference_fp16.pdiparams") + + convert_to_mixed_precision( + src_model, # fp32 model path + src_params, # fp32 params path + dst_model, # mix precious model path + dst_params, # mix precious params path + PrecisionType.Half, + PlaceType.GPU, + False + ) + + # paddle.set_device("gpu") + paddle.enable_static() + path_fp16 = os.path.join(path, "inference_fp16") + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + [inference_program, feed_target_names, fetch_targets] = paddle.static.load_inference_model(path_fp16, exe) + + # infer paddle fp16 + np.random.seed(10) + tensor_img = np.array(np.random.random((1, 3, 224, 224)), dtype=np.float16) + results = exe.run(inference_program, + feed={feed_target_names[0]: tensor_img}, + fetch_list=fetch_targets) + + # convert to onnx + input_spec = [paddle.static.InputSpec(shape=[-1, 3, 224, 224], dtype='float16', name='inputs')] + model_file = path_fp16 + ".pdmodel" + params_file = path_fp16 + ".pdiparams" + paddle2onnx.export(model_file, params_file, "./resnet_fp16.onnx", export_fp16_model=True) # ONNX模型导出 + + # valid precision + onnx_file_name = "./resnet_fp16.onnx" + ort_session = onnxruntime.InferenceSession(onnx_file_name) + + ort_inputs = {ort_session.get_inputs()[0].name: tensor_img} + ort_outputs = ort_session.run(None, ort_inputs) + + # assert + np.testing.assert_allclose( + results[0], ort_outputs[0], rtol=2e-02, atol=2e-05 + ) + +if __name__ == "__main__": + test_resnet_fp16_convert()