@@ -138,6 +138,22 @@ def make_one_hot_labels(
138138 yield make_one_hot_label (extra_dims_ )
139139
140140
141+ def make_segmentation_mask (size = None , * , num_categories = 80 , extra_dims = (), dtype = torch .long ):
142+ size = size or torch .randint (16 , 33 , (2 ,)).tolist ()
143+ shape = (* extra_dims , 1 , * size )
144+ data = make_tensor (shape , low = 0 , high = num_categories , dtype = dtype )
145+ return features .SegmentationMask (data )
146+
147+
148+ def make_segmentation_masks (
149+ image_sizes = ((16 , 16 ), (7 , 33 ), (31 , 9 )),
150+ dtypes = (torch .long ,),
151+ extra_dims = ((), (4 ,), (2 , 3 )),
152+ ):
153+ for image_size , dtype , extra_dims_ in itertools .product (image_sizes , dtypes , extra_dims ):
154+ yield make_segmentation_mask (size = image_size , dtype = dtype , extra_dims = extra_dims_ )
155+
156+
141157class SampleInput :
142158 def __init__ (self , * args , ** kwargs ):
143159 self .args = args
@@ -212,7 +228,7 @@ def resize_bounding_box():
212228@register_kernel_info_from_sample_inputs_fn
213229def affine_image_tensor ():
214230 for image , angle , translate , scale , shear in itertools .product (
215- make_images (extra_dims = ()),
231+ make_images (extra_dims = ((), ( 4 ,) )),
216232 [- 87 , 15 , 90 ], # angle
217233 [5 , - 5 ], # translate
218234 [0.77 , 1.27 ], # scale
@@ -248,6 +264,24 @@ def affine_bounding_box():
248264 )
249265
250266
267+ @register_kernel_info_from_sample_inputs_fn
268+ def affine_segmentation_mask ():
269+ for image , angle , translate , scale , shear in itertools .product (
270+ make_segmentation_masks (extra_dims = ((), (4 ,))),
271+ [- 87 , 15 , 90 ], # angle
272+ [5 , - 5 ], # translate
273+ [0.77 , 1.27 ], # scale
274+ [0 , 12 ], # shear
275+ ):
276+ yield SampleInput (
277+ image ,
278+ angle = angle ,
279+ translate = (translate , translate ),
280+ scale = scale ,
281+ shear = (shear , shear ),
282+ )
283+
284+
251285@register_kernel_info_from_sample_inputs_fn
252286def rotate_bounding_box ():
253287 for bounding_box , angle , expand , center in itertools .product (
@@ -444,6 +478,76 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
444478 np .testing .assert_allclose (out_box .cpu ().numpy (), a_out_box )
445479
446480
481+ @pytest .mark .parametrize ("angle" , [- 54 , 56 ])
482+ @pytest .mark .parametrize ("translate" , [- 7 , 8 ])
483+ @pytest .mark .parametrize ("scale" , [0.89 , 1.12 ])
484+ @pytest .mark .parametrize ("shear" , [4 ])
485+ @pytest .mark .parametrize ("center" , [None , (12 , 14 )])
486+ def test_correctness_affine_segmentation_mask (angle , translate , scale , shear , center ):
487+ def _compute_expected_mask (mask , angle_ , translate_ , scale_ , shear_ , center_ ):
488+ assert mask .ndim == 3 and mask .shape [0 ] == 1
489+ affine_matrix = _compute_affine_matrix (angle_ , translate_ , scale_ , shear_ , center_ )
490+ inv_affine_matrix = np .linalg .inv (affine_matrix )
491+ inv_affine_matrix = inv_affine_matrix [:2 , :]
492+
493+ expected_mask = torch .zeros_like (mask .cpu ())
494+ for out_y in range (expected_mask .shape [1 ]):
495+ for out_x in range (expected_mask .shape [2 ]):
496+ output_pt = np .array ([out_x + 0.5 , out_y + 0.5 , 1.0 ])
497+ input_pt = np .floor (np .dot (inv_affine_matrix , output_pt )).astype (np .int32 )
498+ in_x , in_y = input_pt [:2 ]
499+ if 0 <= in_x < mask .shape [2 ] and 0 <= in_y < mask .shape [1 ]:
500+ expected_mask [0 , out_y , out_x ] = mask [0 , in_y , in_x ]
501+ return expected_mask .to (mask .device )
502+
503+ for mask in make_segmentation_masks (extra_dims = ((), (4 ,))):
504+ output_mask = F .affine_segmentation_mask (
505+ mask ,
506+ angle = angle ,
507+ translate = (translate , translate ),
508+ scale = scale ,
509+ shear = (shear , shear ),
510+ center = center ,
511+ )
512+ if center is None :
513+ center = [s // 2 for s in mask .shape [- 2 :][::- 1 ]]
514+
515+ if mask .ndim < 4 :
516+ masks = [mask ]
517+ else :
518+ masks = [m for m in mask ]
519+
520+ expected_masks = []
521+ for mask in masks :
522+ expected_mask = _compute_expected_mask (mask , angle , (translate , translate ), scale , (shear , shear ), center )
523+ expected_masks .append (expected_mask )
524+ if len (expected_masks ) > 1 :
525+ expected_masks = torch .stack (expected_masks )
526+ else :
527+ expected_masks = expected_masks [0 ]
528+ torch .testing .assert_close (output_mask , expected_masks )
529+
530+
531+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
532+ def test_correctness_affine_segmentation_mask_on_fixed_input (device ):
533+ # Check transformation against known expected output and CPU/CUDA devices
534+
535+ # Create a fixed input segmentation mask with 2 square masks
536+ # in top-left, bottom-left corners
537+ mask = torch .zeros (1 , 32 , 32 , dtype = torch .long , device = device )
538+ mask [0 , 2 :10 , 2 :10 ] = 1
539+ mask [0 , 32 - 9 : 32 - 3 , 3 :9 ] = 2
540+
541+ # Rotate 90 degrees and scale
542+ expected_mask = torch .rot90 (mask , k = - 1 , dims = (- 2 , - 1 ))
543+ expected_mask = torch .nn .functional .interpolate (expected_mask [None , :].float (), size = (64 , 64 ), mode = "nearest" )
544+ expected_mask = expected_mask [0 , :, 16 : 64 - 16 , 16 : 64 - 16 ].long ()
545+
546+ out_mask = F .affine_segmentation_mask (mask , 90 , [0.0 , 0.0 ], 64.0 / 32.0 , [0.0 , 0.0 ])
547+
548+ torch .testing .assert_close (out_mask , expected_mask )
549+
550+
447551@pytest .mark .parametrize ("angle" , range (- 90 , 90 , 56 ))
448552@pytest .mark .parametrize ("expand" , [True , False ])
449553@pytest .mark .parametrize ("center" , [None , (12 , 14 )])
0 commit comments