Skip to content

Commit 205ab99

Browse files
committed
feat(//core/conversion/converters/impl/plugins): Created interpolate plugin, works for mode='linear'
Signed-off-by: Abhiram Iyer <abhirami@nvidia.com> Signed-off-by: Abhiram Iyer <abhi.iyer.ai@gmail.com>
1 parent a0848b1 commit 205ab99

File tree

2 files changed

+21
-31
lines changed

2 files changed

+21
-31
lines changed

core/conversion/converters/impl/plugins/interpolate_plugin.cpp

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,3 @@
1-
// #include <string>
2-
// #include <iostream>
3-
// #include <sstream>
4-
// #include <ATen/ATen.h>
5-
// #include <ATen/cuda/CUDAEvent.h>
6-
// #include <cuda_runtime_api.h>
7-
// #include <vector>
8-
// #include <cudnn.h>
9-
10-
// #include "core/util/prelude.h"
11-
// #include "torch/torch.h"
12-
// #include "NvInfer.h"
13-
141
#include "interpolate_plugin.h"
152

163
using namespace nvinfer1;
@@ -80,27 +67,30 @@ int InterpolatePlugin::getNbOutputs() const {
8067
}
8168

8269
const char* InterpolatePlugin::getPluginType() const {
83-
return "Interpolate_TRTorch";
70+
return "Interpolate";
8471
}
8572

8673
const char* InterpolatePlugin::getPluginVersion() const{
8774
return "1";
8875
}
8976

9077
const char* InterpolatePlugin::getPluginNamespace() const {
91-
return "trtorch";
78+
return "";
9279
}
9380

94-
int InterpolatePlugin::getTensorRTVersion() const {
95-
return NV_TENSORRT_MAJOR;
96-
}
9781

9882
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
9983
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
10084
}
10185

10286
nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) {
103-
return inputs[0];
87+
nvinfer1::DimsExprs output(inputs[0]);
88+
89+
for (unsigned int i = 0; i < out_shape.size(); i++) {
90+
output.d[i] = exprBuilder.constant(out_shape[i]);
91+
}
92+
93+
return output;
10494
}
10595

10696
nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
@@ -109,6 +99,8 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
10999

110100
int InterpolatePlugin::initialize() {
111101
tensor_options = tensor_options.device(c10::kCUDA);
102+
103+
// c10::kFloat = FLOAT32
112104
tensor_options = tensor_options.dtype(c10::kFloat);
113105

114106
return 0;
@@ -164,6 +156,7 @@ size_t InterpolatePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inp
164156
int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
165157
void *const *outputs, void *workspace,
166158
cudaStream_t stream) {
159+
167160
at::Tensor input = at::from_blob((void*) inputs[0], in_shape, [](void*){}, tensor_options);
168161
at::Tensor output = at::from_blob(outputs[0], out_shape, [](void*){}, tensor_options);
169162

@@ -200,13 +193,11 @@ int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, cons
200193
* InterpolatePluginCreator class implementations
201194
*/
202195
const char* InterpolatePluginCreator::getPluginNamespace() const {
203-
return "trtorch";
196+
return "";
204197
}
205198

206-
void InterpolatePluginCreator::setPluginNamespace(const char* libNamespace) {}
207-
208199
const char* InterpolatePluginCreator::getPluginName() const {
209-
return "interpolate";
200+
return "Interpolate";
210201
}
211202

212203
const char* InterpolatePluginCreator::getPluginVersion() const {
@@ -217,7 +208,9 @@ nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, co
217208
return nullptr;
218209
}
219210

220-
nvinfer1::IPluginV2DynamicExt* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners) {
211+
InterpolatePlugin* InterpolatePluginCreator::createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape,
212+
std::vector<int64_t> size,
213+
std::string mode, bool align_corners) {
221214
name = name;
222215
return new InterpolatePlugin(in_shape, out_shape, size, mode, align_corners);
223216
}
@@ -238,5 +231,4 @@ REGISTER_TENSORRT_PLUGIN(InterpolatePluginCreator);
238231
} // namespace converters
239232
} // namespace conversion
240233
} // namespace core
241-
} // namespace trtorch
242-
234+
} // namespace trtorch

core/conversion/converters/impl/plugins/interpolate_plugin.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,7 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
6565

6666
const char* getPluginNamespace() const override;
6767

68-
void setPluginNamespace(const char* pluginNamespace) {}
69-
70-
int getTensorRTVersion() const override;
68+
void setPluginNamespace(const char* pluginNamespace) override {};
7169

7270
nvinfer1::IPluginV2DynamicExt* clone() const override;
7371

@@ -107,15 +105,15 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
107105

108106
const char* getPluginNamespace() const override;
109107

110-
void setPluginNamespace(const char* libNamespace) override;
108+
void setPluginNamespace(const char* libNamespace) override {};
111109

112110
const char* getPluginName() const override;
113111

114112
const char* getPluginVersion() const override;
115113

116114
nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection *fc) override;
117115

118-
nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners);
116+
InterpolatePlugin* createPlugin(const char* name, std::vector<int64_t> in_shape, std::vector<int64_t> out_shape, std::vector<int64_t> size, std::string mode, bool align_corners);
119117

120118
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void *serialData, size_t serialLength) override;
121119

0 commit comments

Comments
 (0)