From 9f8b751a2fd67b1e1b6d5fd3287334b346027c1b Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Thu, 11 May 2023 20:36:07 +0800 Subject: [PATCH] [inference Zero-Dim]prelu trt converter support zero dim tensor (#53634) * prelu op trt converter support zero dim --- .../framework/ir/trt_support_nhwc_pass.cc | 2 + .../inference/tensorrt/convert/prelu_op.cc | 1 - paddle/fluid/inference/tensorrt/op_teller.cc | 28 +- paddle/phi/infermeta/binary.cc | 2 +- test/ir/inference/test_trt_convert_prelu.py | 288 ++++++++---------- 5 files changed, 148 insertions(+), 173 deletions(-) diff --git a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc index 86c7b7c9dbbae..0b12559e31deb 100644 --- a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc +++ b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc @@ -356,6 +356,8 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { } }; InsertTransposeOp(); + + AddStatis(transposed_ops.size()); } } // namespace ir diff --git a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc index d655a4ce04fd6..80a2ac46f44dc 100644 --- a/paddle/fluid/inference/tensorrt/convert/prelu_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/prelu_op.cc @@ -87,7 +87,6 @@ class PReluOpConverter : public OpConverter { if (hw_tensor != nullptr) { shape_tensor = Concat( std::vector{n_tensor, c_tensor, hw_tensor}); - } else { shape_tensor = Concat(std::vector{n_tensor, c_tensor}); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index f3b4778baf879..eb21cba81879c 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1845,28 +1845,28 @@ struct SimpleOpTypeSetTeller : public Teller { "the pass."; return false; } - auto* var_desc = block->FindVar(desc.Input("Alpha")[0]); - if (!var_desc) { + auto* alpha_var = block->FindVar(desc.Input("Alpha")[0]); + if (!alpha_var) { VLOG(3) << "Variable Alpha of prelu TRT converter not found."; return false; } - - auto x_var_name = desc.Input("X")[0]; - auto* x_var_desc = block->FindVar(x_var_name); - const auto x_shape = x_var_desc->GetShape(); - if (!with_dynamic_shape && x_shape.size() == 1) { - VLOG(3) << "prelu op does not support input's dim is 1 in tensorrt " - "with static shape."; + auto alpha_shape = alpha_var->GetShape(); + if (!with_dynamic_shape && alpha_shape.size() == 0) { + VLOG(3) << op_type + << " op does not support alpha's dim is 0 in tensorrt " + "static shape mode."; return false; } -#if IS_TRT_VERSION_LT(7000) - if (!with_dynamic_shape) { - // TODO(inference): fix trt6 static plugin error. - VLOG(3) << "prelu static plugin in trt6 has bug."; + auto x_var_name = desc.Input("X")[0]; + auto* x_var = block->FindVar(x_var_name); + const auto x_shape = x_var->GetShape(); + if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.size() == 0)) { + VLOG(3) << op_type + << " op does not support input's dim is 1 or 0 in tensorrt " + "with static shape."; return false; } -#endif } if (op_type == "mish") { diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index efa271e8e1582..a1273d5bc2a8d 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2328,7 +2328,7 @@ void PReluInferMeta(const MetaTensor& x, 1, phi::errors::InvalidArgument( "For mode 'element', rank of input X must be " - "equal or larger than 2. But recevied X's " + "equal or larger than 1. But recevied X's " "rank: %d", x_rank)); PADDLE_ENFORCE_EQ( diff --git a/test/ir/inference/test_trt_convert_prelu.py b/test/ir/inference/test_trt_convert_prelu.py index 14046d45a3b2d..ae53f10ad5761 100644 --- a/test/ir/inference/test_trt_convert_prelu.py +++ b/test/ir/inference/test_trt_convert_prelu.py @@ -18,7 +18,7 @@ import numpy as np from program_config import ProgramConfig, TensorConfig -from trt_layer_auto_scan_test import SkipReasons, TrtLayerAutoScanTest +from trt_layer_auto_scan_test import TrtLayerAutoScanTest import paddle.inference as paddle_infer @@ -28,170 +28,165 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool: return True def sample_program_configs(self): - def generate_input(batch, dim1, dim2, dim3): - shape = [batch] - if dim1 != 0: - shape.append(dim1) - if dim2 != 0: - shape.append(dim2) - if dim3 != 0: - shape.append(dim3) - return np.random.random(shape).astype(np.float32) - - def generate_alpha(attrs: List[Dict[str, Any]], dim1, dim2, dim3): + def generate_input(attrs: List[Dict[str, Any]], batch): + if self.dims == 0: + return np.random.random([]).astype(np.float32) + elif self.dims == 1: + return np.random.random([16]).astype(np.float32) + elif self.dims == 2: + return np.random.random([1, 3]).astype(np.float32) + elif self.dims == 3: + if attrs[0]["data_format"] == "NCHW": + return np.random.random([batch, 3, 16]).astype(np.float32) + elif attrs[0]["data_format"] == "NHWC": + return np.random.random([batch, 16, 3]).astype(np.float32) + else: + raise AssertionError() + else: + if attrs[0]["data_format"] == "NCHW": + return np.random.random([batch, 3, 16, 32]).astype( + np.float32 + ) + else: + return np.random.random([batch, 16, 32, 3]).astype( + np.float32 + ) + + def generate_alpha(attrs: List[Dict[str, Any]]): + if self.dims == 0: + return np.random.random([]).astype(np.float32) if attrs[0]["mode"] == "all": - return np.random.random(size=(1)).astype(np.float32) - elif ( - attrs[0]["mode"] == "channel" - and attrs[0]["data_format"] == "NCHW" - ): - shape = [1] - if dim1 != 0: - shape.append(dim1) - if dim2 != 0: - shape.append(dim2) - if dim3 != 0: - shape.append(dim3) - return np.random.random(size=shape[1]).astype(np.float32) - elif ( - attrs[0]["mode"] == "channel" - and attrs[0]["data_format"] == "NHWC" - ): - shape = [1] - if dim1 != 0: - shape.append(dim1) - if dim2 != 0: - shape.append(dim2) - if dim3 != 0: - shape.append(dim3) - return np.random.random(size=shape[-1]).astype(np.float32) + return np.random.random([1]).astype(np.float32) + elif attrs[0]["mode"] == "channel": + return np.random.random([3]).astype(np.float32) elif attrs[0]["mode"] == "element": - shape = [1] - if dim1 != 0: - shape.append(dim1) - if dim2 != 0: - shape.append(dim2) - if dim3 != 0: - shape.append(dim3) - return np.random.random(size=shape).astype(np.float32) + if self.dims == 1: + return np.random.random([16]).astype(np.float32) + elif self.dims == 2: + return np.random.random([1, 3]).astype(np.float32) + elif self.dims == 3: + if attrs[0]["data_format"] == "NCHW": + return np.random.random([1, 3, 16]).astype(np.float32) + elif attrs[0]["data_format"] == "NHWC": + return np.random.random([1, 16, 3]).astype(np.float32) + else: + raise AssertionError() + else: + if attrs[0]["data_format"] == "NCHW": + return np.random.random([1, 3, 16, 32]).astype( + np.float32 + ) + elif attrs[0]["data_format"] == "NHWC": + return np.random.random([1, 16, 32, 3]).astype( + np.float32 + ) + else: + raise AssertionError() for batch in [1, 4]: - for dim1 in [0, 3]: - for dim2 in [0, 16]: - for dim3 in [0, 32]: - self.dim1 = dim1 - self.dim2 = dim2 - self.dim3 = dim3 - - if dim1 == 0 and dim2 != 0: + for dims in [0, 1, 2, 3, 4]: + for mode in ["all", "element", "channel"]: + for data_format in ["NCHW", "NHWC"]: + if (mode == "element" or mode == "all") and dims == 0: continue - if dim1 == 0 and dim2 == 0 and dim3 != 0: + if mode == "channel" and dims != 4: continue - - for mode in ["all", "channel", "element"]: - for data_format in ['NCHW', 'NHWC']: - if ( - mode == "channel" - and dim1 == 0 - and data_format == "NCHW" - ): - continue - if ( - mode == "channel" - and dim3 == 0 - and data_format == "NHWC" - ): - continue - dics = [ - {"mode": mode, "data_format": data_format} - ] - ops_config = [ - { - "op_type": "prelu", - "op_inputs": { - "X": ["input_data"], - "Alpha": ["alpha_weight"], - }, - "op_outputs": {"Out": ["output_data"]}, - "op_attrs": dics[0], - } - ] - ops = self.generate_op_config(ops_config) - - program_config = ProgramConfig( - ops=ops, - weights={ - "alpha_weight": TensorConfig( - data_gen=partial( - generate_alpha, - dics, - dim1, - dim2, - dim3, - ) - ) - }, - inputs={ - "input_data": TensorConfig( - data_gen=partial( - generate_input, - batch, - dim1, - dim2, - dim3, - ) - ), - }, - outputs=["output_data"], + self.dims = dims + dics = [{"mode": mode, "data_format": data_format}] + ops_config = [ + { + "op_type": "prelu", + "op_inputs": { + "X": ["input_data"], + "Alpha": ["alpha_weight"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": dics[0], + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "alpha_weight": TensorConfig( + data_gen=partial(generate_alpha, dics) ) - - yield program_config + }, + inputs={ + "input_data": TensorConfig( + data_gen=partial( + generate_input, dics, batch + ) + ), + }, + outputs=["output_data"], + ) + + yield program_config def sample_predictor_configs( self, program_config ) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): - if self.dim1 == 0: - self.dynamic_shape.min_input_shape = { - "input_data": [1], - } - self.dynamic_shape.max_input_shape = { - "input_data": [4], - } - self.dynamic_shape.opt_input_shape = { - "input_data": [2], - } - else: - if self.dim2 == 0 and self.dim3 == 0: + if self.dims == 0: + self.dynamic_shape.min_input_shape = {"input_data": []} + self.dynamic_shape.max_input_shape = {"input_data": []} + self.dynamic_shape.opt_input_shape = {"input_data": []} + elif self.dims == 1: + self.dynamic_shape.min_input_shape = {"input_data": [16]} + self.dynamic_shape.max_input_shape = {"input_data": [16]} + self.dynamic_shape.opt_input_shape = {"input_data": [16]} + elif self.dims == 2: + self.dynamic_shape.min_input_shape = {"input_data": [1, 3]} + self.dynamic_shape.max_input_shape = {"input_data": [1, 3]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 3]} + elif self.dims == 3: + if attrs[0]["data_format"] == "NCHW": self.dynamic_shape.min_input_shape = { - "input_data": [1, 1], + "input_data": [1, 3, 16] } self.dynamic_shape.max_input_shape = { - "input_data": [4, 32], + "input_data": [4, 3, 16] } self.dynamic_shape.opt_input_shape = { - "input_data": [2, 3], + "input_data": [1, 3, 16] } - elif self.dim2 != 0 and self.dim3 != 0: + elif attrs[0]["data_format"] == "NHWC": self.dynamic_shape.min_input_shape = { - "input_data": [1, 1, 1, 1], + "input_data": [1, 16, 3] } self.dynamic_shape.max_input_shape = { - "input_data": [4, 3, 16, 32], + "input_data": [4, 16, 3] } self.dynamic_shape.opt_input_shape = { - "input_data": [2, 3, 16, 32], + "input_data": [1, 16, 3] } - elif self.dim3 == 0: + else: + raise AssertionError() + else: + if attrs[0]["data_format"] == "NCHW": self.dynamic_shape.min_input_shape = { - "input_data": [1, 1, 1], + "input_data": [1, 3, 16, 32] } self.dynamic_shape.max_input_shape = { - "input_data": [4, 3, 32], + "input_data": [4, 3, 16, 32] } self.dynamic_shape.opt_input_shape = { - "input_data": [2, 3, 16], + "input_data": [1, 3, 16, 32] } + elif attrs[0]["data_format"] == "NHWC": + self.dynamic_shape.min_input_shape = { + "input_data": [1, 16, 32, 3] + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 16, 32, 3] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [1, 16, 32, 3] + } + else: + raise AssertionError() def clear_dynamic_shape(): self.dynamic_shape.max_input_shape = {} @@ -203,12 +198,7 @@ def clear_dynamic_shape(): ] def generate_trt_nodes_num(attrs, dynamic_shape): - if ( - not dynamic_shape - and self.dim1 == 0 - and self.dim2 == 0 - and self.dim3 == 0 - ): + if not dynamic_shape and (self.dims == 1 or self.dims == 0): return 0, 3 return 1, 2 @@ -234,23 +224,7 @@ def generate_trt_nodes_num(attrs, dynamic_shape): attrs, True ), (1e-3, 1e-3) - def add_skip_trt_case(self): - ver = paddle_infer.get_trt_compile_version() - if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 < 7000: - - def teller(program_config, predictor_config): - if not predictor_config.tensorrt_dynamic_shape_enabled(): - return True - return False - - self.add_skip_case( - teller, - SkipReasons.TRT_NOT_IMPLEMENTED, - "Need to repair the case: the output of GPU and tensorrt has diff in trt6, the prelu static plugin has bug.", - ) - def test(self): - self.add_skip_trt_case() self.run_test()