1- #include < torch/extension.h>
2- #include < torch/script.h>
31#include < string>
42#include < iostream>
53#include < sstream>
6- #include < NvInfer.h>
74#include < ATen/ATen.h>
85#include < ATen/cuda/CUDAEvent.h>
9- #include < torch/torch.h>
106#include < cuda_runtime_api.h>
11- #include " NvInferVersion.h"
127#include < vector>
138#include < cudnn.h>
14- #include < NVInferRuntime.h>
15- #include < NVInferRuntimeCommon.h>
9+
10+ #include " core/util/prelude.h"
11+ #include " torch/torch.h"
12+ #include " NvInfer.h"
13+
14+ using namespace nvinfer1 ;
1615
1716namespace trtorch {
1817namespace core {
@@ -25,24 +24,58 @@ namespace {
2524class InterpolatePlugin : public nvinfer1 ::IPluginV2DynamicExt {
2625private:
2726 at::TensorOptions tensor_options;
28- std::vector<int64_t > input_sizes;
29- std::vector<int64_t > output_sizes;
3027 DataType dtype;
3128
29+ std::vector<int64_t > in_shape;
30+ std::vector<int64_t > out_shape;
3231 std::vector<int64_t > size;
3332 std::string mode;
3433 bool align_corners;
3534
3635public:
37- InterpolatePlugin (const char * name, std::vector<int64_t > in_shape,
38- std::vector<int64_t > out_shape,
39- std::string mode,
40- bool align_corners) : name(name), in_shape(in_shape), out_shape(out_shape), mode(mode), align_corners(align_corners) {}
41-
36+ InterpolatePlugin (std::vector<int64_t > in_shape, std::vector<int64_t > out_shape, std::vector<int64_t > size, std::string mode, bool align_corners) :
37+ in_shape (in_shape), out_shape(out_shape), size(size), mode(mode), align_corners(align_corners)
38+ {}
39+
40+ InterpolatePlugin (const char *data, size_t length) {
41+ std::istringstream data_stream (std::string (data, length));
42+
43+ torch::serialize::InputArchive input_archive;
44+ input_archive.load_from (data_stream);
45+
46+ {
47+ torch::IValue value;
48+ input_archive.read (" in_shape" , value);
49+ in_shape = value.toIntVector ();
50+ }
51+ {
52+ torch::IValue value;
53+ input_archive.read (" out_shape" , value);
54+ out_shape = value.toIntVector ();
55+ }
56+ {
57+ torch::IValue value;
58+ input_archive.read (" size" , value);
59+ size = value.toIntVector ();
60+ }
61+ {
62+ torch::IValue value;
63+ input_archive.read (" mode" , value);
64+ mode = value.toStringRef ();
65+ }
66+ {
67+ torch::IValue value;
68+ input_archive.read (" align_corners" , value);
69+ align_corners = value.toBool ();
70+ }
71+ }
4272
73+ int getNbOutputs () const override {
74+ return 1 ;
75+ }
4376
4477 const char * getPluginType () const override {
45- return " Interpolate " ;
78+ return " Interpolate_TRTorch " ;
4679 }
4780
4881 const char * getPluginVersion () const override {
@@ -60,79 +93,125 @@ class InterpolatePlugin : public nvinfer1::IPluginV2DynamicExt {
6093 }
6194
6295 nvinfer1::IPluginV2DynamicExt* clone () const override {
63- auto * plugin = new InterpolatePlugin (*this );
64- return plugin;
96+ return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners);
6597 }
6698
67- nvinfer::DimsExprs getOutputDimensions (int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) const override {
68-
69- }
99+ nvinfer1::DimsExprs getOutputDimensions (int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, nvinfer1::IExprBuilder &exprBuilder) override {
100+ // nvinfer1::DimsExprs output(inputs[0]);
70101
71- nvinfer1::DataType getOutputDataType (int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override {
102+ // output.nbDims = out_shape.size();
103+
104+ // for (int i = 0; i < out_shape.size(); i++) {
105+ // output.d[i] = exprBuilder.getConstantValue(out_shape[i]);
106+ // }
72107
108+ // return output;
109+ nvinfer1::DimsExprs empty;
110+ return empty;
73111 }
74112
75- int getNbOutputs ( ) const override {
76- return 1 ;
113+ nvinfer1::DataType getOutputDataType ( int index, const nvinfer1::DataType* inputTypes, int nbInputs ) const override {
114+ return DataType:: kFLOAT ;
77115 }
78116
79117 int initialize () override {
118+ tensor_options = tensor_options.device (c10::kCUDA );
119+ tensor_options = tensor_options.dtype (c10::kFloat );
80120
121+ return 0 ;
81122 }
82123
83- void terminate () override {
84-
85- }
124+ void terminate () override {}
86125
87126 void serialize (void * buffer) const override {
127+ std::string data = serializeToString ();
128+ size_t size = getSerializationSize ();
88129
130+ data.copy ((char *) buffer, size);
89131 }
90132
91- void size_t getSerializationSize () const override {
133+ std::string serializeToString () const {
134+ torch::serialize::OutputArchive output_archive;
92135
93- }
136+ output_archive.write (" in_shape" , torch::IValue (in_shape));
137+ output_archive.write (" out_shape" , torch::IValue (out_shape));
138+ output_archive.write (" size" , torch::IValue (size));
139+ output_archive.write (" mode" , torch::IValue (mode));
140+ output_archive.write (" align_corners" , torch::IValue (align_corners));
94141
95- void destroy () override {
142+ std::ostringstream data_str;
143+ output_archive.save_to (data_str);
96144
145+ return data_str.str ();
97146 }
98147
99- bool supportsFormatCombination ( int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override {
100-
148+ size_t getSerializationSize () const override {
149+ return serializeToString (). size ();
101150 }
102151
103- void configurePlugin ( const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs ) override {
152+ void destroy ( ) override {}
104153
105- }
154+ bool supportsFormatCombination (int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) override {
155+ if (inOut->format != nvinfer1::TensorFormat::kLINEAR ) {
156+ return false ;
157+ }
106158
107- size_t getWorkspaceSize (const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const override {
159+ if (inOut->type == DataType::kINT32 || inOut->type == DataType::kINT8 ) {
160+ return false ;
161+ }
108162
163+ return true ;
109164 }
110165
111- void attachToContext (nvinfer1::cudnnContext*, nvinfer1::cublasContext*, nvinfer1::IGpuAllocator*) override {}
166+ void configurePlugin (const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) override {
167+ dtype = DataType::kFLOAT ;
168+ }
112169
113- void detachFromContext () override {}
170+ size_t getWorkspaceSize (const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const override {
171+ return 0 ;
172+ }
114173
115174 int enqueue (const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc, const void *const *inputs,
116175 void *const *outputs, void *workspace,
117176 cudaStream_t stream) override {
118-
119- }
177+ at::Tensor input = at::from_blob (( void *) inputs[ 0 ], in_shape, []( void *){}, tensor_options);
178+ at::Tensor output = at::from_blob (outputs[ 0 ], out_shape, []( void *){}, tensor_options);
120179
180+ at::cuda::CUDAStream torch_stream = at::cuda::getStreamFromPool ();
181+ at::cuda::CUDAStreamGuard torch_guard (torch_stream);
121182
183+ cudaEvent_t event;
184+ cudaEventCreate (&event);
185+ cudaEventRecord (event, stream);
122186
187+ cudaStreamWaitEvent (torch_stream.stream (), event, 0 );
123188
124- private:
125- std::string name;
126- std::vector<int64_t > in_shape;
127- std::vector<int64_t > out_shape;
128- std::string mode;
129- bool align_corners;
189+ if (mode == " linear" ) {
190+ at::upsample_linear1d_out (output, input, {size[0 ]}, align_corners);
191+ } else if (mode == " bilinear" ) {
192+ at::upsample_bilinear2d_out (output, input, {size[0 ], size[1 ]}, align_corners);
193+ } else if (mode == " trilinear" ) {
194+ at::upsample_trilinear3d_out (output, input, {size[0 ], size[1 ], size[2 ]}, align_corners);
195+ }
196+
197+ cudaEvent_t torch_event;
198+ cudaEventCreate (&torch_event);
199+ cudaEventRecord (torch_event, torch_stream.stream ());
200+
201+ cudaStreamWaitEvent (stream, torch_event, 0 );
202+
203+ cudaEventDestroy (event);
204+ cudaEventDestroy (torch_event);
130205
131- nvinfer1::DataType dtype;
132- }
206+ return 0 ;
207+ }
208+ };
133209
134210
135211class InterpolatePluginCreator : public nvinfer1 ::IPluginCreator {
212+ private:
213+ std::string name;
214+
136215public:
137216 InterpolatePluginCreator () {}
138217
@@ -158,18 +237,20 @@ class InterpolatePluginCreator : public nvinfer1::IPluginCreator {
158237 return nullptr ;
159238 }
160239
161- nvinfer1::IPluginV2* createPlugin (const char * name, std::vector<int64_t > in_shape, std::vector<int64_t > out_shape, std::string mode, bool align_corners) {
162- return new InterpolatePlugin (name, in_shape, out_shape, mode, align_corners);
240+ nvinfer1::IPluginV2* createPlugin (const char * name, std::vector<int64_t > in_shape, std::vector<int64_t > out_shape, std::vector<int64_t > size, std::string mode, bool align_corners) {
241+ name = name;
242+ return new InterpolatePlugin (in_shape, out_shape, size, mode, align_corners);
163243 }
164244
165245 nvinfer1::IPluginV2* deserializePlugin (const char * name, const void *serialData, size_t serialLength) override {
166- return nullptr ;
246+ name = name;
247+ return new InterpolatePlugin ((const char *) serialData, serialLength);
167248 }
168249
169250 const nvinfer1::PluginFieldCollection* getFieldNames () override {
170251 return nullptr ;
171252 }
172- }
253+ };
173254
174255REGISTER_TENSORRT_PLUGIN (InterpolatePluginCreator);
175256
0 commit comments