diff --git a/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc b/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc index 7ef79e547d09a..fb1fb4a6a7b39 100644 --- a/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc @@ -64,9 +64,21 @@ class HardSwishOpConverter : public OpConverter { nvinfer1::ElementWiseOperation::kPROD); layer = eltwise_layer; } else { - plugin::HardSwishPlugin* plugin = - new plugin::HardSwishPlugin(threshold, scale, offset); - layer = engine_->AddPlugin(&input, input_num, plugin); + if (engine_->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) + plugin::HardSwishPluginDynamic* plugin = + new plugin::HardSwishPluginDynamic(threshold, scale, offset); + layer = engine_->AddDynamicPlugin(&input, input_num, plugin); +#else + PADDLE_THROW(platform::errors::Fatal( + "You are running the TRT Dynamic Shape mode, need to confirm that " + "your TRT version is no less than 6.0")); +#endif + } else { + plugin::HardSwishPlugin* plugin = + new plugin::HardSwishPlugin(threshold, scale, offset); + layer = engine_->AddPlugin(&input, input_num, plugin); + } } auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode); diff --git a/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu index 28060bd2facbe..9872b1ff8d957 100644 --- a/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.cu @@ -22,10 +22,10 @@ namespace tensorrt { namespace plugin { nvinfer1::Dims HardSwishPlugin::getOutputDimensions( - int index, const nvinfer1::Dims* in_dims, int nb_inputs) TRT_NOEXCEPT { + int index, const nvinfer1::Dims *in_dims, int nb_inputs) TRT_NOEXCEPT { assert(nb_inputs == 1); assert(index < this->getNbOutputs()); - nvinfer1::Dims const& input_dims = in_dims[0]; + nvinfer1::Dims const &input_dims = in_dims[0]; nvinfer1::Dims output_dims = input_dims; return output_dims; } @@ -42,7 +42,7 @@ __device__ T kMin(T a, T b) { template __global__ void hard_swish_kernel(float threshold, float scale, float offset, - int n, const T* input, T* output) { + int n, const T *input, T *output) { const int idx = blockIdx.x * TPB + threadIdx.x; if (idx < n) { const T in = input[idx]; @@ -50,14 +50,14 @@ __global__ void hard_swish_kernel(float threshold, float scale, float offset, } } -int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs, +int HardSwishPlugin::enqueue(int batch_size, const void *const *inputs, #if IS_TRT_VERSION_LT(8000) - void** outputs, void*, cudaStream_t stream) { + void **outputs, void *, cudaStream_t stream) { #else - void* const* outputs, void*, + void *const *outputs, void *, cudaStream_t stream) TRT_NOEXCEPT { #endif - const auto& input_dims = this->getInputDims(0); + const auto &input_dims = this->getInputDims(0); int num = batch_size; for (int i = 0; i < input_dims.nbDims; i++) { num *= input_dims.d[i]; @@ -69,14 +69,79 @@ int HardSwishPlugin::enqueue(int batch_size, const void* const* inputs, const int block_size = 256; const int grid_size = (num + block_size - 1) / block_size; - const float* input = static_cast(inputs[0]); - float* output = static_cast(outputs[0]); + const float *input = static_cast(inputs[0]); + float *output = static_cast(outputs[0]); hard_swish_kernel<<>>( threshold, scale, offset, num, input, output); return cudaGetLastError() != cudaSuccess; } +#if IS_TRT_VERSION_GE(6000) + +nvinfer1::DimsExprs HardSwishPluginDynamic::getOutputDimensions( + int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, + nvinfer1::IExprBuilder &expr_builder) TRT_NOEXCEPT { + return inputs[0]; +} + +int HardSwishPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *input_desc, + const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs, + void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT { + auto input_dims = input_desc[0].dims; + int num = 1; + for (int i = 0; i < input_dims.nbDims; i++) { + num *= input_dims.d[i]; + } + float threshold = threshold_; + float scale = scale_; + float offset = offset_; + const int block_size = 256; + const int grid_size = (num + block_size - 1) / block_size; + const float *input = static_cast(inputs[0]); + float *output = static_cast(outputs[0]); + hard_swish_kernel<<>>( + threshold, scale, offset, num, input, output); + + return cudaGetLastError() != cudaSuccess; +} + +nvinfer1::DataType HardSwishPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *input_types, + int nb_inputs) const TRT_NOEXCEPT { + PADDLE_ENFORCE_EQ(index, 0, + platform::errors::InvalidArgument( + "The Elementwise Plugin only has one input, so the " + "index value should be 0, but get %d.", + index)); + return input_types[0]; +} + +bool HardSwishPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs, + int nb_outputs) TRT_NOEXCEPT { + PADDLE_ENFORCE_NOT_NULL( + in_out, platform::errors::InvalidArgument( + "The input of swish plugin shoule not be nullptr.")); + + PADDLE_ENFORCE_LT( + pos, nb_inputs + nb_outputs, + platform::errors::InvalidArgument("The pos(%d) should be less than the " + "num(%d) of the input and the output.", + pos, nb_inputs + nb_outputs)); + (in_out && pos < (nb_inputs + nb_outputs)); + + const nvinfer1::PluginTensorDesc &in = in_out[pos]; + if (pos == 0) { + return (in.type == nvinfer1::DataType::kFLOAT) && + (in.format == nvinfer1::TensorFormat::kLINEAR); + } + const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; + // output + return in.type == prev.type && in.format == prev.format; +} +#endif } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h index 5dfa00ef1c204..c0ee608c39dab 100644 --- a/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/hard_swish_op_plugin.h @@ -94,6 +94,113 @@ class HardSwishPluginCreator : public TensorRTPluginCreator { }; REGISTER_TRT_PLUGIN_V2(HardSwishPluginCreator); +#if IS_TRT_VERSION_GE(6000) +class HardSwishPluginDynamic : public DynamicPluginTensorRT { + public: + HardSwishPluginDynamic(const float threshold, const float scale, + const float offset) + : threshold_(threshold), scale_(scale), offset_(offset) {} + + // It was used for tensorrt deserialization. + // It should not be called by users. + HardSwishPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &threshold_); + DeserializeValue(&serialData, &serialLength, &scale_); + DeserializeValue(&serialData, &serialLength, &offset_); + } + ~HardSwishPluginDynamic() {} + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { + return new HardSwishPluginDynamic(threshold_, scale_, offset_); + } + const char* getPluginType() const TRT_NOEXCEPT override { + return "hard_swish_plugin_dynamic"; + } + int getNbOutputs() const TRT_NOEXCEPT override { return 1; } + int initialize() TRT_NOEXCEPT override { return 0; } + nvinfer1::DimsExprs getOutputDimensions( + int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, + nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, void* const* outputs, void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + size_t getSerializationSize() const TRT_NOEXCEPT override { + return SerializedSize(threshold_) + SerializedSize(scale_) + + SerializedSize(offset_); + } + + // TRT will call this func to serialize the configuration of TRT + // It should not be called by users. + void serialize(void* buffer) const TRT_NOEXCEPT override { + SerializeValue(&buffer, threshold_); + SerializeValue(&buffer, scale_); + SerializeValue(&buffer, offset_); + } + nvinfer1::DataType getOutputDataType( + int index, const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* inOut, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT override; + + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT override {} + void destroy() TRT_NOEXCEPT override { delete this; } + + protected: + float threshold_; + float scale_; + float offset_; +}; + +class HardSwishPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + HardSwishPluginDynamicCreator() {} + const char* getPluginName() const TRT_NOEXCEPT override { + return "hardswish_plugin_dynamic"; + } + + const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin(const char* name, + const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin( + const char* name, const void* serial_data, + size_t serial_length) TRT_NOEXCEPT override { + auto plugin = new HardSwishPluginDynamic(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const TRT_NOEXCEPT override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; +REGISTER_TRT_PLUGIN_V2(HardSwishPluginDynamicCreator); + +#endif + } // namespace plugin } // namespace tensorrt } // namespace inference diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_hard_swish.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_hard_swish.py new file mode 100644 index 0000000000000..283a19ec00574 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_hard_swish.py @@ -0,0 +1,117 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + + +class TrtConvertHardSwishTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + weights = program_config.weights + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + if attrs[0]['threshold'] <= 0 or attrs[0]['scale'] <= 0: + return False + + return True + + def sample_program_configs(self): + def generate_input1(attrs: List[Dict[str, Any]]): + return np.ones([1, 3, 64, 64]).astype(np.float32) + + for threshold in [6.0, 7.0, 100.0, 0.0, -1.0]: + for scale in [5.0, 6.0, 7.0, -1.0, 0.0, 100.0]: + for offset in [3.0, 4.0, 5.0, -1.0, 0.0, 100.0]: + dics = [{ + "threshold": threshold, + "scale": scale, + "offset": offset + }] + + ops_config = [{ + "op_type": "hard_swish", + "op_inputs": { + "X": ["input_data"] + }, + "op_outputs": { + "Out": ["hard_swish_output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig(data_gen=partial( + generate_input1, dics)) + }, + outputs=["hard_swish_output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), (1e-5, 1e-5) + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), (1e-5, 1e-5) + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main()