1
- // #include <string>
2
- // #include <iostream>
3
- // #include <sstream>
4
- // #include <ATen/ATen.h>
5
- // #include <ATen/cuda/CUDAEvent.h>
6
- // #include <cuda_runtime_api.h>
7
- // #include <vector>
8
- // #include <cudnn.h>
9
-
10
- // #include "core/util/prelude.h"
11
- // #include "torch/torch.h"
12
- // #include "NvInfer.h"
13
-
14
1
#include " interpolate_plugin.h"
15
2
16
3
using namespace nvinfer1 ;
@@ -80,27 +67,30 @@ int InterpolatePlugin::getNbOutputs() const {
80
67
}
81
68
82
69
const char * InterpolatePlugin::getPluginType () const {
83
- return " Interpolate_TRTorch " ;
70
+ return " Interpolate " ;
84
71
}
85
72
86
73
const char * InterpolatePlugin::getPluginVersion () const {
87
74
return " 1" ;
88
75
}
89
76
90
77
const char * InterpolatePlugin::getPluginNamespace () const {
91
- return " trtorch " ;
78
+ return " " ;
92
79
}
93
80
94
- int InterpolatePlugin::getTensorRTVersion () const {
95
- return NV_TENSORRT_MAJOR;
96
- }
97
81
98
82
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone () const {
99
83
return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners);
100
84
}
101
85
102
86
nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions (int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) {
103
- return inputs[0 ];
87
+ nvinfer1::DimsExprs output (inputs[0 ]);
88
+
89
+ for (unsigned int i = 0 ; i < out_shape.size (); i++) {
90
+ output.d [i] = exprBuilder.constant (out_shape[i]);
91
+ }
92
+
93
+ return output;
104
94
}
105
95
106
96
nvinfer1::DataType InterpolatePlugin::getOutputDataType (int index, const nvinfer1::DataType* inputTypes, int nbInputs) const {
@@ -109,6 +99,8 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
109
99
110
100
int InterpolatePlugin::initialize () {
111
101
tensor_options = tensor_options.device (c10::kCUDA );
102
+
103
+ // c10::kFloat = FLOAT32
112
104
tensor_options = tensor_options.dtype (c10::kFloat );
113
105
114
106
return 0 ;
@@ -164,6 +156,7 @@ size_t InterpolatePlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inp
164
156
int InterpolatePlugin::enqueue (const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
165
157
void *const *outputs, void *workspace,
166
158
cudaStream_t stream) {
159
+
167
160
at::Tensor input = at::from_blob ((void *) inputs[0 ], in_shape, [](void *){}, tensor_options);
168
161
at::Tensor output = at::from_blob (outputs[0 ], out_shape, [](void *){}, tensor_options);
169
162
@@ -200,13 +193,11 @@ int InterpolatePlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, cons
200
193
* InterpolatePluginCreator class implementations
201
194
*/
202
195
const char * InterpolatePluginCreator::getPluginNamespace () const {
203
- return " trtorch " ;
196
+ return " " ;
204
197
}
205
198
206
- void InterpolatePluginCreator::setPluginNamespace (const char * libNamespace) {}
207
-
208
199
const char * InterpolatePluginCreator::getPluginName () const {
209
- return " interpolate " ;
200
+ return " Interpolate " ;
210
201
}
211
202
212
203
const char * InterpolatePluginCreator::getPluginVersion () const {
@@ -217,7 +208,9 @@ nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(const char* name, co
217
208
return nullptr ;
218
209
}
219
210
220
- nvinfer1::IPluginV2DynamicExt* InterpolatePluginCreator::createPlugin (const char * name, std::vector<int64_t > in_shape, std::vector<int64_t > out_shape, std::vector<int64_t > size, std::string mode, bool align_corners) {
211
+ InterpolatePlugin* InterpolatePluginCreator::createPlugin (const char * name, std::vector<int64_t > in_shape, std::vector<int64_t > out_shape,
212
+ std::vector<int64_t > size,
213
+ std::string mode, bool align_corners) {
221
214
name = name;
222
215
return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners);
223
216
}
@@ -238,5 +231,4 @@ REGISTER_TENSORRT_PLUGIN(InterpolatePluginCreator);
238
231
} // namespace converters
239
232
} // namespace conversion
240
233
} // namespace core
241
- } // namespace trtorch
242
-
234
+ } // namespace trtorch
0 commit comments