11#include " torch/torch.h"
22#include " core/util/prelude.h"
33#include " core/conversion/converters/converters.h"
4+ #include " NvInfer.h"
5+ #include " plugins/interpolate_plugin.h"
46
57#include < csignal>
68
@@ -108,7 +110,7 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
108110 auto in = args[0 ].ITensor ();
109111 auto in_shape = util::toVec (in->getDimensions ());
110112
111- bool align_corners = args[2 ].IValue ()-> to < bool > ();
113+ bool align_corners = args[2 ].unwrapToBool ();
112114
113115 // Case 1: user uses output size and not scales
114116 if (!args[1 ].IValue ()->isNone () && args[3 ].IValue ()->isNone ()) {
@@ -119,16 +121,29 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
119121 auto out_shape = in_shape;
120122 std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
121123
122- auto resize_layer = ctx->net ->addResize (*in);
123- TRTORCH_CHECK (resize_layer, " Unable to create interpolation (resizing) layer from node" << *n);
124+ if (!align_corners) {
125+ // auto creator = getPluginRegistry()->getPluginCreator("interpolate", "1");
126+ // auto* plugin = creator->createPlugin(util::node_info(n).c_str(), in_shape, out_shape, out_size, std::string("linear"), align_corners);
127+ auto creator = new plugins::InterpolatePluginCreator ();
124128
125- resize_layer->setOutputDimensions (util::toDims (out_shape));
126- resize_layer->setResizeMode (nvinfer1::ResizeMode::kLINEAR );
127- resize_layer->setAlignCorners (align_corners);
128- resize_layer->setName (util::node_info (n).c_str ());
129+ auto plugin = creator->createPlugin (util::node_info (n).c_str (), in_shape, out_shape, out_size, std::string (" linear" ), align_corners);
129130
130- auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
131- LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
131+ auto resize_layer = ctx->net ->addPluginV2 (reinterpret_cast <nvinfer1::ITensor* const *>(in), 1 , *plugin);
132+
133+ auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
134+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
135+ } else {
136+ auto resize_layer = ctx->net ->addResize (*in);
137+ TRTORCH_CHECK (resize_layer, " Unable to create interpolation (resizing) layer from node" << *n);
138+
139+ resize_layer->setOutputDimensions (util::toDims (out_shape));
140+ resize_layer->setResizeMode (nvinfer1::ResizeMode::kLINEAR );
141+ resize_layer->setAlignCorners (align_corners);
142+ resize_layer->setName (util::node_info (n).c_str ());
143+
144+ auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
145+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
146+ }
132147 } else {
133148 TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_linear1d not supported yet." );
134149 }
0 commit comments