diff --git a/aten/src/ATen/native/mps/operations/Pad.mm b/aten/src/ATen/native/mps/operations/Pad.mm index 25cccfd6f4424..36500e864f16e 100644 --- a/aten/src/ATen/native/mps/operations/Pad.mm +++ b/aten/src/ATen/native/mps/operations/Pad.mm @@ -107,23 +107,20 @@ grad_output = grad_output_.contiguous(); } - const int64_t input_dim = input.dim(); - MPSShape *leftPadding = nullptr, *rightPadding = nullptr; - if (padding_dim == 3) { - leftPadding = [NSArray arrayWithObjects:(const NSNumber*[]){ @(0), @(0), @(pad_front), @(pad_t), @(pad_l) } count:input_dim]; - rightPadding = [NSArray arrayWithObjects:(const NSNumber*[]){ @(0), @(0), @(pad_back), @(pad_b), @(pad_r) } count:input_dim]; - } else if (padding_dim == 2) { - leftPadding = [NSArray arrayWithObjects:(const NSNumber*[]){ @(0), @(0), @(pad_t), @(pad_l) } count:input_dim]; - rightPadding = [NSArray arrayWithObjects:(const NSNumber*[]){ @(0), @(0), @(pad_b), @(pad_r) } count:input_dim]; - } else if (padding_dim == 1) { - if (input_dim > 1) { - leftPadding = [NSArray arrayWithObjects:(const NSNumber*[]){ @(0), @(0), @(pad_l) } count:input_dim]; - rightPadding = [NSArray arrayWithObjects:(const NSNumber*[]){ @(0), @(0), @(pad_r) } count:input_dim]; - } else { - leftPadding = [NSArray arrayWithObjects:(const NSNumber*[]){ @(pad_l) } count:input_dim]; - rightPadding = [NSArray arrayWithObjects:(const NSNumber*[]){ @(pad_r) } count:input_dim]; - } + std::vector leftPadVec(ndims, @(0)); + std::vector rightPadVec(ndims, @(0)); + leftPadVec [ndims - 1] = @(pad_l); + rightPadVec[ndims - 1] = @(pad_r); + if (padding_dim >= 2) { + leftPadVec [ndims - 2] = @(pad_t); + rightPadVec[ndims - 2] = @(pad_b); + } + if (padding_dim >= 3) { + leftPadVec [ndims - 3] = @(pad_front); + rightPadVec[ndims - 3] = @(pad_back); } + MPSShape *leftPadding = [NSArray arrayWithObjects:leftPadVec.data() count:ndims]; + MPSShape *rightPadding = [NSArray arrayWithObjects:rightPadVec.data() count:ndims]; struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) { } diff --git a/test/test_mps.py b/test/test_mps.py index fed8c8730962c..f6a38524f390a 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -3482,6 +3482,8 @@ def helper(shape, padding, op, value=0): helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d) # Constant Pad 2D helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d) + # input size < pad size + helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d) # 3D Padding helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)