-
Notifications
You must be signed in to change notification settings - Fork 629
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhancement] TensorRT Anchor generator plugin (#646)
* custom trt anchor generator * add ut * add docstring, update doc
- Loading branch information
q.yao
authored
Jun 28, 2022
1 parent
4d9e209
commit dc5f9c3
Showing
11 changed files
with
573 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
154 changes: 154 additions & 0 deletions
154
csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#include "trt_grid_priors.hpp" | ||
|
||
#include <assert.h> | ||
|
||
#include <chrono> | ||
|
||
#include "trt_grid_priors_kernel.hpp" | ||
#include "trt_serialize.hpp" | ||
|
||
using namespace nvinfer1; | ||
|
||
namespace mmdeploy { | ||
namespace { | ||
static const char *PLUGIN_VERSION{"1"}; | ||
static const char *PLUGIN_NAME{"GridPriorsTRT"}; | ||
} // namespace | ||
|
||
GridPriorsTRT::GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride) | ||
: TRTPluginBase(name), mStride(stride) {} | ||
|
||
GridPriorsTRT::GridPriorsTRT(const std::string name, const void *data, size_t length) | ||
: TRTPluginBase(name) { | ||
deserialize_value(&data, &length, &mStride); | ||
} | ||
GridPriorsTRT::~GridPriorsTRT() {} | ||
|
||
nvinfer1::IPluginV2DynamicExt *GridPriorsTRT::clone() const TRT_NOEXCEPT { | ||
GridPriorsTRT *plugin = new GridPriorsTRT(mLayerName, mStride); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
|
||
return plugin; | ||
} | ||
|
||
nvinfer1::DimsExprs GridPriorsTRT::getOutputDimensions( | ||
int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, | ||
nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { | ||
// input[0] == base_anchor | ||
// input[1] == empty_h | ||
// input[2] == empty_w | ||
|
||
nvinfer1::DimsExprs ret; | ||
ret.nbDims = 2; | ||
auto area = | ||
exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[2].d[0], *inputs[1].d[0]); | ||
ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *area, *(inputs[0].d[0])); | ||
ret.d[1] = exprBuilder.constant(4); | ||
|
||
return ret; | ||
} | ||
|
||
bool GridPriorsTRT::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, | ||
int nbInputs, int nbOutputs) TRT_NOEXCEPT { | ||
if (pos == 0) { | ||
return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && | ||
ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); | ||
} else if (pos - nbInputs == 0) { | ||
return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; | ||
} else { | ||
return true; | ||
} | ||
} | ||
|
||
int GridPriorsTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, | ||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, | ||
void *const *outputs, void *workSpace, | ||
cudaStream_t stream) TRT_NOEXCEPT { | ||
int num_base_anchors = inputDesc[0].dims.d[0]; | ||
int feat_h = inputDesc[1].dims.d[0]; | ||
int feat_w = inputDesc[2].dims.d[0]; | ||
|
||
const void *base_anchor = inputs[0]; | ||
void *output = outputs[0]; | ||
|
||
auto data_type = inputDesc[0].type; | ||
switch (data_type) { | ||
case nvinfer1::DataType::kFLOAT: | ||
trt_grid_priors_impl<float>((float *)base_anchor, (float *)output, num_base_anchors, feat_w, | ||
feat_h, mStride.d[0], mStride.d[1], stream); | ||
break; | ||
default: | ||
return 1; | ||
} | ||
|
||
return 0; | ||
} | ||
|
||
nvinfer1::DataType GridPriorsTRT::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, | ||
int nbInputs) const TRT_NOEXCEPT { | ||
return inputTypes[0]; | ||
} | ||
|
||
// IPluginV2 Methods | ||
const char *GridPriorsTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } | ||
|
||
const char *GridPriorsTRT::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } | ||
|
||
int GridPriorsTRT::getNbOutputs() const TRT_NOEXCEPT { return 1; } | ||
|
||
size_t GridPriorsTRT::getSerializationSize() const TRT_NOEXCEPT { return serialized_size(mStride); } | ||
|
||
void GridPriorsTRT::serialize(void *buffer) const TRT_NOEXCEPT { | ||
serialize_value(&buffer, mStride); | ||
; | ||
} | ||
|
||
////////////////////// creator ///////////////////////////// | ||
|
||
GridPriorsTRTCreator::GridPriorsTRTCreator() { | ||
mPluginAttributes.clear(); | ||
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h")); | ||
mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w")); | ||
mFC.nbFields = mPluginAttributes.size(); | ||
mFC.fields = mPluginAttributes.data(); | ||
} | ||
|
||
const char *GridPriorsTRTCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } | ||
|
||
const char *GridPriorsTRTCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } | ||
|
||
nvinfer1::IPluginV2 *GridPriorsTRTCreator::createPlugin( | ||
const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { | ||
int stride_w = 1; | ||
int stride_h = 1; | ||
|
||
for (int i = 0; i < fc->nbFields; i++) { | ||
if (fc->fields[i].data == nullptr) { | ||
continue; | ||
} | ||
std::string field_name(fc->fields[i].name); | ||
|
||
if (field_name.compare("stride_w") == 0) { | ||
stride_w = static_cast<const int *>(fc->fields[i].data)[0]; | ||
} | ||
if (field_name.compare("stride_h") == 0) { | ||
stride_h = static_cast<const int *>(fc->fields[i].data)[0]; | ||
} | ||
} | ||
nvinfer1::Dims stride{2, {stride_w, stride_h}}; | ||
|
||
GridPriorsTRT *plugin = new GridPriorsTRT(name, stride); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
|
||
nvinfer1::IPluginV2 *GridPriorsTRTCreator::deserializePlugin(const char *name, | ||
const void *serialData, | ||
size_t serialLength) TRT_NOEXCEPT { | ||
auto plugin = new GridPriorsTRT(name, serialData, serialLength); | ||
plugin->setPluginNamespace(getPluginNamespace()); | ||
return plugin; | ||
} | ||
REGISTER_TENSORRT_PLUGIN(GridPriorsTRTCreator); | ||
} // namespace mmdeploy |
66 changes: 66 additions & 0 deletions
66
csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved. | ||
#ifndef TRT_GRID_PRIORS_HPP | ||
#define TRT_GRID_PRIORS_HPP | ||
#include <cublas_v2.h> | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <vector> | ||
|
||
#include "trt_plugin_base.hpp" | ||
|
||
namespace mmdeploy { | ||
class GridPriorsTRT : public TRTPluginBase { | ||
public: | ||
GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride); | ||
|
||
GridPriorsTRT(const std::string name, const void *data, size_t length); | ||
|
||
GridPriorsTRT() = delete; | ||
|
||
~GridPriorsTRT() TRT_NOEXCEPT override; | ||
|
||
// IPluginV2DynamicExt Methods | ||
nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; | ||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, | ||
int nbInputs, nvinfer1::IExprBuilder &exprBuilder) | ||
TRT_NOEXCEPT override; | ||
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, | ||
int nbOutputs) TRT_NOEXCEPT override; | ||
int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, | ||
const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, | ||
void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; | ||
|
||
// IPluginV2Ext Methods | ||
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, | ||
int nbInputs) const TRT_NOEXCEPT override; | ||
|
||
// IPluginV2 Methods | ||
const char *getPluginType() const TRT_NOEXCEPT override; | ||
const char *getPluginVersion() const TRT_NOEXCEPT override; | ||
int getNbOutputs() const TRT_NOEXCEPT override; | ||
size_t getSerializationSize() const TRT_NOEXCEPT override; | ||
void serialize(void *buffer) const TRT_NOEXCEPT override; | ||
|
||
private: | ||
nvinfer1::Dims mStride; | ||
|
||
cublasHandle_t m_cublas_handle; | ||
}; | ||
|
||
class GridPriorsTRTCreator : public TRTPluginCreatorBase { | ||
public: | ||
GridPriorsTRTCreator(); | ||
|
||
const char *getPluginName() const TRT_NOEXCEPT override; | ||
|
||
const char *getPluginVersion() const TRT_NOEXCEPT override; | ||
|
||
nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) | ||
TRT_NOEXCEPT override; | ||
|
||
nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, | ||
size_t serialLength) TRT_NOEXCEPT override; | ||
}; | ||
} // namespace mmdeploy | ||
#endif // TRT_GRID_PRIORS_HPP |
43 changes: 43 additions & 0 deletions
43
csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#include <cuda_fp16.h> | ||
|
||
#include "common_cuda_helper.hpp" | ||
#include "trt_grid_priors_kernel.hpp" | ||
#include "trt_plugin_helper.hpp" | ||
|
||
template <typename scalar_t> | ||
__global__ void trt_grid_priors_kernel(const scalar_t* base_anchor, scalar_t* output, | ||
int num_base_anchors, int feat_w, int feat_h, int stride_w, | ||
int stride_h) { | ||
// load base anchor into shared memory. | ||
extern __shared__ scalar_t shared_base_anchor[]; | ||
for (int i = threadIdx.x; i < num_base_anchors * 4; i += blockDim.x) { | ||
shared_base_anchor[i] = base_anchor[i]; | ||
} | ||
__syncthreads(); | ||
|
||
CUDA_1D_KERNEL_LOOP(index, num_base_anchors * feat_w * feat_h) { | ||
const int a_offset = (index % num_base_anchors) << 2; | ||
const scalar_t w = scalar_t(((index / num_base_anchors) % feat_w) * stride_w); | ||
const scalar_t h = scalar_t((index / (feat_w * num_base_anchors)) * stride_h); | ||
|
||
auto out_start = output + index * 4; | ||
out_start[0] = shared_base_anchor[a_offset] + w; | ||
out_start[1] = shared_base_anchor[a_offset + 1] + h; | ||
out_start[2] = shared_base_anchor[a_offset + 2] + w; | ||
out_start[3] = shared_base_anchor[a_offset + 3] + h; | ||
} | ||
} | ||
|
||
template <typename scalar_t> | ||
void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, | ||
int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream) { | ||
trt_grid_priors_kernel<<<GET_BLOCKS(num_base_anchors * feat_w * feat_h), THREADS_PER_BLOCK, | ||
DIVUP(num_base_anchors * 4, 32) * 32 * sizeof(scalar_t), stream>>>( | ||
base_anchor, output, (int)num_base_anchors, (int)feat_w, (int)feat_h, (int)stride_w, | ||
(int)stride_h); | ||
} | ||
template void trt_grid_priors_impl<float>(const float* base_anchor, float* output, | ||
int num_base_anchors, int feat_w, int feat_h, | ||
int stride_w, int stride_h, cudaStream_t stream); |
10 changes: 10 additions & 0 deletions
10
csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
// Copyright (c) OpenMMLab. All rights reserved | ||
#ifndef TRT_GRID_PRIORS_KERNEL_HPP | ||
#define TRT_GRID_PRIORS_KERNEL_HPP | ||
#include <cuda_runtime.h> | ||
|
||
template <typename scalar_t> | ||
void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, | ||
int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream); | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .anchor import * # noqa: F401,F403 | ||
from .bbox import * # noqa: F401,F403 | ||
from .ops import * # noqa: F401,F403 | ||
from .post_processing import * # noqa: F401,F403 |
Oops, something went wrong.