Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] : Add ScatterND TensorRT Plugin #786

Merged
merged 7 commits into from
Jan 20, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_plugin.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include "trt_plugin.hpp"

#include "trt_roi_align.hpp"
#include "trt_scatternd.hpp"

REGISTER_TENSORRT_PLUGIN(RoIAlignPluginDynamicCreator);
REGISTER_TENSORRT_PLUGIN(ONNXScatterNDDynamicCreator);

extern "C" {
bool initLibMMCVInferPlugins() { return true; }
Expand Down
206 changes: 206 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_scatternd.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#include "trt_scatternd.hpp"
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved

#include <assert.h>
#include <stdio.h>

#include <chrono>

#include "trt_serialize.hpp"

extern void TRTONNXScatterNDKernelLauncher_float(
const float *data, const int *indices, const float *update, const int *dims,
int nbDims, const int *indices_dims, int indice_nbDims, float *output,
cudaStream_t stream);

extern void TRTONNXScatterNDKernelLauncher_int32(
const int *data, const int *indices, const int *update, const int *dims,
int nbDims, const int *indices_dims, int indice_nbDims, int *output,
cudaStream_t stream);

namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"ScatterND"};
} // namespace

nvinfer1::PluginFieldCollection ONNXScatterNDDynamicCreator::mFC{};
std::vector<nvinfer1::PluginField>
ONNXScatterNDDynamicCreator::mPluginAttributes;

ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string &name)
: mLayerName(name) {}

ONNXScatterNDDynamic::ONNXScatterNDDynamic(const std::string name,
const void *data, size_t length)
: mLayerName(name) {}

nvinfer1::IPluginV2DynamicExt *ONNXScatterNDDynamic::clone() const {
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(mLayerName);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs ONNXScatterNDDynamic::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) {
return inputs[0];
}

bool ONNXScatterNDDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs,
int nbOutputs) {
if (pos < nbInputs) {
switch (pos) {
case 0:
// data
return (inOut[pos].type == nvinfer1::DataType::kFLOAT &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR) ||
(inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR);
case 1:
// indices
return inOut[pos].type == nvinfer1::DataType::kINT32 &&
inOut[pos].format == nvinfer1::TensorFormat::kLINEAR;
case 2:
// updates
return inOut[pos].type == inOut[0].type &&
inOut[pos].format == inOut[0].format;
default:
return true;
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
}
} else {
switch (pos - nbInputs) {
case 0:
// output
return inOut[pos].type == inOut[0].type &&
inOut[pos].format == inOut[0].format;
default:
return true;
}
}
return true;
}

void ONNXScatterNDDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) {}

size_t ONNXScatterNDDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const {
return 0;
}

int ONNXScatterNDDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc,
const void *const *inputs,
void *const *outputs, void *workSpace,
cudaStream_t stream) {
const int *dims = &(inputDesc[0].dims.d[0]);
const int *indices_dims = &(inputDesc[1].dims.d[0]);
int nbDims = inputDesc[0].dims.nbDims;
int indice_nbDims = inputDesc[1].dims.nbDims;

const void *data = inputs[0];
const void *indices = inputs[1];
const void *update = inputs[2];
void *output = outputs[0];

auto data_type = inputDesc[0].type;

switch (data_type) {
case nvinfer1::DataType::kFLOAT:
TRTONNXScatterNDKernelLauncher_float(
(float *)data, (int *)indices, (float *)update, dims, nbDims,
indices_dims, indice_nbDims, (float *)output, stream);
break;

case nvinfer1::DataType::kINT32:
TRTONNXScatterNDKernelLauncher_int32(
(int *)data, (int *)indices, (int *)update, dims, nbDims,
indices_dims, indice_nbDims, (int *)output, stream);
break;
default:
break;
}

return 0;
}

nvinfer1::DataType ONNXScatterNDDynamic::getOutputDataType(
int index, const nvinfer1::DataType *inputTypes, int nbInputs) const {
return inputTypes[0];
}

// IPluginV2 Methods
const char *ONNXScatterNDDynamic::getPluginType() const { return PLUGIN_NAME; }

const char *ONNXScatterNDDynamic::getPluginVersion() const {
return PLUGIN_VERSION;
}

int ONNXScatterNDDynamic::getNbOutputs() const { return 1; }

int ONNXScatterNDDynamic::initialize() { return 0; }

void ONNXScatterNDDynamic::terminate() {}

size_t ONNXScatterNDDynamic::getSerializationSize() const { return 0; }

void ONNXScatterNDDynamic::serialize(void *buffer) const {}

void ONNXScatterNDDynamic::destroy() {
// This gets called when the network containing plugin is destroyed
delete this;
}

void ONNXScatterNDDynamic::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}

const char *ONNXScatterNDDynamic::getPluginNamespace() const {
return mNamespace.c_str();
}

////////////////////// creator /////////////////////////////

ONNXScatterNDDynamicCreator::ONNXScatterNDDynamicCreator() {
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *ONNXScatterNDDynamicCreator::getPluginName() const {
return PLUGIN_NAME;
}

const char *ONNXScatterNDDynamicCreator::getPluginVersion() const {
return PLUGIN_VERSION;
}

const nvinfer1::PluginFieldCollection *
ONNXScatterNDDynamicCreator::getFieldNames() {
return &mFC;
}

nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) {
ONNXScatterNDDynamic *plugin = new ONNXScatterNDDynamic(name);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *ONNXScatterNDDynamicCreator::deserializePlugin(
const char *name, const void *serialData, size_t serialLength) {
auto plugin = new ONNXScatterNDDynamic(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

void ONNXScatterNDDynamicCreator::setPluginNamespace(const char *libNamespace) {
mNamespace = libNamespace;
}

const char *ONNXScatterNDDynamicCreator::getPluginNamespace() const {
return mNamespace.c_str();
}
93 changes: 93 additions & 0 deletions mmcv/ops/csrc/tensorrt/plugins/trt_scatternd_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#include <stdio.h>

#include <vector>

#include "common_cuda_helper.hpp"
#include "trt_cuda_helper.cuh"
#include "trt_plugin_helper.hpp"

static int const threadsPerBlock = sizeof(unsigned long long int) * 8;

using mmcv::TensorDesc;

template <typename T>
__global__ void onnx_scatternd_kernel(const int n, const int* indices,
const T* update, T* output,
TensorDesc tensor_desc,
TensorDesc indice_desc) {
const int indice_cols = indice_desc.shape[indice_desc.dim - 1];
const int copy_stride = tensor_desc.stride[indice_cols - 1];
const int* shape = &(tensor_desc.shape[0]);
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
const int* stride = &(tensor_desc.stride[0]);
CUDA_1D_KERNEL_LOOP(index, n) {
int output_offset = 0;
const int* indices_current = indices + index * indice_cols;
for (int i = 0; i < indice_cols; ++i) {
output_offset += stride[i] * indices_current[i];
}
memcpy(output + output_offset, update + index * copy_stride,
copy_stride * sizeof(T));
}
}

template <typename T>
void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices,
const T* update, const int* dims,
int nbDims, const int* indices_dims,
int indice_nbDims, T* output,
cudaStream_t stream) {
// fill tensordesc and initial
TensorDesc tensor_desc;
memset((void*)&tensor_desc, 0, sizeof(TensorDesc));
tensor_desc.dim = nbDims;
tensor_desc.shape[nbDims - 1] = dims[nbDims - 1];
tensor_desc.stride[nbDims - 1] = 1;
for (int i = nbDims - 2; i >= 0; --i) {
tensor_desc.shape[i] = dims[i];
tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1];
}
const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0];

TensorDesc indice_desc;
memset((void*)&indice_desc, 0, sizeof(TensorDesc));
indice_desc.dim = indice_nbDims;
indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1];
indice_desc.stride[indice_nbDims - 1] = 1;
for (int i = indice_nbDims - 2; i >= 0; --i) {
indice_desc.shape[i] = indices_dims[i];
indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1];
}

// output = np.copy(data)
cudaMemcpyAsync(output, data, data_size * sizeof(T),
cudaMemcpyDeviceToDevice);

int num_update_indice = 1;
for (int i = 0; i < indice_nbDims - 1; ++i) {
num_update_indice *= indice_desc.shape[i];
}
// scatter
const int col_block = DIVUP(num_update_indice, threadsPerBlock);
onnx_scatternd_kernel<<<col_block, threadsPerBlock, 0, stream>>>(
num_update_indice, indices, update, output, tensor_desc, indice_desc);
}

void TRTONNXScatterNDKernelLauncher_float(const float* data, const int* indices,
const float* update, const int* dims,
int nbDims, const int* indices_dims,
int indice_nbDims, float* output,
cudaStream_t stream) {
TRTONNXScatterNDKernelLauncher<float>(data, indices, update, dims, nbDims,
indices_dims, indice_nbDims, output,
stream);
}

void TRTONNXScatterNDKernelLauncher_int32(const int* data, const int* indices,
const int* update, const int* dims,
int nbDims, const int* indices_dims,
int indice_nbDims, int* output,
cudaStream_t stream) {
TRTONNXScatterNDKernelLauncher<int>(data, indices, update, dims, nbDims,
indices_dims, indice_nbDims, output,
stream);
}
16 changes: 16 additions & 0 deletions mmcv/ops/csrc/tensorrt/trt_cuda_helper.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#ifndef TRT_CUDA_HELPER_HPP
#define TRT_CUDA_HELPER_HPP

#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

#define cudaCheckError() \
{ \
cudaError_t e = cudaGetLastError(); \
if (e != cudaSuccess) { \
printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, \
cudaGetErrorString(e)); \
exit(0); \
} \
}

#endif // TRT_CUDA_HELPER_HPP
8 changes: 8 additions & 0 deletions mmcv/ops/csrc/tensorrt/trt_plugin_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@

namespace mmcv {

const int MAXTENSORDIMS = 10;

struct TensorDesc {
int shape[MAXTENSORDIMS];
int stride[MAXTENSORDIMS];
int dim;
};

inline unsigned int getElementSize(nvinfer1::DataType t) {
switch (t) {
case nvinfer1::DataType::kINT32:
Expand Down
Loading