From 4d64059ecfe9d7acef6c79e117d4fccfb906cb3d Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 18 May 2021 18:47:22 +0800 Subject: [PATCH 1/5] add cummax/cummin tensorrt plugin --- .../csrc/tensorrt/plugins/trt_cummaxmin.cpp | 242 ++++++++++++++++++ .../tensorrt/plugins/trt_cummaxmin_kernel.cu | 89 +++++++ mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp | 5 +- mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp | 124 +++++++++ mmcv/tensorrt/tensorrt_utils.py | 7 +- tests/test_ops/test_tensorrt.py | 97 +++++++ 6 files changed, 562 insertions(+), 2 deletions(-) create mode 100644 mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp create mode 100644 mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu create mode 100644 mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp new file mode 100644 index 0000000000..662426dfba --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp @@ -0,0 +1,242 @@ +#include "trt_cummaxmin.hpp" + +#include + +#include "trt_serialize.hpp" + +void CumMaxMinForwardLauncher_float(const float *input, float *output_value, + int *output_index, const int *dims, + int nbDims, int cum_dim, int cum_type, + cudaStream_t stream); + +void CumMaxMinForwardLauncher_int32(const int *input, int *output_value, + int *output_index, const int *dims, + int nbDims, int cum_dim, int cum_type, + cudaStream_t stream); + +namespace { +static const char *PLUGIN_VERSION{"1"}; +static const char *CUMMAXMIN_PLUGIN_NAME{"cummaxmin"}; +static const char *CUMMAX_PLUGIN_NAME{"cummax"}; +static const char *CUMMIN_PLUGIN_NAME{"cummin"}; +} // namespace + +CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string &name, int dim, + TRT_CUMCMPTYPE cumType) + : mLayerName(name), mDim(dim), mCumType(cumType) {} + +CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string name, + const void *data, size_t length) + : mLayerName(name) { + deserialize_value(&data, &length, &mDim); + deserialize_value(&data, &length, &mCumType); +} + +CumMaxMinPluginDynamic::~CumMaxMinPluginDynamic() {} + +nvinfer1::IPluginV2DynamicExt *CumMaxMinPluginDynamic::clone() const { + CumMaxMinPluginDynamic *plugin = + new CumMaxMinPluginDynamic(mLayerName, mDim, mCumType); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; +} + +nvinfer1::DimsExprs CumMaxMinPluginDynamic::getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) { + return inputs[0]; +} + +bool CumMaxMinPluginDynamic::supportsFormatCombination( + int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, + int nbOutputs) { + switch (pos) { + // input[0] + case 0: + return (inOut[pos].type == nvinfer1::DataType::kFLOAT || + inOut[pos].type == nvinfer1::DataType::kINT32) && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + // output[0] + case 1: + return inOut[pos].type == inOut[0].type && + inOut[pos].format == inOut[0].format; + // output[1] + case 2: + return inOut[pos].type == nvinfer1::DataType::kINT32 && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + default: + return false; + } +} + +void CumMaxMinPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {} + +size_t CumMaxMinPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc *inputs, int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const { + int sizeof_dtype = mmcv::getElementSize(outputs[0].type); +} + +int CumMaxMinPluginDynamic::enqueue( + const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, + void *const *outputs, void *workSpace, cudaStream_t stream) { + + const void *input = inputs[0]; + void *output_value = outputs[0]; + int *output_index = (int *)outputs[1]; + + const int *dims = &(inputDesc[0].dims.d[0]); + int nbDims = inputDesc[0].dims.nbDims; + + switch (inputDesc[0].type) { + case nvinfer1::DataType::kFLOAT: + CumMaxMinForwardLauncher_float((float *)input, (float *)output_value, + output_index, dims, nbDims, mDim, + int(mCumType), stream); + break; + case nvinfer1::DataType::kINT32: + CumMaxMinForwardLauncher_int32((int *)input, (int *)output_value, + output_index, dims, nbDims, mDim, + int(mCumType), stream); + break; + default: + break; + } + + return 0; +} + +nvinfer1::DataType CumMaxMinPluginDynamic::getOutputDataType( + int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { + switch (index) { + case 0: + return inputTypes[0]; + case 1: + return nvinfer1::DataType::kINT32; + default: + break; + } +} + +// IPluginV2 Methods +const char *CumMaxMinPluginDynamic::getPluginType() const { + switch (mCumType) { + case TRT_CUMCMPTYPE::TRT_CUMMAX: + return CUMMAX_PLUGIN_NAME; + case TRT_CUMCMPTYPE::TRT_CUMMIN: + return CUMMIN_PLUGIN_NAME; + default: + return "UnknownCumType"; + } +} + +const char *CumMaxMinPluginDynamic::getPluginVersion() const { + return PLUGIN_VERSION; +} + +int CumMaxMinPluginDynamic::getNbOutputs() const { return 2; } + +int CumMaxMinPluginDynamic::initialize() { return 0; } + +void CumMaxMinPluginDynamic::terminate() {} + +size_t CumMaxMinPluginDynamic::getSerializationSize() const { + return sizeof(mDim) + sizeof(mCumType); +} + +void CumMaxMinPluginDynamic::serialize(void *buffer) const { + serialize_value(&buffer, mDim); + serialize_value(&buffer, mCumType); +} + +void CumMaxMinPluginDynamic::destroy() { + // This gets called when the network containing plugin is destroyed + delete this; +} + +void CumMaxMinPluginDynamic::setPluginNamespace(const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *CumMaxMinPluginDynamic::getPluginNamespace() const { + return mNamespace.c_str(); +} + +CumMaxMinPluginDynamicCreator::CumMaxMinPluginDynamicCreator( + TRT_CUMCMPTYPE cumType) + : mCumType(cumType) { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("dim")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); +} + +const char *CumMaxMinPluginDynamicCreator::getPluginName() const { + return CUMMAXMIN_PLUGIN_NAME; +} + +const char *CumMaxMinPluginDynamicCreator::getPluginVersion() const { + return PLUGIN_VERSION; +} + +const nvinfer1::PluginFieldCollection * +CumMaxMinPluginDynamicCreator::getFieldNames() { + return &mFC; +} + +nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) { + int dim = 0; + + for (int i = 0; i < fc->nbFields; i++) { + if (fc->fields[i].data == nullptr) { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("dim") == 0) { + dim = static_cast(fc->fields[i].data)[0]; + } + } + + CumMaxMinPluginDynamic *plugin = + new CumMaxMinPluginDynamic(name, dim, mCumType); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +nvinfer1::IPluginV2 *CumMaxMinPluginDynamicCreator::deserializePlugin( + const char *name, const void *serialData, size_t serialLength) { + // This object will be deleted when the network is destroyed, which will + // call FCPluginDynamic::destroy() + auto plugin = new CumMaxMinPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; +} + +void CumMaxMinPluginDynamicCreator::setPluginNamespace( + const char *libNamespace) { + mNamespace = libNamespace; +} + +const char *CumMaxMinPluginDynamicCreator::getPluginNamespace() const { + return mNamespace.c_str(); +} + +CumMaxPluginDynamicCreator::CumMaxPluginDynamicCreator() + : CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMAX) {} + +const char *CumMaxPluginDynamicCreator::getPluginName() const { + return CUMMAX_PLUGIN_NAME; +} + +CumMinPluginDynamicCreator::CumMinPluginDynamicCreator() + : CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE::TRT_CUMMIN) {} + +const char *CumMinPluginDynamicCreator::getPluginName() const { + return CUMMIN_PLUGIN_NAME; +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu new file mode 100644 index 0000000000..15780c1a83 --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu @@ -0,0 +1,89 @@ + +#include "common_cuda_helper.hpp" +#include "trt_cuda_helper.cuh" +#include "trt_plugin_helper.hpp" + +using mmcv::TensorDesc; + +template +__global__ void cummaxmin_kernel(const scalar_t *input, scalar_t *output_value, + int *output_index, TensorDesc tensor_desc, + int cum_dim, int cum_type) { + const size_t cum_size = tensor_desc.shape[cum_dim]; + const size_t cum_stride = tensor_desc.stride[cum_dim]; + const size_t data_size = + tensor_desc.stride[0] * tensor_desc.shape[0] / cum_size; + CUDA_1D_KERNEL_LOOP(index, data_size) { + size_t cum_offset = + index / cum_stride * (cum_size * cum_stride) + index % cum_stride; + int cum_index = 0; + auto cum_value = input[cum_offset]; + output_value[cum_offset] = cum_value; + output_index[cum_offset] = cum_index; + + for (size_t cum_index_current = 1; cum_index_current < cum_size; + ++cum_index_current) { + cum_offset += cum_stride; + const auto cum_value_current = input[cum_offset]; + switch (cum_type) { + case 0: // max + if (cum_value_current > cum_value) { + cum_value = cum_value_current; + cum_index = cum_index_current; + } + break; + case 1: // min + if (cum_value_current < cum_value) { + cum_value = cum_value_current; + cum_index = cum_index_current; + } + break; + } + output_value[cum_offset] = cum_value; + output_index[cum_offset] = cum_index; + } + } +} + +template +void CumMaxMinForwardLauncher(const scalar_t *input, scalar_t *output_value, + int *output_index, const int *dims, int nbDims, + int cum_dim, int cum_type, cudaStream_t stream) { + // fill tensordesc and initial + TensorDesc tensor_desc; + memset((void *)&tensor_desc, 0, sizeof(TensorDesc)); + tensor_desc.dim = nbDims; + tensor_desc.shape[nbDims - 1] = dims[nbDims - 1]; + tensor_desc.stride[nbDims - 1] = 1; + for (int i = nbDims - 2; i >= 0; --i) { + tensor_desc.shape[i] = dims[i]; + tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1]; + } + + // cum dim should be larger than 0 + cum_dim = cum_dim >= 0 ? cum_dim : (nbDims + cum_dim); + + const int data_size = + tensor_desc.stride[0] * tensor_desc.shape[0] / tensor_desc.shape[cum_dim]; + + const int col_block = DIVUP(data_size, THREADS_PER_BLOCK); + + cummaxmin_kernel<<>>( + input, output_value, output_index, tensor_desc, cum_dim, cum_type); +} + +void CumMaxMinForwardLauncher_float(const float *input, float *output_value, + int *output_index, const int *dims, + int nbDims, int cum_dim, int cum_type, + cudaStream_t stream) { + CumMaxMinForwardLauncher(input, output_value, output_index, dims, + nbDims, cum_dim, cum_type, stream); +} + +void CumMaxMinForwardLauncher_int32(const int *input, int *output_value, + int *output_index, const int *dims, + int nbDims, int cum_dim, int cum_type, + cudaStream_t stream) { + CumMaxMinForwardLauncher(input, output_value, output_index, dims, nbDims, + cum_dim, cum_type, stream); +} diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp index 06d034c365..3dbf315cfb 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp @@ -1,11 +1,14 @@ #include "trt_plugin.hpp" +#include "trt_cummaxmin.hpp" #include "trt_deform_conv.hpp" #include "trt_grid_sampler.hpp" #include "trt_nms.hpp" #include "trt_roi_align.hpp" #include "trt_scatternd.hpp" +REGISTER_TENSORRT_PLUGIN(CumMaxPluginDynamicCreator); +REGISTER_TENSORRT_PLUGIN(CumMinPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(GridSamplerDynamicCreator); REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator); @@ -14,4 +17,4 @@ REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); extern "C" { bool initLibMMCVInferPlugins() { return true; } -} // extern "C" +} // extern "C" diff --git a/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp b/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp new file mode 100644 index 0000000000..5665ba919f --- /dev/null +++ b/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp @@ -0,0 +1,124 @@ +#ifndef TRT_CUMMAXMIN_HPP +#define TRT_CUMMAXMIN_HPP +#include +#include + +#include "trt_plugin_helper.hpp" + +enum TRT_CUMCMPTYPE { TRT_CUMMAX = 0, TRT_CUMMIN = 1 }; + +// implement of cummax and cummin +class CumMaxMinPluginDynamic : public nvinfer1::IPluginV2DynamicExt { +public: + CumMaxMinPluginDynamic(const std::string &name, int dim, + TRT_CUMCMPTYPE cumType); + + CumMaxMinPluginDynamic(const std::string name, const void *data, + size_t length); + + CumMaxMinPluginDynamic() = delete; + + ~CumMaxMinPluginDynamic(); + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt *clone() const override; + nvinfer1::DimsExprs + getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, + int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) override; + bool supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc *inOut, + int nbInputs, int nbOutputs) override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc *out, + int nbOutputs) override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc *outputs, + int nbOutputs) const override; + int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, + const nvinfer1::PluginTensorDesc *outputDesc, + const void *const *inputs, void *const *outputs, void *workspace, + cudaStream_t stream) override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, + const nvinfer1::DataType *inputTypes, + int nbInputs) const override; + + // IPluginV2 Methods + const char *getPluginType() const override; + const char *getPluginVersion() const override; + int getNbOutputs() const override; + int initialize() override; + void terminate() override; + size_t getSerializationSize() const override; + void serialize(void *buffer) const override; + void destroy() override; + void setPluginNamespace(const char *pluginNamespace) override; + const char *getPluginNamespace() const override; + +protected: + const std::string mLayerName; + std::string mNamespace; + + int mDim; + TRT_CUMCMPTYPE mCumType; + +protected: + // To prevent compiler warnings. + using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::configurePlugin; + using nvinfer1::IPluginV2DynamicExt::enqueue; + using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; + using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize; + using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::supportsFormat; +}; + +// cummax and cummin creator +class CumMaxMinPluginDynamicCreator : public nvinfer1::IPluginCreator { +public: + CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE cumType); + + const char *getPluginName() const override; + + const char *getPluginVersion() const override; + + const nvinfer1::PluginFieldCollection *getFieldNames() override; + + nvinfer1::IPluginV2 * + createPlugin(const char *name, + const nvinfer1::PluginFieldCollection *fc) override; + + nvinfer1::IPluginV2 *deserializePlugin(const char *name, + const void *serialData, + size_t serialLength) override; + + void setPluginNamespace(const char *pluginNamespace) override; + + const char *getPluginNamespace() const override; + +protected: + TRT_CUMCMPTYPE mCumType; + nvinfer1::PluginFieldCollection mFC; + std::vector mPluginAttributes; + std::string mNamespace; +}; + +// cummax creator +class CumMaxPluginDynamicCreator : public CumMaxMinPluginDynamicCreator { +public: + CumMaxPluginDynamicCreator(); + const char *getPluginName() const override; +}; + +// cummin creator +class CumMinPluginDynamicCreator : public CumMaxMinPluginDynamicCreator { +public: + CumMinPluginDynamicCreator(); + const char *getPluginName() const override; +}; + +#endif TRT_CUMMAXMIN_HPP // TRT_CUMMAXMIN_HPP diff --git a/mmcv/tensorrt/tensorrt_utils.py b/mmcv/tensorrt/tensorrt_utils.py index 5966881df2..22e1e8860a 100644 --- a/mmcv/tensorrt/tensorrt_utils.py +++ b/mmcv/tensorrt/tensorrt_utils.py @@ -238,7 +238,7 @@ class TRTWraper(torch.nn.Module): output_names should be the same as onnx model. """ - def __init__(self, engine, input_names, output_names): + def __init__(self, engine, input_names=None, output_names=None): super(TRTWraper, self).__init__() self.engine = engine if isinstance(self.engine, str): @@ -250,6 +250,11 @@ def __init__(self, engine, input_names, output_names): self._register_state_dict_hook(TRTWraper._on_state_dict) self.context = self.engine.create_execution_context() + # get input and output names from engine + if input_names is None or output_names is None: + names = [_ for _ in self.engine] + input_names = list(filter(self.engine.binding_is_input, names)) + output_names = list(set(names) - set(input_names)) self.input_names = input_names self.output_names = output_names diff --git a/tests/test_ops/test_tensorrt.py b/tests/test_ops/test_tensorrt.py index 3f8fe473c8..d76e3b228f 100644 --- a/tests/test_ops/test_tensorrt.py +++ b/tests/test_ops/test_tensorrt.py @@ -6,6 +6,7 @@ import pytest import torch import torch.nn as nn +from typing import Callable try: from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, @@ -478,3 +479,99 @@ def func(input, grid): if os.path.exists(trt_file): os.remove(trt_file) assert torch.allclose(pytorch_results, trt_results) + + +@pytest.mark.parametrize('func', [torch.cummax, torch.cummin]) +def test_cummin_cummax(func: Callable): + # Note generally `cummax` or `cummin` is exportable to ONNX + # as long as the pytorch version >= 1.5.0, since `torch.cummax` + # is only supported with torch >= 1.5.0. + # But when `cummax` or `cummin` serves as an intermediate component + # whose outputs is used as inputs for another modules, it's expected + # that pytorch version must be >= 1.7.0. Otherwise error appears like: + # `RuntimeError: tuple appears in op that does not forward tuples, + # unsupported 'kind: prim::PythonOp`. + from packaging import version + if version.parse(torch.__version__) < version.parse('1.7.0'): + pytest.skip('test_cummax_cummin should be ran with pytorch >= 1.7.0') + + opset = 11 + # register custom op `mmcv::cummax` and `mmcv::cummin` + from mmcv.onnx.symbolic import register_extra_symbolics + register_extra_symbolics(opset) + + input_list = [ + # arbitrary shape, e.g. 1-D, 2-D, 3-D, ... + torch.rand((2, 3, 4, 1, 5)).cuda(), + torch.rand((1)).cuda() + ] + + input_names = ['input'] + output_names = ['output', 'indices'] + + for input in input_list: + ndims = input.dim() + # valid dim range is [-ndims, ndims-1] + # test for all `dim` value which is valid + for dim in range(-ndims, ndims): + cummax_func = partial(func, dim=dim) + wrapped_model = WrapFunction(cummax_func).eval().cuda() + + with torch.no_grad(): + torch.onnx.export( + wrapped_model, + input, + onnx_file, + export_params=True, + keep_initializers_as_inputs=False, + input_names=input_names, + output_names=output_names, + opset_version=opset) + + onnx_model = onnx.load(onnx_file) + + # create trt engine and wraper + opt_shape_dict = { + 'input': + [list(input.shape), + list(input.shape), + list(input.shape)] + } + # trt config + fp16_mode = False + max_workspace_size = 1 << 30 + + trt_engine = onnx2trt( + onnx_model, + opt_shape_dict, + fp16_mode=fp16_mode, + max_workspace_size=max_workspace_size) + + # remove ONNX model after conversion + if os.path.exists(onnx_file): + os.remove(onnx_file) + + # save TensorRT model + save_trt_engine(trt_engine, trt_file) + + # load and wrap TensorRT model + trt_model = TRTWraper(trt_file) + + # remove trt model after loading + if os.path.exists(trt_file): + os.remove(trt_file) + + # compute trt output + with torch.no_grad(): + trt_results = trt_model({'input': input.contiguous().clone()}) + trt_output = trt_results['output'] + trt_indices = trt_results['indices'] + + # compute pytorch output + with torch.no_grad(): + pytorch_results = wrapped_model(input.clone()) + pytorch_output = pytorch_results[0] + pytorch_indices = pytorch_results[1] + + torch.testing.assert_allclose(trt_output, pytorch_output) + torch.testing.assert_allclose(trt_indices, pytorch_indices) From 1f8dd073bad77987f8394fbf61d35540438a35c5 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 18 May 2021 19:54:10 +0800 Subject: [PATCH 2/5] fix isort --- tests/test_ops/test_tensorrt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ops/test_tensorrt.py b/tests/test_ops/test_tensorrt.py index d76e3b228f..ddfa68165a 100644 --- a/tests/test_ops/test_tensorrt.py +++ b/tests/test_ops/test_tensorrt.py @@ -1,12 +1,12 @@ import os from functools import partial +from typing import Callable import numpy as np import onnx import pytest import torch import torch.nn as nn -from typing import Callable try: from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, From 8d67401917e5a4df1735b9e97d55ad35104c7bf1 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 19 May 2021 15:46:11 +0800 Subject: [PATCH 3/5] fix with clang-format --- .../csrc/tensorrt/plugins/trt_cummaxmin.cpp | 81 +++++++++---------- .../tensorrt/plugins/trt_cummaxmin_kernel.cu | 24 +++--- mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp | 28 +++---- 3 files changed, 65 insertions(+), 68 deletions(-) diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp index 662426dfba..2e920cfed0 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin.cpp @@ -19,7 +19,7 @@ static const char *PLUGIN_VERSION{"1"}; static const char *CUMMAXMIN_PLUGIN_NAME{"cummaxmin"}; static const char *CUMMAX_PLUGIN_NAME{"cummax"}; static const char *CUMMIN_PLUGIN_NAME{"cummin"}; -} // namespace +} // namespace CumMaxMinPluginDynamic::CumMaxMinPluginDynamic(const std::string &name, int dim, TRT_CUMCMPTYPE cumType) @@ -52,21 +52,21 @@ bool CumMaxMinPluginDynamic::supportsFormatCombination( int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, int nbOutputs) { switch (pos) { - // input[0] - case 0: - return (inOut[pos].type == nvinfer1::DataType::kFLOAT || - inOut[pos].type == nvinfer1::DataType::kINT32) && - inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; - // output[0] - case 1: - return inOut[pos].type == inOut[0].type && - inOut[pos].format == inOut[0].format; - // output[1] - case 2: - return inOut[pos].type == nvinfer1::DataType::kINT32 && - inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; - default: - return false; + // input[0] + case 0: + return (inOut[pos].type == nvinfer1::DataType::kFLOAT || + inOut[pos].type == nvinfer1::DataType::kINT32) && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + // output[0] + case 1: + return inOut[pos].type == inOut[0].type && + inOut[pos].format == inOut[0].format; + // output[1] + case 2: + return inOut[pos].type == nvinfer1::DataType::kINT32 && + inOut[pos].format == nvinfer1::TensorFormat::kLINEAR; + default: + return false; } } @@ -84,7 +84,6 @@ int CumMaxMinPluginDynamic::enqueue( const nvinfer1::PluginTensorDesc *inputDesc, const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, void *const *outputs, void *workSpace, cudaStream_t stream) { - const void *input = inputs[0]; void *output_value = outputs[0]; int *output_index = (int *)outputs[1]; @@ -93,18 +92,18 @@ int CumMaxMinPluginDynamic::enqueue( int nbDims = inputDesc[0].dims.nbDims; switch (inputDesc[0].type) { - case nvinfer1::DataType::kFLOAT: - CumMaxMinForwardLauncher_float((float *)input, (float *)output_value, - output_index, dims, nbDims, mDim, - int(mCumType), stream); - break; - case nvinfer1::DataType::kINT32: - CumMaxMinForwardLauncher_int32((int *)input, (int *)output_value, - output_index, dims, nbDims, mDim, - int(mCumType), stream); - break; - default: - break; + case nvinfer1::DataType::kFLOAT: + CumMaxMinForwardLauncher_float((float *)input, (float *)output_value, + output_index, dims, nbDims, mDim, + int(mCumType), stream); + break; + case nvinfer1::DataType::kINT32: + CumMaxMinForwardLauncher_int32((int *)input, (int *)output_value, + output_index, dims, nbDims, mDim, + int(mCumType), stream); + break; + default: + break; } return 0; @@ -113,24 +112,24 @@ int CumMaxMinPluginDynamic::enqueue( nvinfer1::DataType CumMaxMinPluginDynamic::getOutputDataType( int index, const nvinfer1::DataType *inputTypes, int nbInputs) const { switch (index) { - case 0: - return inputTypes[0]; - case 1: - return nvinfer1::DataType::kINT32; - default: - break; + case 0: + return inputTypes[0]; + case 1: + return nvinfer1::DataType::kINT32; + default: + break; } } // IPluginV2 Methods const char *CumMaxMinPluginDynamic::getPluginType() const { switch (mCumType) { - case TRT_CUMCMPTYPE::TRT_CUMMAX: - return CUMMAX_PLUGIN_NAME; - case TRT_CUMCMPTYPE::TRT_CUMMIN: - return CUMMIN_PLUGIN_NAME; - default: - return "UnknownCumType"; + case TRT_CUMCMPTYPE::TRT_CUMMAX: + return CUMMAX_PLUGIN_NAME; + case TRT_CUMCMPTYPE::TRT_CUMMIN: + return CUMMIN_PLUGIN_NAME; + default: + return "UnknownCumType"; } } diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu index 15780c1a83..753104071f 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_cummaxmin_kernel.cu @@ -26,18 +26,18 @@ __global__ void cummaxmin_kernel(const scalar_t *input, scalar_t *output_value, cum_offset += cum_stride; const auto cum_value_current = input[cum_offset]; switch (cum_type) { - case 0: // max - if (cum_value_current > cum_value) { - cum_value = cum_value_current; - cum_index = cum_index_current; - } - break; - case 1: // min - if (cum_value_current < cum_value) { - cum_value = cum_value_current; - cum_index = cum_index_current; - } - break; + case 0: // max + if (cum_value_current > cum_value) { + cum_value = cum_value_current; + cum_index = cum_index_current; + } + break; + case 1: // min + if (cum_value_current < cum_value) { + cum_value = cum_value_current; + cum_index = cum_index_current; + } + break; } output_value[cum_offset] = cum_value; output_index[cum_offset] = cum_index; diff --git a/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp b/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp index 5665ba919f..5b856b02fb 100644 --- a/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp +++ b/mmcv/ops/csrc/tensorrt/trt_cummaxmin.hpp @@ -9,7 +9,7 @@ enum TRT_CUMCMPTYPE { TRT_CUMMAX = 0, TRT_CUMMIN = 1 }; // implement of cummax and cummin class CumMaxMinPluginDynamic : public nvinfer1::IPluginV2DynamicExt { -public: + public: CumMaxMinPluginDynamic(const std::string &name, int dim, TRT_CUMCMPTYPE cumType); @@ -22,10 +22,9 @@ class CumMaxMinPluginDynamic : public nvinfer1::IPluginV2DynamicExt { // IPluginV2DynamicExt Methods nvinfer1::IPluginV2DynamicExt *clone() const override; - nvinfer1::DimsExprs - getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) override; + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, + nvinfer1::IExprBuilder &exprBuilder) override; bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, int nbOutputs) override; @@ -59,14 +58,14 @@ class CumMaxMinPluginDynamic : public nvinfer1::IPluginV2DynamicExt { void setPluginNamespace(const char *pluginNamespace) override; const char *getPluginNamespace() const override; -protected: + protected: const std::string mLayerName; std::string mNamespace; int mDim; TRT_CUMCMPTYPE mCumType; -protected: + protected: // To prevent compiler warnings. using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; using nvinfer1::IPluginV2DynamicExt::configurePlugin; @@ -79,7 +78,7 @@ class CumMaxMinPluginDynamic : public nvinfer1::IPluginV2DynamicExt { // cummax and cummin creator class CumMaxMinPluginDynamicCreator : public nvinfer1::IPluginCreator { -public: + public: CumMaxMinPluginDynamicCreator(TRT_CUMCMPTYPE cumType); const char *getPluginName() const override; @@ -88,9 +87,8 @@ class CumMaxMinPluginDynamicCreator : public nvinfer1::IPluginCreator { const nvinfer1::PluginFieldCollection *getFieldNames() override; - nvinfer1::IPluginV2 * - createPlugin(const char *name, - const nvinfer1::PluginFieldCollection *fc) override; + nvinfer1::IPluginV2 *createPlugin( + const char *name, const nvinfer1::PluginFieldCollection *fc) override; nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, @@ -100,7 +98,7 @@ class CumMaxMinPluginDynamicCreator : public nvinfer1::IPluginCreator { const char *getPluginNamespace() const override; -protected: + protected: TRT_CUMCMPTYPE mCumType; nvinfer1::PluginFieldCollection mFC; std::vector mPluginAttributes; @@ -109,16 +107,16 @@ class CumMaxMinPluginDynamicCreator : public nvinfer1::IPluginCreator { // cummax creator class CumMaxPluginDynamicCreator : public CumMaxMinPluginDynamicCreator { -public: + public: CumMaxPluginDynamicCreator(); const char *getPluginName() const override; }; // cummin creator class CumMinPluginDynamicCreator : public CumMaxMinPluginDynamicCreator { -public: + public: CumMinPluginDynamicCreator(); const char *getPluginName() const override; }; -#endif TRT_CUMMAXMIN_HPP // TRT_CUMMAXMIN_HPP +#endif TRT_CUMMAXMIN_HPP // TRT_CUMMAXMIN_HPP From fead485d22345c3efbd73b12563629927cb84558 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Wed, 19 May 2021 15:52:37 +0800 Subject: [PATCH 4/5] fix with clang-format again --- mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp index 3dbf315cfb..ab4ee11e81 100644 --- a/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp +++ b/mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp @@ -17,4 +17,4 @@ REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator); extern "C" { bool initLibMMCVInferPlugins() { return true; } -} // extern "C" +} // extern "C" From 824d29d3f50ce4c5fe33d577f13675ea534a72aa Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Thu, 20 May 2021 16:56:52 +0800 Subject: [PATCH 5/5] add document --- docs/tensorrt_custom_ops.md | 76 +++++++++++++++++++++++++++++++++++++ docs/tensorrt_plugin.md | 4 +- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/docs/tensorrt_custom_ops.md b/docs/tensorrt_custom_ops.md index da696f03e9..7bf369cfb7 100644 --- a/docs/tensorrt_custom_ops.md +++ b/docs/tensorrt_custom_ops.md @@ -33,6 +33,18 @@ - [Inputs](#inputs-4) - [Outputs](#outputs-4) - [Type Constraints](#type-constraints-4) + - [cummax](#cummax) + - [Description](#description-5) + - [Parameters](#parameters-5) + - [Inputs](#inputs-5) + - [Outputs](#outputs-5) + - [Type Constraints](#type-constraints-5) + - [cummin](#cummin) + - [Description](#description-6) + - [Parameters](#parameters-6) + - [Inputs](#inputs-6) + - [Outputs](#outputs-6) + - [Type Constraints](#type-constraints-6) @@ -227,3 +239,67 @@ Perform sample from `input` with pixel locations from `grid`. ### Type Constraints - T:tensor(float32, Linear) + +## cummax + +### Description + +Returns a namedtuple (`values`, `indices`) where `values` is the cumulative maximum of elements of `input` in the dimension `dim`. And `indices` is the index location of each maximum value found in the dimension `dim`. + +### Parameters + +| Type | Parameter | Description | +| ----- | --------- | --------------------------------------- | +| `int` | `dim` | The dimension to do the operation over. | + +### Inputs + +
+
inputs[0]: T
+
The input tensor.
+
+ +### Outputs + +
+
outputs[0]: T
+
Output values.
+
outputs[1]: (int32, Linear)
+
Output indices.
+
+ +### Type Constraints + +- T:tensor(float32, Linear) + +## cummin + +### Description + +Returns a namedtuple (`values`, `indices`) where `values` is the cumulative minimum of elements of `input` in the dimension `dim`. And `indices` is the index location of each minimum value found in the dimension `dim`. + +### Parameters + +| Type | Parameter | Description | +| ----- | --------- | --------------------------------------- | +| `int` | `dim` | The dimension to do the operation over. | + +### Inputs + +
+
inputs[0]: T
+
The input tensor.
+
+ +### Outputs + +
+
outputs[0]: T
+
Output values.
+
outputs[1]: (int32, Linear)
+
Output indices.
+
+ +### Type Constraints + +- T:tensor(float32, Linear) diff --git a/docs/tensorrt_plugin.md b/docs/tensorrt_plugin.md index 5ed62d1ba3..63b5300000 100644 --- a/docs/tensorrt_plugin.md +++ b/docs/tensorrt_plugin.md @@ -30,7 +30,9 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u | ScatterND | [ScatterND](./tensorrt_custom_ops.md#scatternd) | 1.2.6 | | NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 | | MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 | -| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | master | +| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 | +| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master | +| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master | Notes