From c928bc7bea6112100927180a53cef4b8b0015999 Mon Sep 17 00:00:00 2001 From: lizexu123 <39205361+lizexu123@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:24:44 +0800 Subject: [PATCH] add inference api:exp_specify_tensorrt_subgraph_precision (#62402) add inference api:exp_specify_tensorrt_subgraph_precision (#62402) --- paddle/fluid/inference/analysis/argument.h | 9 ++ .../inference/analysis/ir_pass_manager.cc | 9 ++ .../ir_passes/tensorrt_subgraph_pass.cc | 40 ++++- paddle/fluid/inference/api/analysis_config.cc | 24 +++ .../fluid/inference/api/analysis_predictor.cc | 3 + .../inference/api/paddle_analysis_config.h | 22 +++ paddle/fluid/pybind/inference_api.cc | 2 + .../test_trt_ops_fp16_mix_precision.py | 144 ++++++++++++++++++ 8 files changed, 252 insertions(+), 1 deletion(-) create mode 100644 test/ir/inference/test_trt_ops_fp16_mix_precision.py diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 8c4fbceced1ab9..aeaa305191974f 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -256,6 +256,15 @@ struct Argument { DECL_ARGUMENT_FIELD(tensorrt_disabled_ops, TensorRtDisabledOPs, std::vector); + DECL_ARGUMENT_FIELD(trt_parameter_run_fp16, + TRTParameterRunFp16, + std::vector); + DECL_ARGUMENT_FIELD(trt_parameter_run_int8, + TRTParameterRunInt8, + std::vector); + DECL_ARGUMENT_FIELD(trt_parameter_run_bfp16, + TRTParameterRunBfp16, + std::vector); DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, int); DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index cc126e5fea6122..57fd4fb7c311a2 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -173,6 +173,15 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set( "trt_exclude_var_names", new std::vector(argument->trt_exclude_var_names())); + pass->Set( + "trt_parameter_run_fp16", + new std::vector(argument->trt_parameter_run_fp16())); + pass->Set( + "trt_parameter_run_int8", + new std::vector(argument->trt_parameter_run_int8())); + pass->Set( + "trt_parameter_run_bfp16", + new std::vector(argument->trt_parameter_run_bfp16())); pass->Set("forbid_dynamic_op", new bool(argument->trt_forbid_dynamic_op())); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index d6441cc6d4a566..db185b15c03d92 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -14,7 +14,6 @@ // limitations under the License. #include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" - #include #include #include @@ -476,9 +475,47 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( } auto precision_mode = static_cast(Get("trt_precision_mode")); + auto trt_params_run_fp16 = + Get>("trt_parameter_run_fp16"); + auto trt_params_run_int8 = + Get>("trt_parameter_run_int8"); + auto trt_params_run_bfp16 = + Get>("trt_parameter_run_bfp16"); + + for (const auto ¶ : parameters) { + if (std::find(trt_params_run_fp16.begin(), + trt_params_run_fp16.end(), + para) != trt_params_run_fp16.end()) { + precision_mode = phi::DataType::FLOAT16; + break; + } + } + bool enable_fp16 = false; if (precision_mode == phi::DataType::FLOAT16) enable_fp16 = true; auto enable_int8 = Get("enable_int8"); + + for (const auto ¶ : parameters) { + if (std::find(trt_params_run_int8.begin(), + trt_params_run_int8.end(), + para) != trt_params_run_int8.end()) { + enable_int8 = true; + precision_mode = phi::DataType::INT8; + break; + } + } + + for (const auto ¶ : parameters) { + if (std::find(trt_params_run_bfp16.begin(), + trt_params_run_bfp16.end(), + para) != trt_params_run_bfp16.end()) { + precision_mode = phi::DataType::BFLOAT16; + break; + } + } + bool enable_bfp16 = false; + if (precision_mode == phi::DataType::BFLOAT16) enable_bfp16 = true; + auto use_calib_mode = Get("use_calib_mode"); auto &subgraph_nodes = *framework::ir::Agent(node).subgraph(); auto min_input_shape = @@ -724,6 +761,7 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( op_desc->SetAttr("calibration_data", calibration_data); op_desc->SetAttr("enable_int8", enable_int8); op_desc->SetAttr("enable_fp16", enable_fp16); + op_desc->SetAttr("enbale_bfp16", enable_bfp16); op_desc->SetAttr("use_calib_mode", use_calib_mode); op_desc->SetAttr("engine_key", engine_key); op_desc->SetAttr("calibration_engine_key", calibration_engine_key); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 5ab33c65208a3c..d97e41f0b1e131 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -462,6 +462,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(tensorrt_min_subgraph_size_); CP_MEMBER(tensorrt_precision_mode_); CP_MEMBER(trt_mark_output_); + CP_MEMBER(trt_parameters_run_fp16_); + CP_MEMBER(trt_parameters_run_int8_); + CP_MEMBER(trt_parameters_run_bfp16_); CP_MEMBER(trt_forbid_dynamic_op_) CP_MEMBER(trt_output_tensor_names_); CP_MEMBER(trt_disabled_ops_); @@ -880,6 +883,21 @@ void AnalysisConfig::Exp_DisableTensorRtSubgraph( var_name_not_trt.end()); } +void AnalysisConfig::Exp_SpecifyTensorRTSubgraphPrecision( + const std::vector &trt_parameters_run_fp16, + const std::vector &trt_parameters_run_int8, + const std::vector &trt_parameters_run_bfp16) { + trt_parameters_run_fp16_.insert(trt_parameters_run_fp16_.end(), + trt_parameters_run_fp16.begin(), + trt_parameters_run_fp16.end()); + trt_parameters_run_int8_.insert(trt_parameters_run_int8_.end(), + trt_parameters_run_int8.begin(), + trt_parameters_run_int8.end()); + trt_parameters_run_bfp16_.insert(trt_parameters_run_bfp16_.end(), + trt_parameters_run_bfp16.begin(), + trt_parameters_run_bfp16.end()); +} + void AnalysisConfig::EnableVarseqlen() { trt_use_varseqlen_ = true; } void AnalysisConfig::SetTensorRtOptimizationLevel(int level) { @@ -1135,6 +1153,12 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << tensorrt_max_batchsize_; ss << tensorrt_min_subgraph_size_; ss << trt_mark_output_; + for (auto &name : trt_parameters_run_fp16_) ss << name.c_str(); + ss << ";"; + for (auto &name : trt_parameters_run_int8_) ss << name.c_str(); + ss << ";"; + for (auto &name : trt_parameters_run_bfp16_) ss << name.c_str(); + ss << ";"; ss << trt_forbid_dynamic_op_; ss << use_dlnne_; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 961c0e350be388..8be9fa420318c7 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1759,6 +1759,9 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_); argument_->SetTRTMarkOutput(config_.trt_mark_output_); argument_->SetTRTOutputTensorNames(config_.trt_output_tensor_names_); + argument_->SetTRTParameterRunFp16(config_.trt_parameters_run_fp16_); + argument_->SetTRTParameterRunInt8(config_.trt_parameters_run_int8_); + argument_->SetTRTParameterRunBfp16(config_.trt_parameters_run_bfp16_); argument_->SetTensorRtDisabledOPs(config_.trt_disabled_ops_); argument_->SetTRTExcludeVarNames(config_.trt_exclude_var_names_); argument_->SetTRTForbidDynamicOp(config_.trt_forbid_dynamic_op_); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 2c5b254ea1c142..251f390b9afdac 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -810,9 +810,27 @@ struct PD_INFER_DECL AnalysisConfig { /// void Exp_DisableTensorRtOPs(const std::vector& ops); + /// + /// \brief Prevent TensorRtSubgraph running in Paddle-TRT + /// NOTE: just experimental, not an official stable API, easy to be broken. + /// void Exp_DisableTensorRtSubgraph( const std::vector& var_name_not_trt); + /// + /// \brief Specify TensorRT subgraph precision,fp16, int8 or bfp16(TensorRT + /// Version>=9.0) NOTE: just experimental, not an official stable API, easy to + /// be broken. + /// + void Exp_SpecifyTensorRTSubgraphPrecision( + const std::vector& trt_parameters_fp16, + const std::vector& trt_parameters_int8, + const std::vector& trt_parameters_bfp16); + + /// + /// \brief Prevent DynamicShape OPs running in Paddle-TRT + /// NOTE: just experimental, not an official stable API, easy to be broken. + /// void Exp_DisableTensorRTDynamicShapeOPs(bool trt_forbid_dynamic_op); /// @@ -1289,6 +1307,10 @@ struct PD_INFER_DECL AnalysisConfig { std::vector trt_output_tensor_names_{}; std::vector trt_exclude_var_names_{}; + std::vector trt_parameters_run_fp16_{}; + std::vector trt_parameters_run_int8_{}; + std::vector trt_parameters_run_bfp16_{}; + std::string tensorrt_transformer_posid_{""}; std::string tensorrt_transformer_maskid_{""}; bool trt_use_dla_{false}; diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 69cb7303ea4e85..e5c3ffd15bb72e 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -937,6 +937,8 @@ void BindAnalysisConfig(py::module *m) { .def("exp_disable_tensorrt_ops", &AnalysisConfig::Exp_DisableTensorRtOPs) .def("exp_disable_tensorrt_subgraph", &AnalysisConfig::Exp_DisableTensorRtSubgraph) + .def("exp_specify_tensorrt_subgraph_precision", + &AnalysisConfig::Exp_SpecifyTensorRTSubgraphPrecision) .def("exp_disable_tensorrt_dynamic_shape_ops", &AnalysisConfig::Exp_DisableTensorRTDynamicShapeOPs) .def("enable_tensorrt_dla", diff --git a/test/ir/inference/test_trt_ops_fp16_mix_precision.py b/test/ir/inference/test_trt_ops_fp16_mix_precision.py new file mode 100644 index 00000000000000..f950f3bca8bf40 --- /dev/null +++ b/test/ir/inference/test_trt_ops_fp16_mix_precision.py @@ -0,0 +1,144 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest + +import numpy as np + +import paddle +from paddle import nn, static +from paddle.inference import Config, PrecisionType, create_predictor + +paddle.enable_static() + + +class SimpleNet(nn.Layer): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2D( + in_channels=4, + out_channels=4, + kernel_size=3, + stride=2, + padding=0, + ) + self.relu1 = nn.ReLU() + self.conv2 = nn.Conv2D( + in_channels=4, + out_channels=2, + kernel_size=3, + stride=2, + padding=0, + ) + self.relu2 = nn.ReLU() + self.conv3 = nn.Conv2D( + in_channels=2, + out_channels=1, + kernel_size=3, + stride=2, + padding=0, + ) + self.relu3 = nn.ReLU() + self.flatten = nn.Flatten() + self.fc = nn.Linear(729, 10) + self.softmax = nn.Softmax() + + def forward(self, x): + x = self.conv1(x) + x = self.relu1(x) + x = self.conv2(x) + x = self.relu2(x) + x = self.conv3(x) + x = self.relu3(x) + x = self.flatten(x) + x = self.fc(x) + x = self.softmax(x) + return x + + +class TestTRTOptimizationLevel(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.temp_dir = tempfile.TemporaryDirectory() + self.path = os.path.join(self.temp_dir.name, 'optimization_level', '') + self.model_prefix = self.path + 'infer_model' + + def tearDown(self): + shutil.rmtree(self.path) + + def build_model(self): + image = static.data( + name='img', shape=[None, 4, 224, 224], dtype='float32' + ) + predict = SimpleNet()(image) + exe = paddle.static.Executor(self.place) + exe.run(paddle.static.default_startup_program()) + paddle.static.save_inference_model( + self.model_prefix, [image], [predict], exe + ) + + def init_predictor(self): + config = Config( + self.model_prefix + '.pdmodel', self.model_prefix + '.pdiparams' + ) + config.enable_use_gpu(256, 0, PrecisionType.Float32) + config.exp_disable_tensorrt_ops(["relu_1.tmp_0"]) + config.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=1, + min_subgraph_size=3, + precision_mode=PrecisionType.Float32, + use_static=False, + use_calib_mode=False, + ) + + config.exp_specify_tensorrt_subgraph_precision( + ["conv2d_1.w_0"], [""], ["conv2d_2.w_0"] + ) + + config.enable_memory_optim() + # config.disable_glog_info() + config.set_tensorrt_optimization_level(0) + self.assertEqual(config.tensorrt_optimization_level(), 0) + predictor = create_predictor(config) + return predictor + + def infer(self, predictor, img): + input_names = predictor.get_input_names() + for i, name in enumerate(input_names): + input_tensor = predictor.get_input_handle(name) + input_tensor.reshape(img[i].shape) + input_tensor.copy_from_cpu(img[i].copy()) + + predictor.run() + results = [] + output_names = predictor.get_output_names() + for i, name in enumerate(output_names): + output_tensor = predictor.get_output_handle(name) + output_data = output_tensor.copy_to_cpu() + results.append(output_data) + return results + + def test_optimization_level(self): + self.build_model() + predictor = self.init_predictor() + img = np.ones((1, 4, 224, 224), dtype=np.float32) + results = self.infer(predictor, img=[img]) + + +if __name__ == '__main__': + unittest.main()