Skip to content

Commit dad25f6

Browse files
borisfomnarendasan
authored andcommitted
Fixing issue #552
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
1 parent c16ed0b commit dad25f6

File tree

3 files changed

+5
-2
lines changed

3 files changed

+5
-2
lines changed

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
4040
LOG_DEBUG("momentum disregarded");
4141
LOG_DEBUG("training disregarded");
4242
LOG_DEBUG("cudnn disregarded");
43+
TRTORCH_CHECK(orig_shape.nbDims > 2 , "Unable to create batch normalization layer from node: " << *n);
4344

4445
// Expand spatial dims from 1D to 2D if needed
4546
bool expandDims = (orig_shape.nbDims < 4);

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
3030
LOG_DEBUG("out_padding: " << out_padding);
3131
LOG_DEBUG("groups: " << groups);
3232

33-
// Expand spatial dims from 1D to 2D if needed
33+
TRTORCH_CHECK(orig_dims.nbDims > 2 , "Unable to create convolution layer from node: " << *n);
34+
3435
bool expandDims = (orig_dims.nbDims < 4);
3536
if (expandDims) {
3637
in = addPadding(ctx, n, in, 4);

core/conversion/converters/impl/pooling.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ bool AdaptivePoolingConverter(
4848

4949
auto orig_dims = in->getDimensions();
5050
bool expandDims = (orig_dims.nbDims < 4);
51-
51+
TRTORCH_CHECK(orig_dims.nbDims > 2 , "Unable to create pooling layer from node: " << *n);
5252
if (expandDims) {
5353
in = addPadding(ctx, n, in, 4, false, false);
5454
}
@@ -122,6 +122,7 @@ bool PoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& args,
122122

123123
// Max Pool needs at least 4D input
124124
auto orig_dims = in->getDimensions();
125+
TRTORCH_CHECK(orig_dims.nbDims > 2 , "Unable to create pooling layer from node: " << *n);
125126
bool expandDims = (orig_dims.nbDims < 4);
126127

127128
if (expandDims) {

0 commit comments

Comments
 (0)