From 8b3a9a2eb38594b9f2626f8e3141898b0da1585d Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Fri, 14 Jul 2023 03:00:02 +0000 Subject: [PATCH 1/2] [Paddle-TRT] add assign op --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 3 +- .../inference/tensorrt/convert/assign_op.cc | 67 +++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 13 +- test/ir/inference/test_trt_convert_assign.py | 171 ++++++++++++++++++ 5 files changed, 245 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/assign_op.cc create mode 100644 test/ir/inference/test_trt_convert_assign.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index ad6ab5a7ef0eb..a85524676feca 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2915,6 +2915,7 @@ USE_TRT_CONVERTER(take_along_axis) USE_TRT_CONVERTER(skip_groupnorm_act) USE_TRT_CONVERTER(preln_groupnorm_act) USE_TRT_CONVERTER(cumsum) +USE_TRT_CONVERTER(assign) #if IS_TRT_VERSION_GE(8522) USE_TRT_CONVERTER(flash_multihead_matmul) USE_TRT_CONVERTER(cross_multihead_matmul) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 1064362df3878..6483e0619a963 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -106,7 +106,8 @@ list( expand_v2_op.cc cumsum_op.cc temporal_shift_op.cc - einsum_op.cc) + einsum_op.cc + assign_op.cc) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc diff --git a/paddle/fluid/inference/tensorrt/convert/assign_op.cc b/paddle/fluid/inference/tensorrt/convert/assign_op.cc new file mode 100644 index 0000000000000..2f3e30115fb9a --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/assign_op.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class AssignOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a assign op to tensorrt"; + framework::OpDesc op_desc(op, nullptr); + + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input); + + switch (dtype) { + case 0: // BOOL = 0 + layer->setOutputType(0, nvinfer1::DataType::kBOOL); + layer->getOutput(0)->setType(nvinfer1::DataType::kBOOL); + break; + case 2: // INT32 = 2 + case 3: // INT64 = 3 there is no int64 in tensorrt subgraph + layer->setOutputType(0, nvinfer1::DataType::kINT32); + layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); + break; + case 4: // FP16 = 4 + layer->setOutputType(0, nvinfer1::DataType::kHALF); + layer->getOutput(0)->setType(nvinfer1::DataType::kHALF); + break; + case 5: // FP32 = 5 + layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT); + break; + default: + LOG(ERROR) << "Nvinfer DataType doesn't support the fluid data type(" + << dtype << ") to a nvinfer DataType"; + break; + } + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "assign", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(assign, AssignOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ee20184c86194..c682e2f2e8982 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -695,13 +695,6 @@ struct SimpleOpTypeSetTeller : public Teller { auto* input_var_desc = block->FindVar(input_var_name); auto* index_var_desc = block->FindVar(index_var_name); - // The index input must be int32 datatype. - if (index_var_desc->GetDataType() != - paddle::framework::proto::VarType_Type::VarType_Type_INT32) { - VLOG(3) << "take_along_axis op Index input data type must be int32"; - return false; - } - const auto input_shape = input_var_desc->GetShape(); const auto index_shape = index_var_desc->GetShape(); if (input_shape.size() != index_shape.size()) { @@ -2903,7 +2896,8 @@ struct SimpleOpTypeSetTeller : public Teller { "preln_groupnorm_act", "temporal_shift", "grid_sampler", - "cumsum"}; + "cumsum", + "assign"}; std::unordered_set teller_set{ "matrix_multiply", @@ -3065,7 +3059,8 @@ struct SimpleOpTypeSetTeller : public Teller { "preln_groupnorm_act", "temporal_shift", "grid_sampler", - "cumsum"}; + "cumsum", + "assign"}; }; struct GenericPluginTeller : public Teller { diff --git a/test/ir/inference/test_trt_convert_assign.py b/test/ir/inference/test_trt_convert_assign.py new file mode 100644 index 0000000000000..386077c0f3f65 --- /dev/null +++ b/test/ir/inference/test_trt_convert_assign.py @@ -0,0 +1,171 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer +from paddle.framework import convert_np_dtype_to_dtype_ + + +class TrtConvertAssignTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + if attrs[0]['dtype'] not in [0, 1, 2, 3, 4, 5]: + return False + compile_version = paddle_infer.get_trt_compile_version() + runtime_version = paddle_infer.get_trt_runtime_version() + if ( + compile_version[0] * 1000 + + compile_version[1] * 100 + + compile_version[2] * 10 + < 8400 + ): + return False + if ( + runtime_version[0] * 1000 + + runtime_version[1] * 100 + + runtime_version[2] * 10 + < 8400 + ): + return False + return True + + def sample_program_configs(self): + def generate_input(type): + if self.dims == 0: + return np.ones([]).astype(type) + elif self.dims == 1: + return np.ones([1]).astype(type) + else: + return np.ones([1, 3, 64, 64]).astype(type) + + for dims in [0, 1, 4]: + self.dims = dims + for dtype in [ + np.bool_, + np.int32, + np.float32, + np.float64, + np.int64, + ]: + self.has_bool_dtype = dtype == np.bool_ + ops_config = [ + { + "op_type": "assign", + "op_inputs": {"X": ["input_data"]}, + "op_outputs": {"Out": ["assign_output_data0"]}, + "op_attrs": { + "dtype": convert_np_dtype_to_dtype_(dtype) + }, + }, + { + "op_type": "assign", + "op_inputs": {"X": ["assign_output_data0"]}, + "op_outputs": {"Out": ["assign_output_data1"]}, + "op_attrs": { + "dtype": convert_np_dtype_to_dtype_(dtype) + }, + }, + ] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input, dtype) + ) + }, + outputs=["assign_output_data1"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 0: + self.dynamic_shape.min_input_shape = {"input_data": []} + self.dynamic_shape.max_input_shape = {"input_data": []} + self.dynamic_shape.opt_input_shape = {"input_data": []} + elif self.dims == 1: + self.dynamic_shape.min_input_shape = {"input_data": [1]} + self.dynamic_shape.max_input_shape = {"input_data": [1]} + self.dynamic_shape.opt_input_shape = {"input_data": [1]} + else: + self.dynamic_shape.min_input_shape = { + "input_data": [1, 3, 64, 64] + } + self.dynamic_shape.max_input_shape = { + "input_data": [1, 3, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 3, 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if not dynamic_shape and ( + self.has_bool_dtype or self.dims == 1 or self.dims == 0 + ): + return 0, 4 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-2 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-2 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() From ef4cc30fda83bf0365d9ddb31f3e411c4b940e45 Mon Sep 17 00:00:00 2001 From: minghaipeng Date: Fri, 14 Jul 2023 07:08:41 +0000 Subject: [PATCH 2/2] [Paddle-TRT] add assign op --- .../inference/tensorrt/convert/assign_op.cc | 28 ------------------- test/ir/inference/test_trt_convert_assign.py | 15 ++-------- 2 files changed, 2 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/assign_op.cc b/paddle/fluid/inference/tensorrt/convert/assign_op.cc index 2f3e30115fb9a..5f14d19ee132b 100644 --- a/paddle/fluid/inference/tensorrt/convert/assign_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/assign_op.cc @@ -25,36 +25,8 @@ class AssignOpConverter : public OpConverter { bool test_mode) override { VLOG(3) << "convert a assign op to tensorrt"; framework::OpDesc op_desc(op, nullptr); - auto* input = engine_->GetITensor(op_desc.Input("X")[0]); - int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); - auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input); - - switch (dtype) { - case 0: // BOOL = 0 - layer->setOutputType(0, nvinfer1::DataType::kBOOL); - layer->getOutput(0)->setType(nvinfer1::DataType::kBOOL); - break; - case 2: // INT32 = 2 - case 3: // INT64 = 3 there is no int64 in tensorrt subgraph - layer->setOutputType(0, nvinfer1::DataType::kINT32); - layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); - break; - case 4: // FP16 = 4 - layer->setOutputType(0, nvinfer1::DataType::kHALF); - layer->getOutput(0)->setType(nvinfer1::DataType::kHALF); - break; - case 5: // FP32 = 5 - layer->setOutputType(0, nvinfer1::DataType::kFLOAT); - layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT); - break; - default: - LOG(ERROR) << "Nvinfer DataType doesn't support the fluid data type(" - << dtype << ") to a nvinfer DataType"; - break; - } - auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "assign", {output_name}, test_mode); } diff --git a/test/ir/inference/test_trt_convert_assign.py b/test/ir/inference/test_trt_convert_assign.py index 386077c0f3f65..64dbf16064e94 100644 --- a/test/ir/inference/test_trt_convert_assign.py +++ b/test/ir/inference/test_trt_convert_assign.py @@ -21,16 +21,10 @@ from trt_layer_auto_scan_test import TrtLayerAutoScanTest import paddle.inference as paddle_infer -from paddle.framework import convert_np_dtype_to_dtype_ class TrtConvertAssignTest(TrtLayerAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: - attrs = [ - program_config.ops[i].attrs for i in range(len(program_config.ops)) - ] - if attrs[0]['dtype'] not in [0, 1, 2, 3, 4, 5]: - return False compile_version = paddle_infer.get_trt_compile_version() runtime_version = paddle_infer.get_trt_runtime_version() if ( @@ -64,7 +58,6 @@ def generate_input(type): np.bool_, np.int32, np.float32, - np.float64, np.int64, ]: self.has_bool_dtype = dtype == np.bool_ @@ -73,17 +66,13 @@ def generate_input(type): "op_type": "assign", "op_inputs": {"X": ["input_data"]}, "op_outputs": {"Out": ["assign_output_data0"]}, - "op_attrs": { - "dtype": convert_np_dtype_to_dtype_(dtype) - }, + "op_attrs": {}, }, { "op_type": "assign", "op_inputs": {"X": ["assign_output_data0"]}, "op_outputs": {"Out": ["assign_output_data1"]}, - "op_attrs": { - "dtype": convert_np_dtype_to_dtype_(dtype) - }, + "op_attrs": {}, }, ]