@@ -332,6 +332,12 @@ def crop_bounding_box():
332332 )
333333
334334
335+ @register_kernel_info_from_sample_inputs_fn
336+ def vertical_flip_segmentation_mask ():
337+ for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
338+ yield SampleInput (mask )
339+
340+
335341@pytest .mark .parametrize (
336342 "kernel" ,
337343 [
@@ -860,3 +866,26 @@ def test_correctness_crop_bounding_box(device, top, left, height, width, expecte
860866 )
861867
862868 torch .testing .assert_close (output_boxes .tolist (), expected_bboxes )
869+
870+
871+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
872+ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input (device ):
873+ mask = torch .tensor (
874+ [
875+ [[1 , 1 , 1 , 1 , 1 ], [0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ]],
876+ [[1 , 1 , 1 , 1 , 1 ], [0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ]],
877+ [[1 , 1 , 1 , 1 , 1 ], [0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ]],
878+ ],
879+ device = device ,
880+ )
881+
882+ expected_mask = torch .tensor (
883+ [
884+ [[0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 , 1 ]],
885+ [[0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 , 1 ]],
886+ [[0 , 0 , 0 , 0 , 0 ], [0 , 0 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 , 1 ]],
887+ ],
888+ device = device ,
889+ )
890+ out_mask = F .vertical_flip_segmentation_mask (mask )
891+ torch .testing .assert_close (out_mask , expected_mask )
0 commit comments