Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tensorrt custom plugin MMCVCornerPool #1179

Merged
merged 1 commit into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 216 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
#include "trt_corner_pool.hpp"

#include <assert.h>

#include "trt_serialize.hpp"

void CornerPoolForwardLauncher_float(const float *input, float *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type, cudaStream_t stream);

namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *CORNER_POOL_PLUGIN_NAME{"MMCVCornerPool"};
} // namespace

CornerPoolPluginDynamic::CornerPoolPluginDynamic(const std::string &name,
TRT_CORNER_POOL_TYPE poolType)
: mLayerName(name), mPoolType(poolType) {}

CornerPoolPluginDynamic::CornerPoolPluginDynamic(const std::string name,
const void *data,
size_t length)
: mLayerName(name) {
deserialize_value(&data, &length, &mPoolType);
}

CornerPoolPluginDynamic::~CornerPoolPluginDynamic() {}

nvinfer1::IPluginV2DynamicExt *CornerPoolPluginDynamic::clone() const {
CornerPoolPluginDynamic *plugin =
new CornerPoolPluginDynamic(mLayerName, mPoolType);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs CornerPoolPluginDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) {
return inputs[0];
}

bool CornerPoolPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) {
switch (pos) {
// input[0]
case 0:
return inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
// output[0]
case 1:
return inOut[pos].type == inOut[0].type &&
inOut[pos].format == inOut[0].format;
default:
return false;
}
}

void CornerPoolPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}

size_t CornerPoolPluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
int sizeof_dtype = mmcv::getElementSize(outputs[0].type);
}

int CornerPoolPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace, cudaStream_t stream) {
const void *input = inputs[0];
void *output_value = outputs[0];

const int batch_size = inputDesc[0].dims.d[0];
const int channels = inputDesc[0].dims.d[1];
const int height = inputDesc[0].dims.d[2];
const int width = inputDesc[0].dims.d[3];

CornerPoolForwardLauncher_float((float *)input, (float *)output_value,
batch_size, channels, height, width,
int(mPoolType), stream);

return 0;
}

nvinfer1::DataType CornerPoolPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
return inputTypes[0];
}

// IPluginV2 Methods
const char *CornerPoolPluginDynamic::getPluginType() const {
switch (mPoolType) {
case TRT_CORNER_POOL_TYPE::TRT_TOP_POOL:
case TRT_CORNER_POOL_TYPE::TRT_BOTTOM_POOL:
case TRT_CORNER_POOL_TYPE::TRT_LEFT_POOL:
case TRT_CORNER_POOL_TYPE::TRT_RIGHT_POOL:
return CORNER_POOL_PLUGIN_NAME;

default:
return "UnknownpoolType";
}
}

const char *CornerPoolPluginDynamic::getPluginVersion() const {
return PLUGIN_VERSION;
}

int CornerPoolPluginDynamic::getNbOutputs() const { return 1; }

int CornerPoolPluginDynamic::initialize() { return 0; }

void CornerPoolPluginDynamic::terminate() {}

size_t CornerPoolPluginDynamic::getSerializationSize() const {
return sizeof(mPoolType);
}

void CornerPoolPluginDynamic::serialize(void *buffer) const {
serialize_value(&buffer, mPoolType);
}

void CornerPoolPluginDynamic::destroy() {
// This gets called when the network containing plugin is destroyed
delete this;
}

void CornerPoolPluginDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}

const char *CornerPoolPluginDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}

CornerPoolPluginDynamicCreator::CornerPoolPluginDynamicCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("mode"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *CornerPoolPluginDynamicCreator::getPluginName() const {
return CORNER_POOL_PLUGIN_NAME;
}

const char *CornerPoolPluginDynamicCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}

const nvinfer1::PluginFieldCollection *
CornerPoolPluginDynamicCreator::getFieldNames() {
return &mFC;
}

nvinfer1::IPluginV2 *CornerPoolPluginDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) {
TRT_CORNER_POOL_TYPE poolType;
int poolMode = -1;

for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);

if (field_name.compare("mode") == 0) {
poolMode = static_cast<const int *>(fc->fields[i].data)[0];
}
}

assert(poolMode >= 0 && poolMode <= 3);
switch (poolMode) {
case 0:
poolType = TRT_CORNER_POOL_TYPE::TRT_TOP_POOL;
break;
case 1:
poolType = TRT_CORNER_POOL_TYPE::TRT_BOTTOM_POOL;
break;
case 2:
poolType = TRT_CORNER_POOL_TYPE::TRT_LEFT_POOL;
break;
case 3:
poolType = TRT_CORNER_POOL_TYPE::TRT_RIGHT_POOL;
break;

default:
break;
}

CornerPoolPluginDynamic *plugin = new CornerPoolPluginDynamic(name, poolType);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *CornerPoolPluginDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) {
// This object will be deleted when the network is destroyed, which will
// call FCPluginDynamic::destroy()
auto plugin = new CornerPoolPluginDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

void CornerPoolPluginDynamicCreator::setPluginNamespace(
const char *libNamespace) {
mNamespace = libNamespace;
}

const char *CornerPoolPluginDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}
109 changes: 109 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_corner_pool_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"

template <typename scalar_t>
__global__ void top_bottom_pool_kernel(const scalar_t *input, scalar_t *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type) {
const int nthreads = batch_size * channels * width;
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n_idx = index / (channels * width); // batch
int w_idx = index % width; // width
int c_idx = (index / width) % channels; // channels
int offset_n = n_idx * channels * width * height;
int offset_n_c = offset_n + c_idx * width * height;
int direction = -1; // in [-1, 1], default for TopPool
int index_start = height - 2; // default for TopPool
// pool_type in [0, 1]
if (pool_type == 0) {
// TopPool
// directly copy the most bottom value from input to output
output[offset_n_c + (height - 1) * width + w_idx] =
input[offset_n_c + (height - 1) * width + w_idx];
} else {
// BottomPool
// directly copy the most top value from input to output
output[offset_n_c + w_idx] = input[offset_n_c + w_idx];
index_start = 1;
direction = 1;
}
// do pool
for (int h = index_start; h >= 0 && h < height; h += direction) {
output[offset_n_c + h * width + w_idx] =
max(output[offset_n_c + (h - direction) * width + w_idx],
input[offset_n_c + h * width + w_idx]);
}
}
}

template <typename scalar_t>
__global__ void left_right_pool_kernel(const scalar_t *input, scalar_t *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type) {
const int nthreads = batch_size * channels * height;
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n_idx = index / (channels * height); // batch
int h_idx = index % height; // height
int c_idx = (index / height) % channels; // channels
int offset_n = n_idx * channels * width * height;
int offset_n_c = offset_n + c_idx * width * height;
int offset_n_c_h = offset_n_c + h_idx * width;
int direction = -1; // in [-1, 1], default for LeftPool
int index_start = width - 2; // default for LeftPool
// pool_type in [2, 3]
if (pool_type == 2) {
// LeftPool
// directly copy the most right value from input to output
output[offset_n_c_h + width - 1] = input[offset_n_c_h + width - 1];
} else {
// RightPool
// directly copy the most left value from input to output
output[offset_n_c_h] = input[offset_n_c_h];
index_start = 1;
direction = 1;
}
// do pool
for (int w = index_start; w >= 0 && w < width; w += direction) {
output[offset_n_c_h + w] =
max(output[offset_n_c_h + w - direction], input[offset_n_c_h + w]);
}
}
}

template <typename scalar_t>
void CornerPoolForwardLauncher(const scalar_t *input, scalar_t *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type, cudaStream_t stream) {
int nthreads = -1, col_block = -1;

switch (pool_type) {
case 0:
case 1:
nthreads = batch_size * channels * width;
col_block = DIVUP(nthreads, THREADS_PER_BLOCK);
top_bottom_pool_kernel<scalar_t>
<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
input, output, batch_size, channels, height, width, pool_type);
break;
case 2:
case 3:
nthreads = batch_size * channels * height;
col_block = DIVUP(nthreads, THREADS_PER_BLOCK);
left_right_pool_kernel<scalar_t>
<<<col_block, THREADS_PER_BLOCK, 0, stream>>>(
input, output, batch_size, channels, height, width, pool_type);
break;
}
}

void CornerPoolForwardLauncher_float(const float *input, float *output,
const int batch_size, const int channels,
const int height, const int width,
const int pool_type, cudaStream_t stream) {
CornerPoolForwardLauncher<float>(input, output, batch_size, channels, height,
width, pool_type, stream);
}
2 changes: 2 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "trt_plugin.hpp"

#include "trt_corner_pool.hpp"
#include "trt_cummaxmin.hpp"
#include "trt_deform_conv.hpp"
#include "trt_grid_sampler.hpp"
Expand All @@ -18,6 +19,7 @@ REGISTER_TENSORRT_PLUGIN(NonMaxSuppressionDynamicCreator);
REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);
REGISTER_TENSORRT_PLUGIN(InstanceNormalizationDynamicCreator);
REGISTER_TENSORRT_PLUGIN(CornerPoolPluginDynamicCreator);

extern "C" {
bool initLibMMCVInferPlugins() { return true; }
Expand Down
Loading