@@ -352,6 +352,32 @@ def vertical_flip_segmentation_mask():
352352 yield SampleInput (mask )
353353
354354
355+ @register_kernel_info_from_sample_inputs_fn
356+ def pad_segmentation_mask ():
357+ for mask , padding , fill , padding_mode in itertools .product (
358+ make_segmentation_masks (),
359+ [[1 ], [1 , 1 ], [1 , 1 , 2 , 2 ]], # padding
360+ [0 , 1 ], # fill
361+ ["constant" , "symmetric" , "edge" ], # padding mode,
362+ ):
363+ if padding_mode == "symmetric" and mask .ndim not in [3 , 4 ]:
364+ continue
365+ if padding_mode == "edge" and fill != 0 :
366+ continue
367+ if (
368+ padding_mode == "edge"
369+ and len (padding ) == 2
370+ and mask .ndim not in [2 , 3 ]
371+ or len (padding ) == 4
372+ and mask .ndim not in [4 , 3 ]
373+ or len (padding ) == 1
374+ ):
375+ continue
376+ if padding_mode == "edge" and mask .ndim not in [2 , 3 , 4 , 5 ]:
377+ continue
378+ yield SampleInput (mask , padding = padding , fill = fill , padding_mode = padding_mode )
379+
380+
355381@pytest .mark .parametrize (
356382 "kernel" ,
357383 [
@@ -933,3 +959,15 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
933959 expected_mask = torch .zeros ((3 , 3 , 3 ), dtype = torch .long , device = device )
934960 expected_mask [:, - 1 , :] = 1
935961 torch .testing .assert_close (out_mask , expected_mask )
962+
963+
964+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
965+ def test_correctness_pad_segmentation_mask_on_fixed_input (device ):
966+ mask = torch .ones ((1 , 3 , 3 ), dtype = torch .long , device = device )
967+ mask [:, 1 , 1 ] = 0
968+
969+ out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ], fill = 1 )
970+
971+ expected_mask = torch .ones ((1 , 3 + 1 + 1 , 3 + 1 + 1 ), dtype = torch .long , device = device )
972+ expected_mask [:, 2 , 2 ] = 0
973+ torch .testing .assert_close (out_mask , expected_mask )
0 commit comments