diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp new file mode 100644 index 0000000000..ac49c8b29e --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp @@ -0,0 +1,216 @@ +#include "trt_corner_pool.hpp" + +#include + +#include "trt_serialize.hpp" + +void CornerPoolForwardLauncher_float(const float *input, float *output, + const int batch_size, const int channels, + const int height, const int width, + const int pool_type, cudaStream_t stream); + +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *CORNER_POOL_PLUGIN_NAME{"MMCVCornerPool"}; +} // namespace + +CornerPoolPluginDynamic::CornerPoolPluginDynamic(const std::string &name, + TRT_CORNER_POOL_TYPE poolType) + : mLayerName(name), mPoolType(poolType) {} + +CornerPoolPluginDynamic::CornerPoolPluginDynamic(const std::string name, + const void *data, + size_t length) + : mLayerName(name) { + deserialize_value(&data, &length, &mPoolType); +} + +CornerPoolPluginDynamic::~CornerPoolPluginDynamic() {} + +nvinfer1::IPluginV2DynamicExt *CornerPoolPluginDynamic::clone() const { + CornerPoolPluginDynamic *plugin = + new CornerPoolPluginDynamic(mLayerName, mPoolType); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs CornerPoolPluginDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) { + return inputs[0]; +} + +bool CornerPoolPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, + int nbOutputs) { + switch (pos) { + // input[0] + case 0: + return inOut[pos].type == nvinfer1::DataType::kFLOAT && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + // output[0] + case 1: + return inOut[pos].type == inOut[0].type && + inOut[pos].format == inOut[0].format; + default: + return false; + } +} + +void CornerPoolPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {} + +size_t CornerPoolPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const { + int sizeof_dtype = mmcv::getElementSize(outputs[0].type); +} + +int CornerPoolPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workSpace, cudaStream_t stream) { + const void *input = inputs[0]; + void *output_value = outputs[0]; + + const int batch_size = inputDesc[0].dims.d[0]; + const int channels = inputDesc[0].dims.d[1]; + const int height = inputDesc[0].dims.d[2]; + const int width = inputDesc[0].dims.d[3]; + + CornerPoolForwardLauncher_float((float *)input, (float *)output_value, + batch_size, channels, height, width, + int(mPoolType), stream); + + return 0; +} + +nvinfer1::DataType CornerPoolPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { + return inputTypes[0]; +} + +// IPluginV2 Methods +const char *CornerPoolPluginDynamic::getPluginType() const { + switch (mPoolType) { + case TRT_CORNER_POOL_TYPE::TRT_TOP_POOL: + case TRT_CORNER_POOL_TYPE::TRT_BOTTOM_POOL: + case TRT_CORNER_POOL_TYPE::TRT_LEFT_POOL: + case TRT_CORNER_POOL_TYPE::TRT_RIGHT_POOL: + return CORNER_POOL_PLUGIN_NAME; + + default: + return "UnknownpoolType"; + } +} + +const char *CornerPoolPluginDynamic::getPluginVersion() const { + return PLUGIN_VERSION; +} + +int CornerPoolPluginDynamic::getNbOutputs() const { return 1; } + +int CornerPoolPluginDynamic::initialize() { return 0; } + +void CornerPoolPluginDynamic::terminate() {} + +size_t CornerPoolPluginDynamic::getSerializationSize() const { + return sizeof(mPoolType); +} + +void CornerPoolPluginDynamic::serialize(void *buffer) const { + serialize_value(&buffer, mPoolType); +} + +void CornerPoolPluginDynamic::destroy() { + // This gets called when the network containing plugin is destroyed + delete this; +} + +void CornerPoolPluginDynamic::setPluginNamespace(const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *CornerPoolPluginDynamic::getPluginNamespace() const { + return mNamespace.c_str(); +} + +CornerPoolPluginDynamicCreator::CornerPoolPluginDynamicCreator() { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("mode")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *CornerPoolPluginDynamicCreator::getPluginName() const { + return CORNER_POOL_PLUGIN_NAME; +} + +const char *CornerPoolPluginDynamicCreator::getPluginVersion() const { + return PLUGIN_VERSION; +} + +const nvinfer1::PluginFieldCollection * +CornerPoolPluginDynamicCreator::getFieldNames() { + return &mFC; +} + +nvinfer1::IPluginV2 *CornerPoolPluginDynamicCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) { + TRT_CORNER_POOL_TYPE poolType; + int poolMode = -1; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("mode") == 0) { + poolMode = static_cast(fc->fields[i].data)[0]; + } + } + + assert(poolMode >= 0 && poolMode <= 3); + switch (poolMode) { + case 0: + poolType = TRT_CORNER_POOL_TYPE::TRT_TOP_POOL; + break; + case 1: + poolType = TRT_CORNER_POOL_TYPE::TRT_BOTTOM_POOL; + break; + case 2: + poolType = TRT_CORNER_POOL_TYPE::TRT_LEFT_POOL; + break; + case 3: + poolType = TRT_CORNER_POOL_TYPE::TRT_RIGHT_POOL; + break; + + default: + break; + } + + CornerPoolPluginDynamic *plugin = new CornerPoolPluginDynamic(name, poolType); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *CornerPoolPluginDynamicCreator::deserializePlugin( + const char *name, const void *serialData, size_t serialLength) { + // This object will be deleted when the network is destroyed, which will + // call FCPluginDynamic::destroy() + auto plugin = new CornerPoolPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +void CornerPoolPluginDynamicCreator::setPluginNamespace( + const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *CornerPoolPluginDynamicCreator::getPluginNamespace() const { + return mNamespace.c_str(); +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu new file mode 100644 index 0000000000..0d7bf03f54 --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu @@ -0,0 +1,109 @@ +#include "common_cuda_helper.hpp" +#include "trt_cuda_helper.cuh" +#include "trt_plugin_helper.hpp" + +template +__global__ void top_bottom_pool_kernel(const scalar_t *input, scalar_t *output, + const int batch_size, const int channels, + const int height, const int width, + const int pool_type) { + const int nthreads = batch_size * channels * width; + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int n_idx = index / (channels * width); // batch + int w_idx = index % width; // width + int c_idx = (index / width) % channels; // channels + int offset_n = n_idx * channels * width * height; + int offset_n_c = offset_n + c_idx * width * height; + int direction = -1; // in [-1, 1], default for TopPool + int index_start = height - 2; // default for TopPool + // pool_type in [0, 1] + if (pool_type == 0) { + // TopPool + // directly copy the most bottom value from input to output + output[offset_n_c + (height - 1) * width + w_idx] = + input[offset_n_c + (height - 1) * width + w_idx]; + } else { + // BottomPool + // directly copy the most top value from input to output + output[offset_n_c + w_idx] = input[offset_n_c + w_idx]; + index_start = 1; + direction = 1; + } + // do pool + for (int h = index_start; h >= 0 && h < height; h += direction) { + output[offset_n_c + h * width + w_idx] = + max(output[offset_n_c + (h - direction) * width + w_idx], + input[offset_n_c + h * width + w_idx]); + } + } +} + +template +__global__ void left_right_pool_kernel(const scalar_t *input, scalar_t *output, + const int batch_size, const int channels, + const int height, const int width, + const int pool_type) { + const int nthreads = batch_size * channels * height; + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int n_idx = index / (channels * height); // batch + int h_idx = index % height; // height + int c_idx = (index / height) % channels; // channels + int offset_n = n_idx * channels * width * height; + int offset_n_c = offset_n + c_idx * width * height; + int offset_n_c_h = offset_n_c + h_idx * width; + int direction = -1; // in [-1, 1], default for LeftPool + int index_start = width - 2; // default for LeftPool + // pool_type in [2, 3] + if (pool_type == 2) { + // LeftPool + // directly copy the most right value from input to output + output[offset_n_c_h + width - 1] = input[offset_n_c_h + width - 1]; + } else { + // RightPool + // directly copy the most left value from input to output + output[offset_n_c_h] = input[offset_n_c_h]; + index_start = 1; + direction = 1; + } + // do pool + for (int w = index_start; w >= 0 && w < width; w += direction) { + output[offset_n_c_h + w] = + max(output[offset_n_c_h + w - direction], input[offset_n_c_h + w]); + } + } +} + +template +void CornerPoolForwardLauncher(const scalar_t *input, scalar_t *output, + const int batch_size, const int channels, + const int height, const int width, + const int pool_type, cudaStream_t stream) { + int nthreads = -1, col_block = -1; + + switch (pool_type) { + case 0: + case 1: + nthreads = batch_size * channels * width; + col_block = DIVUP(nthreads, THREADS_PER_BLOCK); + top_bottom_pool_kernel + <<>>( + input, output, batch_size, channels, height, width, pool_type); + break; + case 2: + case 3: + nthreads = batch_size * channels * height; + col_block = DIVUP(nthreads, THREADS_PER_BLOCK); + left_right_pool_kernel + <<>>( + input, output, batch_size, channels, height, width, pool_type); + break; + } +} + +void CornerPoolForwardLauncher_float(const float *input, float *output, + const int batch_size, const int channels, + const int height, const int width, + const int pool_type, cudaStream_t stream) { + CornerPoolForwardLauncher(input, output, batch_size, channels, height, + width, pool_type, stream); +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp index c7b946b5dd..d5a1770529 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp @@ -1,5 +1,6 @@ #include "trt_plugin.hpp" +#include "trt_corner_pool.hpp" #include "trt_cummaxmin.hpp" #include "trt_deform_conv.hpp" #include "trt_grid_sampler.hpp" @@ -18,6 +19,7 @@ REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator); REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); REGISTER_TENSORRT_PLUGIN(InstanceNormalizationDynamicCreator); +REGISTER_TENSORRT_PLUGIN(CornerPoolPluginDynamicCreator); extern "C" { bool initLibMMCVInferPlugins() { return true; } diff --git a/mmcv/ops/csrc/tensorrt/trt_corner_pool.hpp b/mmcv/ops/csrc/tensorrt/trt_corner_pool.hpp new file mode 100644 index 0000000000..f34e15b312 --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/trt_corner_pool.hpp @@ -0,0 +1,111 @@ +#ifndef TRT_CORNER_POOL_HPP +#define TRT_CORNER_POOL_HPP +#include +#include + +#include "trt_plugin_helper.hpp" + +enum TRT_CORNER_POOL_TYPE { + TRT_TOP_POOL = 0, + TRT_BOTTOM_POOL = 1, + TRT_LEFT_POOL = 2, + TRT_RIGHT_POOL = 3 +}; + +// implement of CornerPool +class CornerPoolPluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + CornerPoolPluginDynamic(const std::string &name, + TRT_CORNER_POOL_TYPE poolType); + + CornerPoolPluginDynamic(const std::string name, const void *data, + size_t length); + + CornerPoolPluginDynamic() = delete; + + ~CornerPoolPluginDynamic(); + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc *inOut, + int nbInputs, int nbOutputs) override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workspace, + cudaStream_t stream) override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType *inputTypes, + int nbInputs) const override; + + // IPluginV2 Methods + const char *getPluginType() const override; + const char *getPluginVersion() const override; + int getNbOutputs() const override; + int initialize() override; + void terminate() override; + size_t getSerializationSize() const override; + void serialize(void *buffer) const override; + void destroy() override; + void setPluginNamespace(const char *pluginNamespace) override; + const char *getPluginNamespace() const override; + + protected: + const std::string mLayerName; + std::string mNamespace; + + TRT_CORNER_POOL_TYPE mPoolType; + + protected: + // To prevent compiler warnings. + using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::configurePlugin; + using nvinfer1::IPluginV2DynamicExt::enqueue; + using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; + using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize; + using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::supportsFormat; +}; + +// CornerPool creator +class CornerPoolPluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + CornerPoolPluginDynamicCreator(); + + const char *getPluginName() const override; + + const char *getPluginVersion() const override; + + const nvinfer1::PluginFieldCollection *getFieldNames() override; + + nvinfer1::IPluginV2 *createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) override; + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, + const void *serialData, + size_t serialLength) override; + + void setPluginNamespace(const char *pluginNamespace) override; + + const char *getPluginNamespace() const override; + + protected: + nvinfer1::PluginFieldCollection mFC; + std::vector mPluginAttributes; + std::string mNamespace; +}; + +#endif TRT_CORNER_POOL_HPP // TRT_CORNER_POOL_HPP diff --git a/tests/test_ops/test_tensorrt.py b/tests/test_ops/test_tensorrt.py index d65308ba8a..2f89e5818c 100644 --- a/tests/test_ops/test_tensorrt.py +++ b/tests/test_ops/test_tensorrt.py @@ -727,3 +727,81 @@ def test_instance_norm(dynamic_export, fp16_mode): if os.path.exists(trt_file): os.remove(trt_file) assert torch.allclose(pytorch_results, trt_results) + + +@pytest.mark.parametrize('mode', ['top', 'bottom', 'left', 'right']) +def test_corner_pool(mode): + try: + from mmcv.ops import CornerPool + except (ImportError, ModuleNotFoundError): + pytest.skip('test requires compilation') + + opset = 11 + # register custom op `mmcv::MMCVCornerPool` + from mmcv.onnx.symbolic import register_extra_symbolics + register_extra_symbolics(opset) + + # trt config + fp16_mode = False + max_workspace_size = 1 << 30 + + inputs = [ + # (n, c, h, w) + torch.rand((2, 3, 5, 5)), + torch.rand((1, 2, 4, 6)), + torch.rand((2, 1, 3, 2)), + ] + + class CornerPoolWrapper(CornerPool): + + def __init__(self, mode): + super(CornerPoolWrapper, self).__init__(mode) + + def forward(self, x): + # no use `torch.cummax`, instead `corner_pool` is used + # for various torch version + return self.corner_pool.apply(x) + + wrapped_model = CornerPoolWrapper(mode).cuda() + for input in inputs: + input = input.cuda() + + with torch.no_grad(): + torch.onnx.export( + wrapped_model, (input, ), + onnx_file, + export_params=True, + keep_initializers_as_inputs=True, + input_names=['input'], + output_names=['output'], + opset_version=opset) + onnx_model = onnx.load(onnx_file) + + # create trt engine and wraper + opt_shape_dict = { + 'input': [list(input.shape), + list(input.shape), + list(input.shape)], + } + trt_engine = onnx2trt( + onnx_model, + opt_shape_dict, + fp16_mode=fp16_mode, + max_workspace_size=max_workspace_size) + save_trt_engine(trt_engine, trt_file) + trt_model = TRTWrapper(trt_file, ['input'], ['output']) + + with torch.no_grad(): + trt_outputs = trt_model({'input': input}) + trt_pool_feat = trt_outputs['output'] + + # compute pytorch_output + with torch.no_grad(): + pytorch_pool_feat = wrapped_model(input) + + # allclose + if os.path.exists(onnx_file): + os.remove(onnx_file) + if os.path.exists(trt_file): + os.remove(trt_file) + assert torch.allclose(pytorch_pool_feat, trt_pool_feat, atol=1e-5)