Skip to content

Commit

Permalink
[Enhancement] TensorRT DCN support (#205)
Browse files Browse the repository at this point in the history
* add tensorrt dcn support

* fix lint
  • Loading branch information
q.yao authored Mar 21, 2022
1 parent 167400c commit ea0d0cf
Show file tree
Hide file tree
Showing 9 changed files with 747 additions and 5 deletions.
1 change: 1 addition & 0 deletions csrc/backend_ops/tensorrt/common/common_cuda_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <cublas_v2.h>
#include <cuda.h>
#include <stdio.h>

#include <algorithm>

Expand Down
10 changes: 5 additions & 5 deletions csrc/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
using mmdeploy::TensorDesc;

template <class scalar_t>
__global__ void copy_permute_kernel(scalar_t *dst, const scalar_t *src, int n,
TensorDesc ts_src_stride, TensorDesc ts_dst_stride,
__global__ void copy_permute_kernel(scalar_t *__restrict__ dst, const scalar_t *__restrict__ src,
int n, TensorDesc ts_src_stride, TensorDesc ts_dst_stride,
TensorDesc ts_permute) {
const int src_dim = ts_src_stride.dim;
int *src_stride = &(ts_src_stride.stride[0]);
int *dst_stride = &(ts_dst_stride.stride[0]);
int *permute = &(ts_permute.shape[0]);
const auto src_stride = ts_src_stride.stride;
const auto dst_stride = ts_dst_stride.stride;
const auto permute = ts_permute.shape;
CUDA_1D_KERNEL_LOOP(index, n) {
size_t dst_index = index;
size_t src_index = 0;
Expand Down
257 changes: 257 additions & 0 deletions csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_deform_conv.hpp"

#include <assert.h>

#include <chrono>

#include "trt_deform_conv_kernel.hpp"
#include "trt_serialize.hpp"

using namespace nvinfer1;

namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"MMCVDeformConv2d"};
} // namespace

DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string &name,
const nvinfer1::Dims stride,
const nvinfer1::Dims padding,
const nvinfer1::Dims dilation,
const int deformableGroup, const int group)
: TRTPluginBase(name),
mStride(stride),
mPadding(padding),
mDilation(dilation),
mDeformableGroup(deformableGroup),
mGroup(group) {}

DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name, const void *data,
size_t length)
: TRTPluginBase(name) {
deserialize_value(&data, &length, &mStride);
deserialize_value(&data, &length, &mPadding);
deserialize_value(&data, &length, &mDilation);
deserialize_value(&data, &length, &mDeformableGroup);
deserialize_value(&data, &length, &mGroup);
}
DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {}

nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const TRT_NOEXCEPT {
DeformableConvPluginDynamic *plugin = new DeformableConvPluginDynamic(
mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
// input[0] == input
// input[1] == offset
// input[2] == weight
nvinfer1::DimsExprs ret;
ret.nbDims = 4;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[2].d[0];

ret.d[2] = inputs[1].d[2];
ret.d[3] = inputs[1].d[3];

return ret;
}

bool DeformableConvPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT {
if (pos == 0) {
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else {
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
}
}

void DeformableConvPluginDynamic::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs,
int nbOutputs) TRT_NOEXCEPT {}

size_t DeformableConvPluginDynamic::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT {
int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type);

int batch_size = inputs[0].dims.d[0];
int nInputPlane = inputs[0].dims.d[1];
int inputHeight = inputs[0].dims.d[2];
int inputWidth = inputs[0].dims.d[3];

int nOutputPlane = outputs[0].dims.d[1];
int outputHeight = outputs[0].dims.d[2];
int outputWidth = outputs[0].dims.d[3];

int kW = inputs[2].dims.d[2];
int kH = inputs[2].dims.d[3];
int im2col_step = std::min(32, batch_size);

size_t col_size = mmdeploy::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight *
outputWidth * sizeof_dtype);

size_t out_size = 0;
if (im2col_step != 1)
out_size = mmdeploy::getAlignedSize(batch_size * nOutputPlane * outputHeight * outputWidth *
sizeof_dtype);

return col_size + out_size;
}

int DeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs, void *const *outputs,
void *workSpace, cudaStream_t stream) TRT_NOEXCEPT {
int batch = inputDesc[0].dims.d[0];
int channels = inputDesc[0].dims.d[1];
int height = inputDesc[0].dims.d[2];
int width = inputDesc[0].dims.d[3];
int channels_out = outputDesc[0].dims.d[1];
int kernel_h = inputDesc[2].dims.d[2];
int kernel_w = inputDesc[2].dims.d[3];

const void *x = inputs[0];
const void *offset = inputs[1];
const void *weight = inputs[2];
void *output = outputs[0];
int im2col_step = std::min(batch, 32);

auto data_type = inputDesc[0].type;
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
deform_conv<float>((float *)x, (float *)weight, (float *)offset, (float *)output, workSpace,
batch, channels, height, width, channels_out, kernel_w, kernel_h,
mStride.d[0], mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0],
mDilation.d[1], mGroup, mDeformableGroup, im2col_step, m_cublas_handle,
stream);
break;
default:
return 1;
break;
}

return 0;
}

nvinfer1::DataType DeformableConvPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}

// IPluginV2 Methods
const char *DeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *DeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}

int DeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; }

size_t DeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) +
serialized_size(mDeformableGroup) + serialized_size(mGroup);
}

void DeformableConvPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, mStride);
serialize_value(&buffer, mPadding);
serialize_value(&buffer, mDilation);
serialize_value(&buffer, mDeformableGroup);
serialize_value(&buffer, mGroup);
}

void DeformableConvPluginDynamic::attachToContext(
cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {
m_cublas_handle = cublasContext;
}

void DeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {}

////////////////////// creator /////////////////////////////

DeformableConvPluginDynamicCreator::DeformableConvPluginDynamicCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("padding"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("groups"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *DeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT {
return PLUGIN_NAME;
}

const char *DeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT {
return PLUGIN_VERSION;
}

nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
nvinfer1::Dims stride{2, {1, 1}};
nvinfer1::Dims padding{2, {0, 0}};
nvinfer1::Dims dilation{2, {1, 1}};
int deformableGroup = 1;
int group = 1;

for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);

if (field_name.compare("deform_groups") == 0) {
deformableGroup = static_cast<const int *>(fc->fields[i].data)[0];
}

if (field_name.compare("groups") == 0) {
group = static_cast<const int *>(fc->fields[i].data)[0];
}

if (field_name.compare("stride") == 0) {
stride.nbDims = 2;
stride.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
stride.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
}

if (field_name.compare("padding") == 0) {
padding.nbDims = 2;
padding.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
padding.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
}

if (field_name.compare("dilation") == 0) {
dilation.nbDims = 2;
dilation.d[0] = static_cast<const int *>(fc->fields[i].data)[0];
dilation.d[1] = static_cast<const int *>(fc->fields[i].data)[1];
}
}

DeformableConvPluginDynamic *plugin =
new DeformableConvPluginDynamic(name, stride, padding, dilation, deformableGroup, group);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT {
auto plugin = new DeformableConvPluginDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator);
} // namespace mmdeploy
81 changes: 81 additions & 0 deletions csrc/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_DEFORM_CONV_HPP
#define TRT_DEFORM_CONV_HPP
#include <cublas_v2.h>

#include <memory>
#include <string>
#include <vector>

#include "trt_plugin_base.hpp"

namespace mmdeploy {
class DeformableConvPluginDynamic : public TRTPluginBase {
public:
DeformableConvPluginDynamic(const std::string &name, const nvinfer1::Dims stride,
const nvinfer1::Dims padding, const nvinfer1::Dims dilation,
const int deformableGroup, const int group);

DeformableConvPluginDynamic(const std::string name, const void *data, size_t length);

DeformableConvPluginDynamic() = delete;

~DeformableConvPluginDynamic() TRT_NOEXCEPT override;

// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;
void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override;
void detachFromContext() TRT_NOEXCEPT override;

// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;

// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;

private:
nvinfer1::Dims mStride;
nvinfer1::Dims mPadding;
nvinfer1::Dims mDilation;
int mDeformableGroup;
int mGroup;

cublasHandle_t m_cublas_handle;
};

class DeformableConvPluginDynamicCreator : public TRTPluginCreatorBase {
public:
DeformableConvPluginDynamicCreator();

const char *getPluginName() const TRT_NOEXCEPT override;

const char *getPluginVersion() const TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_DEFORM_CONV_HPP
Loading

0 comments on commit ea0d0cf

Please sign in to comment.