@@ -23,23 +23,23 @@ bool MaxPoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& ar
2323 }
2424
2525
26- auto kernel_size = util::toDimsHW (args[1 ].unwrapToIntList ());
26+ auto kernel_size = util::toDims (args[1 ].unwrapToIntList ());
2727 LOG_DEBUG (" kernel_size: " << kernel_size);
28- auto padding = util::toDimsHW (args[3 ].unwrapToIntList ());
28+ auto padding = util::toDims (args[3 ].unwrapToIntList ());
2929 LOG_DEBUG (" padding: " << padding);
3030 auto stride = util::toDims (args[2 ].unwrapToIntList ());
3131 LOG_DEBUG (" stride: " << stride);
3232
3333 auto dilation = util::toDims (args[4 ].unwrapToIntList ());
3434
35- TRTORCH_ASSERT (dilation == util::toDims (std::vector<int64_t >({ 1 , 1 } )), " Pooling dilation is not supported in TensorRT" );
35+ TRTORCH_ASSERT (dilation == util::toDims (std::vector<int64_t >(dilation. nbDims , 1 )), " Pooling dilation is not supported in TensorRT" );
3636
3737 LOG_DEBUG (" dilation: " << dilation);
3838 LOG_WARNING (" Dilation not used in max pooling converter" );
3939 bool ceil_mode = args[5 ].unwrapToBool ();
4040
4141 auto new_layer = ctx->net ->addPoolingNd (*in, nvinfer1::PoolingType::kMAX , kernel_size);
42- TRTORCH_CHECK (new_layer, " Unable to create Max Pool 2D layer from node: " << *n);
42+ TRTORCH_CHECK (new_layer, " Unable to create Max Pooling layer from node: " << *n);
4343
4444 new_layer->setName (util::node_info (n).c_str ());
4545 new_layer->setPaddingNd (padding);
@@ -77,9 +77,9 @@ bool AvgPoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& ar
7777 }
7878
7979
80- auto kernel_size = util::toDimsHW (args[1 ].unwrapToIntList ());
80+ auto kernel_size = util::toDims (args[1 ].unwrapToIntList ());
8181 LOG_DEBUG (" kernel_size: " << kernel_size);
82- auto padding = util::toDimsHW (args[3 ].unwrapToIntList ());
82+ auto padding = util::toDims (args[3 ].unwrapToIntList ());
8383 LOG_DEBUG (" padding: " << padding);
8484 auto stride = util::toDims (args[2 ].unwrapToIntList ());
8585 LOG_DEBUG (" stride: " << stride);
@@ -88,7 +88,7 @@ bool AvgPoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& ar
8888 bool count_inlcude_pad = args[5 ].unwrapToBool ();
8989
9090 auto new_layer = ctx->net ->addPoolingNd (*in, nvinfer1::PoolingType::kAVERAGE , kernel_size);
91- TRTORCH_CHECK (new_layer, " Unable to create Avg Pool 2D layer from node: " << *n);
91+ TRTORCH_CHECK (new_layer, " Unable to create Avg Pooling layer from node: " << *n);
9292
9393 new_layer->setName (util::node_info (n).c_str ());
9494 new_layer->setPaddingNd (padding);
@@ -118,12 +118,67 @@ bool AvgPoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& ar
118118
119119auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
120120 .pattern({
121- " aten::max_pool1d(Tensor self, int[2 ] kernel_size, int[2 ] stride=[], int[2 ] padding=[0, 0 ], int[2 ] dilation=[1, 1 ], bool ceil_mode=False) -> (Tensor)" ,
121+ " aten::max_pool1d(Tensor self, int[1 ] kernel_size, int[1 ] stride=[], int[1 ] padding=[], int[1 ] dilation=[], bool ceil_mode=False) -> (Tensor)" ,
122122 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
123- return MaxPoolingConverter (ctx, n, args);
123+ auto in = args[0 ].ITensor ();
124+ auto shape = util::toVec (in->getDimensions ());
125+
126+ // Max Pool needs at least 4D input
127+ if (shape.size () < 4 ) {
128+ auto new_shape = util::toDimsPad (shape, 4 );
129+ LOG_DEBUG (" Input shape is less than 4D got: " << util::toDims (shape) << " , inserting shuffle layer to reshape to 4D tensor shape: " << new_shape);
130+ auto shuffle = ctx->net ->addShuffle (*in);
131+ shuffle->setReshapeDimensions (new_shape);
132+ shuffle->setName ((util::node_info (n) + " [Reshape to " + util::toStr (new_shape) + ' ]' ).c_str ());
133+ in = shuffle->getOutput (0 );
134+ }
135+
136+ auto kernel_vec = args[1 ].unwrapToIntList ().vec ();
137+ kernel_vec.insert (kernel_vec.begin (), 1 );
138+ auto kernel_size = util::toDims (kernel_vec);
139+ LOG_DEBUG (" kernel_size: " << kernel_size);
140+ auto stride_vec = args[2 ].unwrapToIntList ().vec ();
141+ stride_vec.insert (stride_vec.begin (), 1 );
142+ auto stride = util::toDims (stride_vec);
143+ LOG_DEBUG (" stride: " << stride);
144+ auto padding_vec = args[3 ].unwrapToIntList ().vec ();
145+ padding_vec.insert (padding_vec.begin (), 0 );
146+ auto padding = util::toDims (padding_vec);
147+ LOG_DEBUG (" padding: " << padding);
148+
149+ auto dilation = util::toDims (args[4 ].unwrapToIntList ());
150+
151+ TRTORCH_ASSERT (dilation == util::toDims (std::vector<int64_t >(dilation.nbDims , 1 )), " Pooling dilation is not supported in TensorRT" );
152+
153+ LOG_DEBUG (" dilation: " << dilation);
154+ LOG_WARNING (" Dilation not used in max pooling converter" );
155+ bool ceil_mode = args[5 ].unwrapToBool ();
156+
157+ auto new_layer = ctx->net ->addPoolingNd (*in, nvinfer1::PoolingType::kMAX , kernel_size);
158+ TRTORCH_CHECK (new_layer, " Unable to create Max Pooling layer from node: " << *n);
159+
160+ new_layer->setName (util::node_info (n).c_str ());
161+ new_layer->setPaddingNd (padding);
162+ if (stride.nbDims != 2 && ctx->settings .device == nvinfer1::DeviceType::kDLA ) {
163+ if (!ctx->settings .allow_gpu_fallback ) {
164+ TRTORCH_THROW_ERROR (" DLA Pooling stride is limited to 2D, allow GPU fallback" );
165+ } else {
166+ LOG_WARNING (" DLA Pooling stride is limited to 2D, will run on GPU" );
167+ }
168+ }
169+ new_layer->setStrideNd (stride);
170+
171+ auto padding_mode = ceil_mode ? nvinfer1::PaddingMode::kEXPLICIT_ROUND_UP : nvinfer1::PaddingMode::kEXPLICIT_ROUND_DOWN ;
172+ new_layer->setPaddingMode (padding_mode);
173+
174+ new_layer->setName (util::node_info (n).c_str ());
175+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], new_layer->getOutput (0 ));
176+
177+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
178+ return true ;
124179 }
125180 }).pattern({
126- " aten::avg_pool1d(Tensor self, int[2 ] kernel_size, int[2 ] stride=[], int[2 ] padding=[ 0, 0], bool ceil_mode=False, bool count_include_pad=True) -> ( Tensor) " ,
181+ " aten::avg_pool1d(Tensor self, int[1 ] kernel_size, int[1 ] stride=[], int[1 ] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor" ,
127182 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
128183 auto in = args[0 ].ITensor ();
129184 auto shape = util::toVec (in->getDimensions ());
@@ -139,12 +194,18 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
139194 }
140195
141196
142- auto kernel_size = util::toDimsHW (args[1 ].unwrapToIntList ());
197+ auto kernel_vec = args[1 ].unwrapToIntList ().vec ();
198+ kernel_vec.insert (kernel_vec.begin (), 1 );
199+ auto kernel_size = util::toDims (kernel_vec);
143200 LOG_DEBUG (" kernel_size: " << kernel_size);
144- auto padding = util::toDimsHW ( args[3 ].unwrapToIntList ());
145- LOG_DEBUG ( " padding: " << padding );
146- auto stride = util::toDims (args[ 2 ]. unwrapToIntList () );
201+ auto stride_vec = args[2 ].unwrapToIntList (). vec ( );
202+ stride_vec. insert (stride_vec. begin (), 1 );
203+ auto stride = util::toDims (stride_vec );
147204 LOG_DEBUG (" stride: " << stride);
205+ auto padding_vec = args[3 ].unwrapToIntList ().vec ();
206+ padding_vec.insert (padding_vec.begin (), 0 );
207+ auto padding = util::toDims (padding_vec);
208+ LOG_DEBUG (" padding: " << padding);
148209
149210 bool ceil_mode = args[4 ].unwrapToBool ();
150211 bool count_inlcude_pad = args[5 ].unwrapToBool ();
@@ -187,12 +248,12 @@ auto pooling_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
187248 return AvgPoolingConverter (ctx, n, args);
188249 }
189250 }).pattern({
190- " aten::max_pool3d(Tensor self, int[2 ] kernel_size, int[2 ] stride=[], int[2 ] padding=[0, 0 ], int[2 ] dilation=[1, 1 ], bool ceil_mode=False) -> (Tensor)" ,
251+ " aten::max_pool3d(Tensor self, int[3 ] kernel_size, int[3 ] stride=[], int[3 ] padding=[], int[3 ] dilation=[], bool ceil_mode=False) -> (Tensor)" ,
191252 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
192253 return MaxPoolingConverter (ctx, n, args);
193254 }
194255 }).pattern({
195- " aten::avg_pool3d(Tensor self, int[2 ] kernel_size, int[2 ] stride=[], int[2 ] padding=[0, 0 ], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)" ,
256+ " aten::avg_pool3d(Tensor self, int[3 ] kernel_size, int[3 ] stride=[], int[3 ] padding=[], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)" ,
196257 [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
197258 return AvgPoolingConverter (ctx, n, args);
198259 }
0 commit comments