Skip to content

Commit

Permalink
[Enhancement] TensorRT Anchor generator plugin (#646)
Browse files Browse the repository at this point in the history
* custom trt anchor generator

* add ut

* add docstring, update doc
  • Loading branch information
q.yao authored Jun 28, 2022
1 parent 4d9e209 commit dc5f9c3
Show file tree
Hide file tree
Showing 11 changed files with 573 additions and 3 deletions.
17 changes: 15 additions & 2 deletions csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
}
const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); }

virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT override {}

virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs,
const nvinfer1::PluginTensorDesc *outputs,
int nbOutputs) const TRT_NOEXCEPT override {
return 0;
}

virtual void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext,
nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override {}

virtual void detachFromContext() TRT_NOEXCEPT override {}

protected:
const std::string mLayerName;
std::string mNamespace;
Expand All @@ -34,10 +49,8 @@ class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt {
protected:
// To prevent compiler warnings.
using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::configurePlugin;
using nvinfer1::IPluginV2DynamicExt::enqueue;
using nvinfer1::IPluginV2DynamicExt::getOutputDimensions;
using nvinfer1::IPluginV2DynamicExt::getWorkspaceSize;
using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch;
using nvinfer1::IPluginV2DynamicExt::supportsFormat;
#endif
Expand Down
154 changes: 154 additions & 0 deletions csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "trt_grid_priors.hpp"

#include <assert.h>

#include <chrono>

#include "trt_grid_priors_kernel.hpp"
#include "trt_serialize.hpp"

using namespace nvinfer1;

namespace mmdeploy {
namespace {
static const char *PLUGIN_VERSION{"1"};
static const char *PLUGIN_NAME{"GridPriorsTRT"};
} // namespace

GridPriorsTRT::GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride)
: TRTPluginBase(name), mStride(stride) {}

GridPriorsTRT::GridPriorsTRT(const std::string name, const void *data, size_t length)
: TRTPluginBase(name) {
deserialize_value(&data, &length, &mStride);
}
GridPriorsTRT::~GridPriorsTRT() {}

nvinfer1::IPluginV2DynamicExt *GridPriorsTRT::clone() const TRT_NOEXCEPT {
GridPriorsTRT *plugin = new GridPriorsTRT(mLayerName, mStride);
plugin->setPluginNamespace(getPluginNamespace());

return plugin;
}

nvinfer1::DimsExprs GridPriorsTRT::getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs,
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT {
// input[0] == base_anchor
// input[1] == empty_h
// input[2] == empty_w

nvinfer1::DimsExprs ret;
ret.nbDims = 2;
auto area =
exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[2].d[0], *inputs[1].d[0]);
ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *area, *(inputs[0].d[0]));
ret.d[1] = exprBuilder.constant(4);

return ret;
}

bool GridPriorsTRT::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc,
int nbInputs, int nbOutputs) TRT_NOEXCEPT {
if (pos == 0) {
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT &&
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else if (pos - nbInputs == 0) {
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format;
} else {
return true;
}
}

int GridPriorsTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workSpace,
cudaStream_t stream) TRT_NOEXCEPT {
int num_base_anchors = inputDesc[0].dims.d[0];
int feat_h = inputDesc[1].dims.d[0];
int feat_w = inputDesc[2].dims.d[0];

const void *base_anchor = inputs[0];
void *output = outputs[0];

auto data_type = inputDesc[0].type;
switch (data_type) {
case nvinfer1::DataType::kFLOAT:
trt_grid_priors_impl<float>((float *)base_anchor, (float *)output, num_base_anchors, feat_w,
feat_h, mStride.d[0], mStride.d[1], stream);
break;
default:
return 1;
}

return 0;
}

nvinfer1::DataType GridPriorsTRT::getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT {
return inputTypes[0];
}

// IPluginV2 Methods
const char *GridPriorsTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *GridPriorsTRT::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }

int GridPriorsTRT::getNbOutputs() const TRT_NOEXCEPT { return 1; }

size_t GridPriorsTRT::getSerializationSize() const TRT_NOEXCEPT { return serialized_size(mStride); }

void GridPriorsTRT::serialize(void *buffer) const TRT_NOEXCEPT {
serialize_value(&buffer, mStride);
;
}

////////////////////// creator /////////////////////////////

GridPriorsTRTCreator::GridPriorsTRTCreator() {
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h"));
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w"));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char *GridPriorsTRTCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; }

const char *GridPriorsTRTCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; }

nvinfer1::IPluginV2 *GridPriorsTRTCreator::createPlugin(
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT {
int stride_w = 1;
int stride_h = 1;

for (int i = 0; i < fc->nbFields; i++) {
if (fc->fields[i].data == nullptr) {
continue;
}
std::string field_name(fc->fields[i].name);

if (field_name.compare("stride_w") == 0) {
stride_w = static_cast<const int *>(fc->fields[i].data)[0];
}
if (field_name.compare("stride_h") == 0) {
stride_h = static_cast<const int *>(fc->fields[i].data)[0];
}
}
nvinfer1::Dims stride{2, {stride_w, stride_h}};

GridPriorsTRT *plugin = new GridPriorsTRT(name, stride);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}

nvinfer1::IPluginV2 *GridPriorsTRTCreator::deserializePlugin(const char *name,
const void *serialData,
size_t serialLength) TRT_NOEXCEPT {
auto plugin = new GridPriorsTRT(name, serialData, serialLength);
plugin->setPluginNamespace(getPluginNamespace());
return plugin;
}
REGISTER_TENSORRT_PLUGIN(GridPriorsTRTCreator);
} // namespace mmdeploy
66 changes: 66 additions & 0 deletions csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) OpenMMLab. All rights reserved.
#ifndef TRT_GRID_PRIORS_HPP
#define TRT_GRID_PRIORS_HPP
#include <cublas_v2.h>

#include <memory>
#include <string>
#include <vector>

#include "trt_plugin_base.hpp"

namespace mmdeploy {
class GridPriorsTRT : public TRTPluginBase {
public:
GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride);

GridPriorsTRT(const std::string name, const void *data, size_t length);

GridPriorsTRT() = delete;

~GridPriorsTRT() TRT_NOEXCEPT override;

// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs,
int nbInputs, nvinfer1::IExprBuilder &exprBuilder)
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs,
int nbOutputs) TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc,
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override;

// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes,
int nbInputs) const TRT_NOEXCEPT override;

// IPluginV2 Methods
const char *getPluginType() const TRT_NOEXCEPT override;
const char *getPluginVersion() const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override;
void serialize(void *buffer) const TRT_NOEXCEPT override;

private:
nvinfer1::Dims mStride;

cublasHandle_t m_cublas_handle;
};

class GridPriorsTRTCreator : public TRTPluginCreatorBase {
public:
GridPriorsTRTCreator();

const char *getPluginName() const TRT_NOEXCEPT override;

const char *getPluginVersion() const TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc)
TRT_NOEXCEPT override;

nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData,
size_t serialLength) TRT_NOEXCEPT override;
};
} // namespace mmdeploy
#endif // TRT_GRID_PRIORS_HPP
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) OpenMMLab. All rights reserved
#include <cuda_fp16.h>

#include "common_cuda_helper.hpp"
#include "trt_grid_priors_kernel.hpp"
#include "trt_plugin_helper.hpp"

template <typename scalar_t>
__global__ void trt_grid_priors_kernel(const scalar_t* base_anchor, scalar_t* output,
int num_base_anchors, int feat_w, int feat_h, int stride_w,
int stride_h) {
// load base anchor into shared memory.
extern __shared__ scalar_t shared_base_anchor[];
for (int i = threadIdx.x; i < num_base_anchors * 4; i += blockDim.x) {
shared_base_anchor[i] = base_anchor[i];
}
__syncthreads();

CUDA_1D_KERNEL_LOOP(index, num_base_anchors * feat_w * feat_h) {
const int a_offset = (index % num_base_anchors) << 2;
const scalar_t w = scalar_t(((index / num_base_anchors) % feat_w) * stride_w);
const scalar_t h = scalar_t((index / (feat_w * num_base_anchors)) * stride_h);

auto out_start = output + index * 4;
out_start[0] = shared_base_anchor[a_offset] + w;
out_start[1] = shared_base_anchor[a_offset + 1] + h;
out_start[2] = shared_base_anchor[a_offset + 2] + w;
out_start[3] = shared_base_anchor[a_offset + 3] + h;
}
}

template <typename scalar_t>
void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors,
int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream) {
trt_grid_priors_kernel<<<GET_BLOCKS(num_base_anchors * feat_w * feat_h), THREADS_PER_BLOCK,
DIVUP(num_base_anchors * 4, 32) * 32 * sizeof(scalar_t), stream>>>(
base_anchor, output, (int)num_base_anchors, (int)feat_w, (int)feat_h, (int)stride_w,
(int)stride_h);
}
template void trt_grid_priors_impl<float>(const float* base_anchor, float* output,
int num_base_anchors, int feat_w, int feat_h,
int stride_w, int stride_h, cudaStream_t stream);
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef TRT_GRID_PRIORS_KERNEL_HPP
#define TRT_GRID_PRIORS_KERNEL_HPP
#include <cuda_runtime.h>

template <typename scalar_t>
void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors,
int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream);

#endif
42 changes: 42 additions & 0 deletions docs/en/ops/tensorrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@
- [Inputs](#inputs-7)
- [Outputs](#outputs-7)
- [Type Constraints](#type-constraints-7)
- [GridPriorsTRT](#gridpriorstrt)
- [Description](#description-8)
- [Parameters](#parameters-8)
- [Inputs](#inputs-8)
- [Outputs](#outputs-8)
- [Type Constraints](#type-constraints-8)

<!-- TOC -->

Expand Down Expand Up @@ -363,3 +369,39 @@ Batched rotated NMS with a fixed number of output bounding boxes.
#### Type Constraints

- T:tensor(float32, Linear)

### GridPriorsTRT

#### Description

Generate the anchors for object detection task.

#### Parameters

| Type | Parameter | Description |
| ----- | ---------- | --------------------------------- |
| `int` | `stride_w` | The stride of the feature width. |
| `int` | `stride_h` | The stride of the feature height. |

#### Inputs

<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>The base anchors; 2-D tensor with shape [num_base_anchor, 4].</dd>
<dt><tt>inputs[1]</tt>: TAny</dt>
<dd>height provider; 1-D tensor with shape [featmap_height]. The data will never been used.</dd>
<dt><tt>inputs[2]</tt>: TAny</dt>
<dd>width provider; 1-D tensor with shape [featmap_width]. The data will never been used.</dd>
</dl>

#### Outputs

<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>output anchors; 2-D tensor of shape (num_base_anchor*featmap_height*featmap_widht, 4).</dd>
</dl>

#### Type Constraints

- T:tensor(float32, Linear)
- TAny: Any
1 change: 1 addition & 0 deletions mmdeploy/codebase/mmdet/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .anchor import * # noqa: F401,F403
from .bbox import * # noqa: F401,F403
from .ops import * # noqa: F401,F403
from .post_processing import * # noqa: F401,F403
Loading

0 comments on commit dc5f9c3

Please sign in to comment.