@@ -95,7 +95,7 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch
9595 cx = torch .randint (1 , width - 1 , ())
9696 cy = torch .randint (1 , height - 1 , ())
9797 w = randint_with_tensor_bounds (1 , torch .minimum (cx , width - cx ) + 1 )
98- h = randint_with_tensor_bounds (1 , torch .minimum (cy , width - cy ) + 1 )
98+ h = randint_with_tensor_bounds (1 , torch .minimum (cy , height - cy ) + 1 )
9999 parts = (cx , cy , w , h )
100100 else :
101101 raise pytest .UsageError ()
@@ -413,6 +413,14 @@ def perspective_segmentation_mask():
413413 )
414414
415415
416+ @register_kernel_info_from_sample_inputs_fn
417+ def center_crop_bounding_box ():
418+ for bounding_box , output_size in itertools .product (make_bounding_boxes (), [(24 , 12 ), [16 , 18 ], [46 , 48 ], [12 ]]):
419+ yield SampleInput (
420+ bounding_box , format = bounding_box .format , output_size = output_size , image_size = bounding_box .image_size
421+ )
422+
423+
416424@pytest .mark .parametrize (
417425 "kernel" ,
418426 [
@@ -1273,3 +1281,59 @@ def _compute_expected_mask(mask, pcoeffs_):
12731281 else :
12741282 expected_masks = expected_masks [0 ]
12751283 torch .testing .assert_close (output_mask , expected_masks )
1284+
1285+
1286+ @pytest .mark .parametrize ("device" , cpu_and_gpu ())
1287+ @pytest .mark .parametrize (
1288+ "output_size" ,
1289+ [(18 , 18 ), [18 , 15 ], (16 , 19 ), [12 ], [46 , 48 ]],
1290+ )
1291+ def test_correctness_center_crop_bounding_box (device , output_size ):
1292+ def _compute_expected_bbox (bbox , output_size_ ):
1293+ format_ = bbox .format
1294+ image_size_ = bbox .image_size
1295+ bbox = convert_bounding_box_format (bbox , format_ , features .BoundingBoxFormat .XYWH )
1296+
1297+ if len (output_size_ ) == 1 :
1298+ output_size_ .append (output_size_ [- 1 ])
1299+
1300+ cy = int (round ((image_size_ [0 ] - output_size_ [0 ]) * 0.5 ))
1301+ cx = int (round ((image_size_ [1 ] - output_size_ [1 ]) * 0.5 ))
1302+ out_bbox = [
1303+ bbox [0 ].item () - cx ,
1304+ bbox [1 ].item () - cy ,
1305+ bbox [2 ].item (),
1306+ bbox [3 ].item (),
1307+ ]
1308+ out_bbox = features .BoundingBox (
1309+ out_bbox ,
1310+ format = features .BoundingBoxFormat .XYWH ,
1311+ image_size = output_size_ ,
1312+ dtype = bbox .dtype ,
1313+ device = bbox .device ,
1314+ )
1315+ return convert_bounding_box_format (out_bbox , features .BoundingBoxFormat .XYWH , format_ , copy = False )
1316+
1317+ for bboxes in make_bounding_boxes (
1318+ image_sizes = [(32 , 32 ), (24 , 33 ), (32 , 25 )],
1319+ extra_dims = ((4 ,),),
1320+ ):
1321+ bboxes = bboxes .to (device )
1322+ bboxes_format = bboxes .format
1323+ bboxes_image_size = bboxes .image_size
1324+
1325+ output_boxes = F .center_crop_bounding_box (bboxes , bboxes_format , output_size , bboxes_image_size )
1326+
1327+ if bboxes .ndim < 2 :
1328+ bboxes = [bboxes ]
1329+
1330+ expected_bboxes = []
1331+ for bbox in bboxes :
1332+ bbox = features .BoundingBox (bbox , format = bboxes_format , image_size = bboxes_image_size )
1333+ expected_bboxes .append (_compute_expected_bbox (bbox , output_size ))
1334+
1335+ if len (expected_bboxes ) > 1 :
1336+ expected_bboxes = torch .stack (expected_bboxes )
1337+ else :
1338+ expected_bboxes = expected_bboxes [0 ]
1339+ torch .testing .assert_close (output_boxes , expected_bboxes )
0 commit comments