Skip to content

Commit 6f4aa40

Browse files
committed
feat(//core/plugins): Add adaptive_max_pool2d plugin, enable the plugins to run on GPU
Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com> Signed-off-by: Dheeraj Peri <peri.dheeraj@gmail.com>
1 parent 03a6ca4 commit 6f4aa40

File tree

6 files changed

+137
-177
lines changed

6 files changed

+137
-177
lines changed

core/conversion/converters/impl/pooling.cpp

+39-59
Original file line numberDiff line numberDiff line change
@@ -60,77 +60,53 @@ bool AdaptivePoolingConverter(
6060
auto in_shape = util::toVec(in->getDimensions());
6161
nvinfer1::ILayer* new_layer = nullptr;
6262

63-
if (ctx->input_is_dynamic) {
64-
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
65-
LOG_WARNING(
66-
"Adaptive pooling layer will be run through ATen, via not TensorRT, performace will be lower than expected. Consider switching either to static input shape or moving to non adaptive pooling if this is an issue");
67-
#else
68-
LOG_WARNING(
69-
"Adaptive pooling layer will be run through ATen (on CPU), via not TensorRT, performace will suffer. Consider switching either to static input shape or moving to non adaptive pooling");
70-
#endif
63+
/*======CONFIGURE PLUGIN PARAMETERS======*/
64+
nvinfer1::PluginFieldCollection fc;
65+
std::vector<nvinfer1::PluginField> f;
7166

72-
TRTORCH_CHECK(
73-
pool_type == nvinfer1::PoolingType::kAVERAGE,
74-
"Unable to create MAX pooling (interpolation) plugin from node" << *n);
75-
76-
nvinfer1::PluginFieldCollection fc;
77-
std::vector<nvinfer1::PluginField> f;
78-
79-
auto out_shape = in_shape;
80-
std::copy_n(out_size.d, out_size.nbDims, out_shape.begin() + (in_shape.size() - out_size.nbDims));
67+
auto out_shape = in_shape;
68+
auto out_size_vec = util::toVec(out_size);
8169

82-
std::vector<int32_t> in_shape_casted(in_shape.begin(), in_shape.end());
83-
f.emplace_back(
84-
nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size()));
70+
std::copy(out_size_vec.begin(), out_size_vec.end(), out_shape.begin() + (in_shape.size() - out_size_vec.size()));
8571

86-
std::vector<int32_t> out_shape_casted(out_shape.begin(), out_shape.end());
87-
f.emplace_back(nvinfer1::PluginField(
88-
"out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size()));
72+
std::vector<int32_t> in_shape_casted(in_shape.begin(), in_shape.end());
73+
f.emplace_back(
74+
nvinfer1::PluginField("in_shape", in_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, in_shape.size()));
8975

90-
auto out_size_vec = util::toVec(out_size);
91-
std::vector<int32_t> out_size_casted(out_size_vec.begin(), out_size_vec.end());
92-
f.emplace_back(nvinfer1::PluginField(
93-
"out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size_vec.size()));
76+
std::vector<int32_t> out_shape_casted(out_shape.begin(), out_shape.end());
77+
f.emplace_back(
78+
nvinfer1::PluginField("out_shape", out_shape_casted.data(), nvinfer1::PluginFieldType::kINT32, out_shape.size()));
9479

95-
f.emplace_back(nvinfer1::PluginField("scales", nullptr, nvinfer1::PluginFieldType::kFLOAT64, 0));
80+
std::vector<int32_t> out_size_casted(out_size_vec.begin(), out_size_vec.end());
81+
f.emplace_back(nvinfer1::PluginField(
82+
"out_size", out_size_casted.data(), nvinfer1::PluginFieldType::kINT32, out_size_vec.size()));
9683

97-
std::string mode = "adaptive_pool2d";
98-
f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1));
84+
f.emplace_back(nvinfer1::PluginField("scales", nullptr, nvinfer1::PluginFieldType::kFLOAT64, 0));
9985

100-
int32_t align_corners_casted = 0;
101-
f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1));
86+
int32_t align_corners_casted = 0;
87+
f.emplace_back(nvinfer1::PluginField("align_corners", &align_corners_casted, nvinfer1::PluginFieldType::kINT32, 1));
10288

103-
int32_t use_scales_casted = 0;
104-
f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1));
89+
int32_t use_scales_casted = 0;
90+
f.emplace_back(nvinfer1::PluginField("use_scales", &use_scales_casted, nvinfer1::PluginFieldType::kINT32, 1));
10591

106-
fc.nbFields = f.size();
107-
fc.fields = f.data();
108-
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
109-
auto interpolate_plugin = creator->createPlugin("adaptive_pool2d", &fc);
110-
111-
new_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
112-
TRTORCH_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n);
92+
std::string mode = "adaptive_avg_pool2d";
93+
if (pool_type == nvinfer1::PoolingType::kMAX) {
94+
mode = "adaptive_max_pool2d";
95+
}
96+
f.emplace_back(nvinfer1::PluginField("mode", &mode, nvinfer1::PluginFieldType::kCHAR, 1));
11397

114-
} else {
115-
std::vector<int64_t> stride(out_size.nbDims);
116-
for (int64_t i = 0; i < out_size.nbDims; i++) {
117-
stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_size.d[(out_size.nbDims - 1) - i];
118-
}
119-
LOG_DEBUG("Stride: " << util::toDims(stride));
98+
fc.nbFields = f.size();
99+
fc.fields = f.data();
100+
/*====== PLUGIN PARAMETERS CONFIGURATION COMPLETED ======*/
120101

121-
std::vector<int64_t> window(out_size.nbDims);
122-
for (int64_t i = 0; i < out_size.nbDims; i++) {
123-
window[window.size() - 1 - i] =
124-
in_shape[in_shape.size() - 1 - i] - (out_size.d[out_size.nbDims - 1 - i] - 1) * stride[stride.size() - 1 - i];
125-
}
102+
LOG_WARNING(
103+
"Adaptive pooling layer will be using Aten library kernels in pytorch for execution. TensorRT does not support adaptive pooling natively. Consider switching to non-adaptive pooling if this is an issue");
126104

127-
LOG_DEBUG("Window: " << util::toDims(window));
105+
auto creator = getPluginRegistry()->getPluginCreator("Interpolate", "1", "trtorch");
106+
auto interpolate_plugin = creator->createPlugin(mode.c_str(), &fc);
128107

129-
auto pooling_layer = ctx->net->addPoolingNd(*in, pool_type, util::toDims(window));
130-
TRTORCH_CHECK(pooling_layer, "Unable to create average pooling layer from node: " << *n);
131-
pooling_layer->setStrideNd(util::toDims(stride));
132-
new_layer = pooling_layer;
133-
}
108+
new_layer = ctx->net->addPluginV2(reinterpret_cast<nvinfer1::ITensor* const*>(&in), 1, *interpolate_plugin);
109+
TRTORCH_CHECK(new_layer, "Unable to create pooling (interpolation) plugin from node" << *n);
134110

135111
new_layer->setName(util::node_info(n).c_str());
136112
auto layer_output = addUnpadding(ctx, n, new_layer->getOutput(0), orig_dims.nbDims, false, false);
@@ -156,7 +132,7 @@ bool PoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& args,
156132
auto padding = util::toDims(args[3].unwrapToIntList());
157133
auto stride = util::toDims(args[2].unwrapToIntList());
158134
if (stride.nbDims == 0) {
159-
LOG_DEBUG("Stride not providied, using kernel_size as stride");
135+
LOG_DEBUG("Stride not provided, using kernel_size as stride");
160136
stride = util::toDims(args[1].unwrapToIntList());
161137
}
162138

@@ -265,6 +241,10 @@ auto pooling_registrations TRTORCH_UNUSED =
265241
.pattern({"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)",
266242
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
267243
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kAVERAGE);
244+
}})
245+
.pattern({"aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)",
246+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
247+
return AdaptivePoolingConverter(ctx, n, args, nvinfer1::PoolingType::kMAX);
268248
}});
269249
} // namespace
270250
} // namespace impl

core/plugins/impl/interpolate_plugin.cpp

+30-69
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ InterpolatePlugin::InterpolatePlugin(
2727
align_corners_(align_corners),
2828
use_scales_(use_scales) {
2929
if (use_scales) {
30-
TRTORCH_ASSERT(mode_ != "adaptive_pool2d", "use_scales is not valid for adaptive_pool2d");
30+
TRTORCH_ASSERT(mode_ != "adaptive_avg_pool2d", "use_scales is not valid for adaptive_avg_pool2d");
3131
TRTORCH_ASSERT(
3232
scales_.size() != 0, "Attempted to use interpolate plugin without providing scales while use_scales=true");
3333
at::Tensor input = at::randint(1, 10, in_shape, {at::kCUDA});
@@ -106,7 +106,11 @@ std::vector<int64_t> InterpolatePlugin::getOutputSize() {
106106
}
107107

108108
int InterpolatePlugin::getNbOutputs() const {
109-
return 1;
109+
if (mode_ == "adaptive_max_pool2d") {
110+
return 2;
111+
} else {
112+
return 1;
113+
}
110114
}
111115

112116
const char* InterpolatePlugin::getPluginType() const {
@@ -166,15 +170,6 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
166170
}
167171

168172
int InterpolatePlugin::initialize() {
169-
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
170-
tensor_options_ = tensor_options_.device(c10::kCUDA);
171-
#else
172-
tensor_options_ = tensor_options_.device(c10::kCPU);
173-
#endif
174-
175-
// c10::kFloat = FLOAT32
176-
tensor_options_ = tensor_options_.dtype(c10::kFloat);
177-
178173
return 0;
179174
}
180175

@@ -211,9 +206,15 @@ bool InterpolatePlugin::supportsFormatCombination(
211206
const nvinfer1::PluginTensorDesc* inOut,
212207
int nbInputs,
213208
int nbOutputs) {
214-
TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output");
215209
TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to interpolate plugin");
216-
TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to interpolate plugin");
210+
211+
if (mode_ == "adaptive_max_pool2d") {
212+
TRTORCH_ASSERT(nbOutputs == 2, "Expected 2 tensors as output to interpolate plugin");
213+
TRTORCH_ASSERT(0 <= pos && pos <= 2, "There should be exactly 3 connections to the plugin - 1 input, 2 output");
214+
} else {
215+
TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to interpolate plugin");
216+
TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output");
217+
}
217218

218219
const nvinfer1::PluginTensorDesc& in = inOut[0];
219220

@@ -250,10 +251,10 @@ int InterpolatePlugin::enqueue(
250251
void* const* outputs,
251252
void* workspace,
252253
cudaStream_t stream) {
253-
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
254-
at::Tensor input = at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, tensor_options_);
255-
at::Tensor output = at::from_blob(
256-
outputs[0], util::volume(outputDesc->dims), [](void*) {}, tensor_options_);
254+
at::Tensor input =
255+
at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat);
256+
at::Tensor output =
257+
at::from_blob(outputs[0], util::toVec(outputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat);
257258

258259
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
259260
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
@@ -263,27 +264,30 @@ int InterpolatePlugin::enqueue(
263264
cudaEventRecord(event, stream);
264265

265266
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
266-
267+
at::Tensor out;
267268
if (use_scales_) {
268269
if (mode_ == "linear") {
269-
at::upsample_linear1d_out(output, input, {}, align_corners_, scales_[0]);
270+
out = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]});
270271
} else if (mode_ == "bilinear") {
271-
at::upsample_bilinear2d_out(output, input, {}, align_corners_, scales_[0], scales_[1]);
272+
out = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
272273
} else if (mode_ == "trilinear") {
273-
at::upsample_trilinear3d_out(output, input, {}, align_corners_, scales_[0], scales_[1], scales_[2]);
274+
out = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
274275
}
275276
} else {
276277
if (mode_ == "linear") {
277-
at::upsample_linear1d_out(output, input, {size_[0]}, align_corners_);
278+
out = at::upsample_linear1d(input, {size_[0]}, align_corners_);
278279
} else if (mode_ == "bilinear") {
279-
at::upsample_bilinear2d_out(output, input, {size_[0], size_[1]}, align_corners_);
280+
out = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
280281
} else if (mode_ == "trilinear") {
281-
at::upsample_trilinear3d_out(output, input, {size_[0], size_[1], size_[2]}, align_corners_);
282-
} else if (mode_ == "adaptive_pool2d") {
283-
at::adaptive_avg_pool2d_out(output, input, {size_[0], size_[1]});
282+
out = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
283+
} else if (mode_ == "adaptive_avg_pool2d") {
284+
out = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
285+
} else if (mode_ == "adaptive_max_pool2d") {
286+
out = std::get<0>(at::adaptive_max_pool2d(input, {size_[0], size_[1]}));
284287
}
285288
}
286289

290+
output.copy_(out);
287291
cudaEvent_t torch_event;
288292
cudaEventCreate(&torch_event);
289293
cudaEventRecord(torch_event, torch_stream.stream());
@@ -294,49 +298,6 @@ int InterpolatePlugin::enqueue(
294298
cudaEventDestroy(torch_event);
295299

296300
return 0;
297-
#else
298-
// TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen
299-
// kernels HACK: WAR because there is a segfault if you try to create a CUDA
300-
// Tensor in the context of TensorRT execution
301-
float* input_blob = (float*)malloc(util::volume(inputDesc->dims) * sizeof(float));
302-
cudaMemcpyAsync(
303-
input_blob,
304-
static_cast<const void*>(inputs[0]),
305-
util::volume(inputDesc->dims) * sizeof(float),
306-
cudaMemcpyDeviceToHost,
307-
stream);
308-
cudaStreamSynchronize(stream);
309-
310-
at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
311-
at::Tensor output;
312-
if (use_scales_) {
313-
if (mode_ == "linear") {
314-
output = at::upsample_linear1d(input, c10::nullopt, align_corners_, {scales_[0]});
315-
} else if (mode_ == "bilinear") {
316-
output = at::upsample_bilinear2d(input, c10::nullopt, align_corners_, scales_);
317-
} else if (mode_ == "trilinear") {
318-
output = at::upsample_trilinear3d(input, c10::nullopt, align_corners_, scales_);
319-
}
320-
} else {
321-
if (mode_ == "linear") {
322-
output = at::upsample_linear1d(input, {size_[0]}, align_corners_);
323-
} else if (mode_ == "bilinear") {
324-
output = at::upsample_bilinear2d(input, {size_[0], size_[1]}, align_corners_);
325-
} else if (mode_ == "trilinear") {
326-
output = at::upsample_trilinear3d(input, {size_[0], size_[1], size_[2]}, align_corners_);
327-
} else if (mode_ == "adaptive_pool2d") {
328-
output = at::adaptive_avg_pool2d(input, {size_[0], size_[1]});
329-
}
330-
}
331-
332-
cudaMemcpyAsync(
333-
outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream);
334-
cudaStreamSynchronize(stream);
335-
336-
free(input_blob);
337-
338-
return 0;
339-
#endif
340301
}
341302

342303
/*

core/plugins/impl/interpolate_plugin.h

-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ namespace impl {
2020

2121
class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
2222
private:
23-
at::TensorOptions tensor_options_;
2423
nvinfer1::DataType dtype_;
2524

2625
std::vector<int64_t> in_shape_;

core/plugins/impl/normalize_plugin.cpp

+7-38
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,6 @@ nvinfer1::DataType NormalizePlugin::getOutputDataType(int index, const nvinfer1:
103103
}
104104

105105
int NormalizePlugin::initialize() {
106-
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
107-
tensor_options_ = tensor_options_.device(c10::kCUDA);
108-
#else
109-
tensor_options_ = tensor_options_.device(c10::kCPU);
110-
#endif
111-
112-
// c10::kFloat = FLOAT32
113-
tensor_options_ = tensor_options_.dtype(c10::kFloat);
114-
115106
return 0;
116107
}
117108

@@ -181,11 +172,10 @@ int NormalizePlugin::enqueue(
181172
void* const* outputs,
182173
void* workspace,
183174
cudaStream_t stream) {
184-
// TRT <= 7.0
185-
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1)
186-
at::Tensor input = at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, tensor_options_);
187-
at::Tensor output = at::from_blob(
188-
outputs[0], util::volume(outputDesc->dims), [](void*) {}, tensor_options_);
175+
at::Tensor input =
176+
at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat);
177+
at::Tensor output =
178+
at::from_blob(outputs[0], util::toVec(outputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat);
189179

190180
at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool();
191181
at::cuda::CUDAStreamGuard torch_guard(torch_stream);
@@ -195,7 +185,9 @@ int NormalizePlugin::enqueue(
195185
cudaEventRecord(event, stream);
196186

197187
cudaStreamWaitEvent(torch_stream.stream(), event, 0);
198-
at::Tensor result = at::norm(input, order_, axes_, keep_dims_);
188+
189+
std::vector<int64_t> axes_double(axes_.begin(), axes_.end());
190+
at::Tensor result = at::norm(input, (int64_t)order_, axes_double, (bool)keep_dims_);
199191
output.copy_(result);
200192
cudaEvent_t torch_event;
201193
cudaEventCreate(&torch_event);
@@ -206,29 +198,6 @@ int NormalizePlugin::enqueue(
206198
cudaEventDestroy(event);
207199
cudaEventDestroy(torch_event);
208200
return 0;
209-
#else
210-
// TODO: When PyTorch updates to cuDNN 8 try moving back to CUDA based ATen
211-
// kernels HACK: WAR because there is a segfault if you try to create a CUDA
212-
// Tensor in the context of TensorRT execution
213-
float* input_blob = (float*)malloc(util::volume(inputDesc->dims) * sizeof(float));
214-
cudaMemcpyAsync(
215-
input_blob,
216-
static_cast<const void*>(inputs[0]),
217-
util::volume(inputDesc->dims) * sizeof(float),
218-
cudaMemcpyDeviceToHost,
219-
stream);
220-
cudaStreamSynchronize(stream);
221-
222-
at::Tensor input = at::from_blob((void*)input_blob, util::toVec(inputDesc->dims), tensor_options_);
223-
std::vector<int64_t> axes_new(axes_.begin(), axes_.end());
224-
at::Tensor output = at::norm(input, (int64_t)order_, axes_new, (bool)keep_dims_);
225-
cudaMemcpyAsync(
226-
outputs[0], output.data_ptr(), util::volume(outputDesc->dims) * sizeof(float), cudaMemcpyHostToDevice, stream);
227-
cudaStreamSynchronize(stream);
228-
229-
free(input_blob);
230-
return 0;
231-
#endif
232201
}
233202

234203
/*

core/plugins/impl/normalize_plugin.h

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ namespace impl {
2222

2323
class NormalizePlugin : public nvinfer1::IPluginV2DynamicExt {
2424
private:
25-
at::TensorOptions tensor_options_;
2625
nvinfer1::DataType dtype_;
2726
int32_t order_;
2827
std::vector<int32_t> axes_;

0 commit comments

Comments
 (0)