@@ -163,13 +163,45 @@ auto interpolate_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
163163 auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
164164 LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
165165 } else {
166- TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_linear1d not supported yet." );
166+ TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_bilinear2d not supported yet." );
167167 }
168168
169169 return true ;
170170 }
171- });
171+ }).pattern({
172+ " aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)" ,
173+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
174+ auto in = args[0 ].ITensor ();
175+ auto in_shape = util::toVec (in->getDimensions ());
176+
177+ bool align_corners = args[2 ].IValue ()->to <bool >();
172178
179+ // Case 1: user uses output size and not scales_d, scales_h, scales_w
180+ if (!args[1 ].IValue ()->isNone () && args[3 ].IValue ()->isNone () && args[4 ].IValue ()->isNone () && args[5 ].IValue ()->isNone ()) {
181+ auto out_size = util::toVec (util::toDims (args[1 ].unwrapToIntList ()));
182+
183+ TRTORCH_ASSERT (out_size.size () == 3 , " aten::upsample_trilinear3d input Tensor and output size dimension mismatch" );
184+
185+ auto out_shape = in_shape;
186+ std::copy (out_size.begin (), out_size.end (), out_shape.begin () + (in_shape.size () - out_size.size ()));
187+
188+ auto resize_layer = ctx->net ->addResize (*in);
189+ TRTORCH_CHECK (resize_layer, " Unable to create interpolation (resizing) layer from node" << *n);
190+
191+ resize_layer->setOutputDimensions (util::toDims (out_shape));
192+ resize_layer->setResizeMode (nvinfer1::ResizeMode::kLINEAR );
193+ resize_layer->setAlignCorners (align_corners);
194+ resize_layer->setName (util::node_info (n).c_str ());
195+
196+ auto layer_output = ctx->AssociateValueAndTensor (n->outputs ()[0 ], resize_layer->getOutput (0 ));
197+ LOG_DEBUG (" Output tensor shape: " << layer_output->getDimensions ());
198+ } else {
199+ TRTORCH_THROW_ERROR (" Unable to convert node: " << util::node_info (n) << " \n Scale factor parameter for upsample_trilinear3d not supported yet." );
200+ }
201+
202+ return true ;
203+ }
204+ });
173205
174206} // namespace
175207} // namespace impl
0 commit comments