diff --git a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc index 3e56200dcaa52..86c7b7c9dbbae 100644 --- a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc +++ b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc @@ -157,8 +157,8 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { "nearest_interp_v2"}; // Ops must run under the original layout even though it has // data_format/data_layout attribute, otherwise it will be very troublesome! - std::unordered_set must_original_layout_ops{"affine_channel", - "softmax"}; + std::unordered_set must_original_layout_ops{ + "affine_channel", "softmax", "temporal_shift"}; // OPs unrelated to layout are consistent according to the layout of input // var! std::unordered_set any_layout_ops{"relu"}; diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 104a20e07f01d..ccda587530bfd 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2546,6 +2546,7 @@ USE_TRT_CONVERTER(grid_sampler) #endif #if IS_TRT_VERSION_GE(8200) USE_TRT_CONVERTER(set_value) +USE_TRT_CONVERTER(temporal_shift) #endif #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index adfc14f731160..487e8c9a78a04 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -101,7 +101,8 @@ list( elementwiseadd_transpose_op.cc skip_groupnorm_act_op.cc preln_groupnorm_act_op.cc - expand_v2_op.cc) + expand_v2_op.cc + temporal_shift_op.cc) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc diff --git a/paddle/fluid/inference/tensorrt/convert/temporal_shift_op.cc b/paddle/fluid/inference/tensorrt/convert/temporal_shift_op.cc new file mode 100644 index 0000000000000..03983ff393033 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/temporal_shift_op.cc @@ -0,0 +1,224 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * TemporalShiftOp. + */ +class TemporalShiftOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { +#if IS_TRT_VERSION_GE(8200) + + VLOG(3) << "convert a temporal shift op to tensorrt temporal layer"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + + const float shift_ratio = + PADDLE_GET_CONST(float, op_desc.GetAttr("shift_ratio")); + const int T = PADDLE_GET_CONST(int, op_desc.GetAttr("seg_num")); + + std::string data_format = "NCHW"; + if (op_desc.HasAttr("data_format")) { + data_format = + PADDLE_GET_CONST(std::string, op_desc.GetAttr("data_format")); + } + + if (data_format == "NHWC") { + // tanspose input to [N,C,H,W] + auto transpose_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + nvinfer1::Permutation perm{0, 3, 1, 2}; + transpose_layer->setFirstTranspose(perm); + input = transpose_layer->getOutput(0); + } + + auto input_dims = input->getDimensions(); + + const int C = input_dims.d[1]; + const int H = input_dims.d[2]; + const int W = input_dims.d[3]; + + // Reshape input to [N,T,C,H,W] + auto reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + nvinfer1::Dims reshape_dims{5, { -1, T, C, H, W }}; + reshape_layer->setReshapeDimensions(reshape_dims); + input = reshape_layer->getOutput(0); + + // Pad input to [N,T+2,C,H,W] + std::vector pre_pad_v{0, 1, 0, 0, 0}; + std::vector post_pad_v{0, 1, 0, 0, 0}; + nvinfer1::ITensor* pre_pad = Add1DConstantLayer(pre_pad_v); + nvinfer1::ITensor* post_pad = Add1DConstantLayer(post_pad_v); + + int dims = 5; + std::vector zeros_v(dims, 0); + auto const zeros = Add1DConstantLayer(zeros_v); + + nvinfer1::ITensor* start{}; + nvinfer1::ITensor* size{}; + + start = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *zeros, + *pre_pad, + nvinfer1::ElementWiseOperation::kSUB) + ->getOutput(0); + + auto const total_padding = + TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *pre_pad, + *post_pad, + nvinfer1::ElementWiseOperation::kSUM) + ->getOutput(0); + + auto const input_shape = Shape(input); + + size = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *input_shape, + *total_padding, + nvinfer1::ElementWiseOperation::kSUM) + ->getOutput(0); + nvinfer1::Dims stride; + stride.nbDims = dims; + std::fill_n(stride.d, dims, 1); + auto const& dummy = stride; + auto* slice_layer = + TRT_ENGINE_ADD_LAYER(engine_, + Slice, + *const_cast(input), + dummy, + dummy, + stride); + slice_layer->setInput(1, *start); + slice_layer->setInput(2, *size); +#if IS_TRT_VERSION_GE(8500) + slice_layer->setMode(nvinfer1::SampleMode::kFILL); +#else + slice_layer->setMode(nvinfer1::SliceMode::kFILL); +#endif + + // Slice Padded Tensor + const int slice_c = static_cast(C * shift_ratio); + const int slice_c2 = static_cast(C * shift_ratio * 2); + + nvinfer1::ITensor* slice_start1 = Add1DConstantLayer(zeros_v); + nvinfer1::ITensor* slice_start2 = + Add1DConstantLayer(std::vector{0, 2, slice_c, 0, 0}); + nvinfer1::ITensor* slice_start3 = + Add1DConstantLayer(std::vector{0, 1, slice_c2, 0, 0}); + + nvinfer1::ITensor* slice_size_base = Shape(input); + nvinfer1::ITensor* sub_size1 = + Add1DConstantLayer(std::vector{0, 0, C - slice_c, 0, 0}); + nvinfer1::ITensor* sub_size2 = Add1DConstantLayer( + std::vector{0, 0, C + slice_c - slice_c2, 0, 0}); + nvinfer1::ITensor* sub_size3 = + Add1DConstantLayer(std::vector{0, 0, slice_c2, 0, 0}); + // [N, T, C, H, W] - [0, 0, C - slice_c, 0, 0] = [N, T, slice_c, H, W] + nvinfer1::ITensor* slice_size1 = + TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *slice_size_base, + *sub_size1, + nvinfer1::ElementWiseOperation::kSUB) + ->getOutput(0); + + nvinfer1::ITensor* slice_size2 = + TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *slice_size_base, + *sub_size2, + nvinfer1::ElementWiseOperation::kSUB) + ->getOutput(0); + nvinfer1::ITensor* slice_size3 = + TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *slice_size_base, + *sub_size3, + nvinfer1::ElementWiseOperation::kSUB) + ->getOutput(0); + + auto* slice1_layer = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *slice_layer->getOutput(0), dummy, dummy, stride); + slice1_layer->setInput(1, *slice_start1); + slice1_layer->setInput(2, *slice_size1); + + auto* slice2_layer = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *slice_layer->getOutput(0), dummy, dummy, stride); + slice2_layer->setInput(1, *slice_start2); + slice2_layer->setInput(2, *slice_size2); + + auto* slice3_layer = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *slice_layer->getOutput(0), dummy, dummy, stride); + slice3_layer->setInput(1, *slice_start3); + slice3_layer->setInput(2, *slice_size3); + + // Concatenate slices along the third dimension (C) + nvinfer1::IConcatenationLayer* concat_layer; + if (!slice_c) { + nvinfer1::ITensor* concat_inputs[2] = {slice2_layer->getOutput(0), + slice3_layer->getOutput(0)}; + concat_layer = + TRT_ENGINE_ADD_LAYER(engine_, Concatenation, concat_inputs, 2); + concat_layer->setAxis(2); + } else { + nvinfer1::ITensor* concat_inputs[3] = {slice1_layer->getOutput(0), + slice2_layer->getOutput(0), + slice3_layer->getOutput(0)}; + concat_layer = + TRT_ENGINE_ADD_LAYER(engine_, Concatenation, concat_inputs, 3); + concat_layer->setAxis(2); + } + + // Reshape output to [N*T,C,H,W] + auto* reshape_layer3 = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *concat_layer->getOutput(0)); + reshape_layer3->setReshapeDimensions(input_dims); + + // Set output + auto output_name = op_desc.Output("Out")[0]; + + if (data_format == "NHWC") { + // Transpose output to [N*T,C,H,W] -> [N*T,H,W,C] + auto transpose_layer2 = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *reshape_layer3->getOutput(0)); + nvinfer1::Permutation permute_order{0, 2, 3, 1}; + transpose_layer2->setFirstTranspose(permute_order); + RreplenishLayerAndOutput( + transpose_layer2, "temporal_shift", {output_name}, test_mode); + } else { + RreplenishLayerAndOutput( + reshape_layer3, "temporal_shift", {output_name}, test_mode); + } +#else + VLOG(3) << "Temporal shift is not supported when TensorRT < 8.2"; +#endif + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(temporal_shift, TemporalShiftOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 725e4fd75e9ea..887b4de910491 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2584,6 +2584,42 @@ struct SimpleOpTypeSetTeller : public Teller { #endif } + if (op_type == "temporal_shift") { +#if !IS_TRT_VERSION_GE(8200) + VLOG(3) << "temporal_shift is not supported when TensorRT < 8.2"; + return false; +#endif + + if (!with_dynamic_shape) { + VLOG(3) << "the temporal shift does not support " + "static shape yet"; + return false; + } + + if (!desc.HasAttr("shift_ratio") || !desc.HasAttr("seg_num")) { + VLOG(3) << "temporal shift need attributes : shift_ratio and seg_num"; + return false; + } + + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + + auto input_name = desc.Input("X")[0]; + auto* input_desc = block->FindVar(input_name); + const auto input_shape = input_desc->GetShape(); + + if (input_shape.size() != 4) { + VLOG(3) << "The input and grid tensors must be shape tensors of rank 4 " + "using TRT TemporalShift layer."; + return false; + } + } + if (use_no_calib_int8) { return int8_teller_set.count(op_type); } else { @@ -2745,6 +2781,7 @@ struct SimpleOpTypeSetTeller : public Teller { "fuse_eleadd_transpose", "skip_groupnorm_act", "preln_groupnorm_act", + "temporal_shift", "grid_sampler"}; std::unordered_set teller_set{ @@ -2899,6 +2936,7 @@ struct SimpleOpTypeSetTeller : public Teller { "fuse_eleadd_transpose", "skip_groupnorm_act", "preln_groupnorm_act", + "temporal_shift", "grid_sampler"}; }; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_temporal_shift.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_temporal_shift.py new file mode 100755 index 0000000000000..b0b2ce5106213 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_temporal_shift.py @@ -0,0 +1,135 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertTemporalShiftTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input1(attrs): + T = attrs[0]["seg_num"] + shape = [2 * T, 10, 64, 64] + return np.random.uniform(low=0.1, high=1.0, size=shape).astype( + np.float32 + ) + + for shift_value in [0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.49]: + for T in range(2, 5): + for data_format in ["NCHW", "NHWC"]: + dics = [ + { + "shift_ratio": shift_value, + "seg_num": T, + "data_format": data_format, + }, + {}, + ] + ops_config = [ + { + "op_type": "temporal_shift", + "op_inputs": {"X": ["input_data"]}, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": dics[0], + } + ] + + ops = self.generate_op_config(ops_config) + for i in range(10): + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, dics) + ), + }, + outputs=["output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + t = attrs[0]['seg_num'] + self.dynamic_shape.min_input_shape = { + "input_data": [2 * t, 10, 64, 64] + } + self.dynamic_shape.max_input_shape = { + "input_data": [5 * t, 10, 64, 64] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [3 * t, 10, 64, 64] + } + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, is_dynamic_shape): + valid_version = (8, 2, 0) + compile_version = paddle_infer.get_trt_compile_version() + runtime_version = paddle_infer.get_trt_runtime_version() + self.assertTrue(compile_version == runtime_version) + if compile_version < valid_version: + return 0, 3 + if is_dynamic_shape: + return 1, 2 + return 0, 3 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-3 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-3 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main()