@@ -1101,17 +1101,6 @@ def _compute_expected_mask(mask, top_, left_, height_, width_, size_):
11011101 torch .testing .assert_close (output_mask , expected_mask )
11021102
11031103
1104- @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1105- def test_correctness_pad_segmentation_mask_on_fixed_input (device ):
1106- mask = torch .ones ((1 , 3 , 3 ), dtype = torch .long , device = device )
1107-
1108- out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ])
1109-
1110- expected_mask = torch .zeros ((1 , 5 , 5 ), dtype = torch .long , device = device )
1111- expected_mask [:, 1 :- 1 , 1 :- 1 ] = 1
1112- torch .testing .assert_close (out_mask , expected_mask )
1113-
1114-
11151104def _parse_padding (padding ):
11161105 if isinstance (padding , int ):
11171106 return [padding ] * 4
@@ -1168,25 +1157,71 @@ def _compute_expected_bbox(bbox, padding_):
11681157 torch .testing .assert_close (output_boxes , expected_bboxes )
11691158
11701159
1160+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1161+ def test_correctness_pad_segmentation_mask_on_fixed_input (device ):
1162+ mask = torch .ones ((1 , 3 , 3 ), dtype = torch .long , device = device )
1163+
1164+ out_mask = F .pad_segmentation_mask (mask , padding = [1 , 1 , 1 , 1 ])
1165+
1166+ expected_mask = torch .zeros ((1 , 5 , 5 ), dtype = torch .long , device = device )
1167+ expected_mask [:, 1 :- 1 , 1 :- 1 ] = 1
1168+ torch .testing .assert_close (out_mask , expected_mask )
1169+
1170+
11711171@pytest .mark .parametrize ("padding" , [[1 , 2 , 3 , 4 ], [1 ], 1 , [1 , 2 ]])
1172- def test_correctness_pad_segmentation_mask (padding ):
1173- def _compute_expected_mask (mask , padding_ ):
1172+ @pytest .mark .parametrize ("padding_mode" , ["constant" , "edge" , "reflect" , "symmetric" ])
1173+ def test_correctness_pad_segmentation_mask (padding , padding_mode ):
1174+ def _compute_expected_mask (mask , padding_ , padding_mode_ ):
11741175 h , w = mask .shape [- 2 ], mask .shape [- 1 ]
11751176 pad_left , pad_up , pad_right , pad_down = _parse_padding (padding_ )
11761177
1178+ if any (pad <= 0 for pad in [pad_left , pad_up , pad_right , pad_down ]):
1179+ raise pytest .UsageError (
1180+ "Expected output can be computed on positive pad values only, "
1181+ "but F.pad_* can also crop for negative values"
1182+ )
1183+
11771184 new_h = h + pad_up + pad_down
11781185 new_w = w + pad_left + pad_right
11791186
11801187 new_shape = (* mask .shape [:- 2 ], new_h , new_w ) if len (mask .shape ) > 2 else (new_h , new_w )
1181- expected_mask = torch .zeros (new_shape , dtype = torch .long )
1182- expected_mask [..., pad_up :- pad_down , pad_left :- pad_right ] = mask
1188+ output = torch .zeros (new_shape , dtype = mask .dtype )
1189+ output [..., pad_up :- pad_down , pad_left :- pad_right ] = mask
1190+
1191+ if padding_mode_ == "edge" :
1192+ # pad top-left corner, left vertical block, bottom-left corner
1193+ output [..., :pad_up , :pad_left ] = mask [..., 0 , 0 ].unsqueeze (- 1 ).unsqueeze (- 2 )
1194+ output [..., pad_up :- pad_down , :pad_left ] = mask [..., :, 0 ].unsqueeze (- 1 )
1195+ output [..., - pad_down :, :pad_left ] = mask [..., - 1 , 0 ].unsqueeze (- 1 ).unsqueeze (- 2 )
1196+ # pad top-right corner, right vertical block, bottom-right corner
1197+ output [..., :pad_up , - pad_right :] = mask [..., 0 , - 1 ].unsqueeze (- 1 ).unsqueeze (- 2 )
1198+ output [..., pad_up :- pad_down , - pad_right :] = mask [..., :, - 1 ].unsqueeze (- 1 )
1199+ output [..., - pad_down :, - pad_right :] = mask [..., - 1 , - 1 ].unsqueeze (- 1 ).unsqueeze (- 2 )
1200+ # pad top and bottom horizontal blocks
1201+ output [..., :pad_up , pad_left :- pad_right ] = mask [..., 0 , :].unsqueeze (- 2 )
1202+ output [..., - pad_down :, pad_left :- pad_right ] = mask [..., - 1 , :].unsqueeze (- 2 )
1203+ elif padding_mode_ in ("reflect" , "symmetric" ):
1204+ d1 = 1 if padding_mode_ == "reflect" else 0
1205+ d2 = - 1 if padding_mode_ == "reflect" else None
1206+ both = (- 1 , - 2 )
1207+ # pad top-left corner, left vertical block, bottom-left corner
1208+ output [..., :pad_up , :pad_left ] = mask [..., d1 : pad_up + d1 , d1 : pad_left + d1 ].flip (both )
1209+ output [..., pad_up :- pad_down , :pad_left ] = mask [..., :, d1 : pad_left + d1 ].flip (- 1 )
1210+ output [..., - pad_down :, :pad_left ] = mask [..., - pad_down - d1 : d2 , d1 : pad_left + d1 ].flip (both )
1211+ # pad top-right corner, right vertical block, bottom-right corner
1212+ output [..., :pad_up , - pad_right :] = mask [..., d1 : pad_up + d1 , - pad_right - d1 : d2 ].flip (both )
1213+ output [..., pad_up :- pad_down , - pad_right :] = mask [..., :, - pad_right - d1 : d2 ].flip (- 1 )
1214+ output [..., - pad_down :, - pad_right :] = mask [..., - pad_down - d1 : d2 , - pad_right - d1 : d2 ].flip (both )
1215+ # pad top and bottom horizontal blocks
1216+ output [..., :pad_up , pad_left :- pad_right ] = mask [..., d1 : pad_up + d1 , :].flip (- 2 )
1217+ output [..., - pad_down :, pad_left :- pad_right ] = mask [..., - pad_down - d1 : d2 , :].flip (- 2 )
11831218
1184- return expected_mask
1219+ return output
11851220
11861221 for mask in make_segmentation_masks ():
1187- out_mask = F .pad_segmentation_mask (mask , padding , "constant" )
1222+ out_mask = F .pad_segmentation_mask (mask , padding , padding_mode = padding_mode )
11881223
1189- expected_mask = _compute_expected_mask (mask , padding )
1224+ expected_mask = _compute_expected_mask (mask , padding , padding_mode )
11901225 torch .testing .assert_close (out_mask , expected_mask )
11911226
11921227
0 commit comments