diff --git a/test/prototype_common_utils.py b/test/prototype_common_utils.py index 1d5766b1fcf..eba23d0cd52 100644 --- a/test/prototype_common_utils.py +++ b/test/prototype_common_utils.py @@ -350,7 +350,7 @@ def fn(shape, dtype, device): if any(dim == 0 for dim in extra_dims): return features.BoundingBox( - torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, image_size=image_size + torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=image_size ) height, width = image_size @@ -375,7 +375,7 @@ def fn(shape, dtype, device): parts = (cx, cy, w, h) return features.BoundingBox( - torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, image_size=image_size + torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=image_size ) return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, image_size=image_size) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index c8cca77e0db..62721c90f4a 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -145,7 +145,7 @@ def sample_inputs_horizontal_flip_bounding_box(): formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] ): yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size + bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.spatial_size ) @@ -201,7 +201,7 @@ def sample_inputs_resize_image_tensor(): for image_loader in make_image_loaders( sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32] ): - for size in _get_resize_sizes(image_loader.image_size): + for size in _get_resize_sizes(image_loader.spatial_size): yield ArgsKwargs(image_loader, size=size) for image_loader, interpolation in itertools.product( @@ -212,7 +212,7 @@ def sample_inputs_resize_image_tensor(): F.InterpolationMode.BICUBIC, ], ): - yield ArgsKwargs(image_loader, size=[min(image_loader.image_size) + 1], interpolation=interpolation) + yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation) yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25) @@ -236,7 +236,7 @@ def reference_inputs_resize_image_tensor(): F.InterpolationMode.BICUBIC, ], ): - for size in _get_resize_sizes(image_loader.image_size): + for size in _get_resize_sizes(image_loader.spatial_size): yield ArgsKwargs( image_loader, size=size, @@ -251,8 +251,8 @@ def reference_inputs_resize_image_tensor(): def sample_inputs_resize_bounding_box(): for bounding_box_loader in make_bounding_box_loaders(): - for size in _get_resize_sizes(bounding_box_loader.image_size): - yield ArgsKwargs(bounding_box_loader, size=size, image_size=bounding_box_loader.image_size) + for size in _get_resize_sizes(bounding_box_loader.spatial_size): + yield ArgsKwargs(bounding_box_loader, size=size, image_size=bounding_box_loader.spatial_size) def sample_inputs_resize_mask(): @@ -394,7 +394,7 @@ def sample_inputs_affine_bounding_box(): yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, + image_size=bounding_box_loader.spatial_size, **affine_params, ) @@ -473,7 +473,7 @@ def reference_inputs_affine_bounding_box(): yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, + image_size=bounding_box_loader.spatial_size, **affine_kwargs, ) @@ -650,7 +650,7 @@ def sample_inputs_vertical_flip_bounding_box(): formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32] ): yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size + bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.spatial_size ) @@ -729,7 +729,7 @@ def sample_inputs_rotate_bounding_box(): yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, + image_size=bounding_box_loader.spatial_size, angle=_ROTATE_ANGLES[0], ) @@ -1001,7 +1001,7 @@ def sample_inputs_pad_bounding_box(): yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, + image_size=bounding_box_loader.spatial_size, padding=padding, padding_mode="constant", ) @@ -1137,7 +1137,7 @@ def _get_elastic_displacement(image_size): def sample_inputs_elastic_image_tensor(): for image_loader in make_image_loaders(sizes=["random"]): - displacement = _get_elastic_displacement(image_loader.image_size) + displacement = _get_elastic_displacement(image_loader.spatial_size) for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: yield ArgsKwargs(image_loader, displacement=displacement, fill=fill) @@ -1151,14 +1151,14 @@ def reference_inputs_elastic_image_tensor(): F.InterpolationMode.BICUBIC, ], ): - displacement = _get_elastic_displacement(image_loader.image_size) + displacement = _get_elastic_displacement(image_loader.spatial_size) for fill in [None, 128.0, 128, [12.0], [12.0 + c for c in range(image_loader.num_channels)]]: yield ArgsKwargs(image_loader, interpolation=interpolation, displacement=displacement, fill=fill) def sample_inputs_elastic_bounding_box(): for bounding_box_loader in make_bounding_box_loaders(): - displacement = _get_elastic_displacement(bounding_box_loader.image_size) + displacement = _get_elastic_displacement(bounding_box_loader.spatial_size) yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, @@ -1241,7 +1241,7 @@ def sample_inputs_center_crop_bounding_box(): yield ArgsKwargs( bounding_box_loader, format=bounding_box_loader.format, - image_size=bounding_box_loader.image_size, + image_size=bounding_box_loader.spatial_size, output_size=output_size, ) @@ -1820,7 +1820,7 @@ def sample_inputs_adjust_saturation_video(): def sample_inputs_clamp_bounding_box(): for bounding_box_loader in make_bounding_box_loaders(): yield ArgsKwargs( - bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size + bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.spatial_size ) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d7a41e7c12c..d5e49078259 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -298,7 +298,7 @@ def test_features_mask(self, p): assert_equal(features.Mask(expected), actual) def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) + input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) transform = transforms.RandomHorizontalFlip(p=p) actual = transform(input) @@ -307,7 +307,7 @@ def test_features_bounding_box(self, p): expected = features.BoundingBox.wrap_like(input, expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format - assert actual.image_size == expected.image_size + assert actual.spatial_size == expected.spatial_size @pytest.mark.parametrize("p", [0.0, 1.0]) @@ -351,7 +351,7 @@ def test_features_mask(self, p): assert_equal(features.Mask(expected), actual) def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) + input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, spatial_size=(10, 10)) transform = transforms.RandomVerticalFlip(p=p) actual = transform(input) @@ -360,7 +360,7 @@ def test_features_bounding_box(self, p): expected = features.BoundingBox.wrap_like(input, expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format - assert actual.image_size == expected.image_size + assert actual.spatial_size == expected.spatial_size class TestPad: @@ -435,7 +435,7 @@ def test__get_params(self, fill, side_range, mocker): transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) image = mocker.MagicMock(spec=features.Image) - h, w = image.image_size = (24, 32) + h, w = image.spatial_size = (24, 32) params = transform._get_params(image) @@ -450,7 +450,7 @@ def test__get_params(self, fill, side_range, mocker): def test__transform(self, fill, side_range, mocker): inpt = mocker.MagicMock(spec=features.Image) inpt.num_channels = 3 - inpt.image_size = (24, 32) + inpt.spatial_size = (24, 32) transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) @@ -562,14 +562,14 @@ def test__transform(self, degrees, expand, fill, center, mocker): def test_boundingbox_image_size(self, angle, expand): # Specific test for BoundingBox.rotate bbox = features.BoundingBox( - torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, image_size=(32, 32) + torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, spatial_size=(32, 32) ) img = features.Image(torch.rand(1, 3, 32, 32)) out_img = img.rotate(angle, expand=expand) out_bbox = bbox.rotate(angle, expand=expand) - assert out_img.image_size == out_bbox.image_size + assert out_img.spatial_size == out_bbox.spatial_size class TestRandomAffine: @@ -619,8 +619,8 @@ def test_assertions(self): def test__get_params(self, degrees, translate, scale, shear, mocker): image = mocker.MagicMock(spec=features.Image) image.num_channels = 3 - image.image_size = (24, 32) - h, w = image.image_size + image.spatial_size = (24, 32) + h, w = image.spatial_size transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear) params = transform._get_params(image) @@ -682,7 +682,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker fn = mocker.patch("torchvision.prototype.transforms.functional.affine") inpt = mocker.MagicMock(spec=features.Image) inpt.num_channels = 3 - inpt.image_size = (24, 32) + inpt.spatial_size = (24, 32) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users @@ -718,8 +718,8 @@ def test_assertions(self): def test__get_params(self, padding, pad_if_needed, size, mocker): image = mocker.MagicMock(spec=features.Image) image.num_channels = 3 - image.image_size = (24, 32) - h, w = image.image_size + image.spatial_size = (24, 32) + h, w = image.spatial_size transform = transforms.RandomCrop(size, padding=padding, pad_if_needed=pad_if_needed) params = transform._get_params(image) @@ -771,19 +771,19 @@ def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker): inpt = mocker.MagicMock(spec=features.Image) inpt.num_channels = 3 - inpt.image_size = (32, 32) + inpt.spatial_size = (32, 32) expected = mocker.MagicMock(spec=features.Image) expected.num_channels = 3 if isinstance(padding, int): - expected.image_size = (inpt.image_size[0] + padding, inpt.image_size[1] + padding) + expected.spatial_size = (inpt.spatial_size[0] + padding, inpt.spatial_size[1] + padding) elif isinstance(padding, list): - expected.image_size = ( - inpt.image_size[0] + sum(padding[0::2]), - inpt.image_size[1] + sum(padding[1::2]), + expected.spatial_size = ( + inpt.spatial_size[0] + sum(padding[0::2]), + inpt.spatial_size[1] + sum(padding[1::2]), ) else: - expected.image_size = inpt.image_size + expected.spatial_size = inpt.spatial_size _ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected) fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop") @@ -859,7 +859,7 @@ def test__transform(self, kernel_size, sigma, mocker): fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") inpt = mocker.MagicMock(spec=features.Image) inpt.num_channels = 3 - inpt.image_size = (24, 32) + inpt.spatial_size = (24, 32) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users @@ -910,11 +910,11 @@ def test__get_params(self, mocker): transform = transforms.RandomPerspective(dscale) image = mocker.MagicMock(spec=features.Image) image.num_channels = 3 - image.image_size = (24, 32) + image.spatial_size = (24, 32) params = transform._get_params(image) - h, w = image.image_size + h, w = image.spatial_size assert "perspective_coeffs" in params assert len(params["perspective_coeffs"]) == 8 @@ -927,7 +927,7 @@ def test__transform(self, distortion_scale, mocker): fn = mocker.patch("torchvision.prototype.transforms.functional.perspective") inpt = mocker.MagicMock(spec=features.Image) inpt.num_channels = 3 - inpt.image_size = (24, 32) + inpt.spatial_size = (24, 32) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users # Otherwise, we can mock transform._get_params @@ -971,11 +971,11 @@ def test__get_params(self, mocker): transform = transforms.ElasticTransform(alpha, sigma) image = mocker.MagicMock(spec=features.Image) image.num_channels = 3 - image.image_size = (24, 32) + image.spatial_size = (24, 32) params = transform._get_params(image) - h, w = image.image_size + h, w = image.spatial_size displacement = params["displacement"] assert displacement.shape == (1, h, w, 2) assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() @@ -1001,7 +1001,7 @@ def test__transform(self, alpha, sigma, mocker): fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") inpt = mocker.MagicMock(spec=features.Image) inpt.num_channels = 3 - inpt.image_size = (24, 32) + inpt.spatial_size = (24, 32) # Let's mock transform._get_params to control the output: transform._get_params = mocker.MagicMock() @@ -1030,7 +1030,7 @@ def test_assertions(self, mocker): image = mocker.MagicMock(spec=features.Image) image.num_channels = 3 - image.image_size = (24, 32) + image.spatial_size = (24, 32) transform = transforms.RandomErasing(value=[1, 2, 3, 4]) @@ -1041,7 +1041,7 @@ def test_assertions(self, mocker): def test__get_params(self, value, mocker): image = mocker.MagicMock(spec=features.Image) image.num_channels = 3 - image.image_size = (24, 32) + image.spatial_size = (24, 32) transform = transforms.RandomErasing(value=value) params = transform._get_params(image) @@ -1057,8 +1057,8 @@ def test__get_params(self, value, mocker): elif isinstance(value, (list, tuple)): assert v.shape == (image.num_channels, 1, 1) - assert 0 <= i <= image.image_size[0] - h - assert 0 <= j <= image.image_size[1] - w + assert 0 <= i <= image.spatial_size[0] - h + assert 0 <= j <= image.spatial_size[1] - w @pytest.mark.parametrize("p", [0, 1]) def test__transform(self, mocker, p): @@ -1222,11 +1222,11 @@ class TestRandomIoUCrop: def test__get_params(self, device, options, mocker): image = mocker.MagicMock(spec=features.Image) image.num_channels = 3 - image.image_size = (24, 32) + image.spatial_size = (24, 32) bboxes = features.BoundingBox( torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), format="XYXY", - image_size=image.image_size, + spatial_size=image.spatial_size, device=device, ) sample = [image, bboxes] @@ -1245,8 +1245,8 @@ def test__get_params(self, device, options, mocker): assert len(params["is_within_crop_area"]) > 0 assert params["is_within_crop_area"].dtype == torch.bool - orig_h = image.image_size[0] - orig_w = image.image_size[1] + orig_h = image.spatial_size[0] + orig_w = image.spatial_size[1] assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h) assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w) @@ -1261,7 +1261,7 @@ def test__get_params(self, device, options, mocker): def test__transform_empty_params(self, mocker): transform = transforms.RandomIoUCrop(sampler_options=[2.0]) image = features.Image(torch.rand(1, 3, 4, 4)) - bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4)) + bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", spatial_size=(4, 4)) label = features.Label(torch.tensor([1])) sample = [image, bboxes, label] # Let's mock transform._get_params to control the output: @@ -1504,7 +1504,7 @@ def test__copy_paste(self, label_type): labels = torch.nn.functional.one_hot(labels, num_classes=5) target = { "boxes": features.BoundingBox( - torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32) + torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32) ), "masks": features.Mask(masks), "labels": label_type(labels), @@ -1519,7 +1519,7 @@ def test__copy_paste(self, label_type): paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5) paste_target = { "boxes": features.BoundingBox( - torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32) + torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32) ), "masks": features.Mask(paste_masks), "labels": label_type(paste_labels), diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 8329de69782..b1b06b6288b 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -342,7 +342,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): [1, 1, 5, 5], ] in_boxes = features.BoundingBox( - in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64, device=device + in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=image_size, dtype=torch.float64, device=device ) # Tested parameters angle = 63 @@ -369,7 +369,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): output_boxes = F.affine_bounding_box( in_boxes, in_boxes.format, - in_boxes.image_size, + in_boxes.spatial_size, angle, (dx * image_size[1], dy * image_size[0]), scale, @@ -406,7 +406,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) affine_matrix = affine_matrix[:2, :] - height, width = bbox.image_size + height, width = bbox.spatial_size bbox_xyxy = convert_format_bounding_box( bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY ) @@ -444,7 +444,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): out_bbox = features.BoundingBox( out_bbox, format=features.BoundingBoxFormat.XYXY, - image_size=(height, width), + spatial_size=(height, width), dtype=bbox.dtype, device=bbox.device, ) @@ -459,7 +459,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)): bboxes_format = bboxes.format - bboxes_image_size = bboxes.image_size + bboxes_image_size = bboxes.spatial_size output_bboxes, output_image_size = F.rotate_bounding_box( bboxes, @@ -479,7 +479,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_): expected_bboxes = [] for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_image_size) expected_bbox, expected_image_size = _compute_expected_bbox(bbox, -angle, expand, center_) expected_bboxes.append(expected_bbox) if len(expected_bboxes) > 1: @@ -503,7 +503,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): [image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10], ] in_boxes = features.BoundingBox( - in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64, device=device + in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=image_size, dtype=torch.float64, device=device ) # Tested parameters angle = 45 @@ -535,7 +535,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): output_boxes, _ = F.rotate_bounding_box( in_boxes, in_boxes.format, - in_boxes.image_size, + in_boxes.spatial_size, angle, expand=expand, center=center, @@ -593,7 +593,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, [50.0, 5.0, 70.0, 22.0], [45.0, 46.0, 56.0, 62.0], ] - in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size, device=device) + in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=size, device=device) if format != features.BoundingBoxFormat.XYXY: in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) @@ -670,7 +670,7 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_): expected_bboxes = torch.tensor(expected_bboxes, device=device) in_boxes = features.BoundingBox( - in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, device=device + in_boxes, format=features.BoundingBoxFormat.XYXY, spatial_size=image_size, device=device ) if format != features.BoundingBoxFormat.XYXY: in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format) @@ -720,13 +720,13 @@ def _compute_expected_bbox(bbox, padding_): def _compute_expected_image_size(bbox, padding_): pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_) - height, width = bbox.image_size + height, width = bbox.spatial_size return height + pad_up + pad_down, width + pad_left + pad_right for bboxes in make_bounding_boxes(): bboxes = bboxes.to(device) bboxes_format = bboxes.format - bboxes_image_size = bboxes.image_size + bboxes_image_size = bboxes.spatial_size output_boxes, output_image_size = F.pad_bounding_box( bboxes, format=bboxes_format, image_size=bboxes_image_size, padding=padding @@ -739,7 +739,7 @@ def _compute_expected_image_size(bbox, padding_): expected_bboxes = [] for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_image_size) expected_bboxes.append(_compute_expected_bbox(bbox, padding)) if len(expected_bboxes) > 1: @@ -807,7 +807,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): out_bbox = features.BoundingBox( np.array(out_bbox), format=features.BoundingBoxFormat.XYXY, - image_size=bbox.image_size, + spatial_size=bbox.spatial_size, dtype=bbox.dtype, device=bbox.device, ) @@ -823,7 +823,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): for bboxes in make_bounding_boxes(image_size=image_size, extra_dims=((4,),)): bboxes = bboxes.to(device) bboxes_format = bboxes.format - bboxes_image_size = bboxes.image_size + bboxes_image_size = bboxes.spatial_size output_bboxes = F.perspective_bounding_box( bboxes, @@ -836,7 +836,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): expected_bboxes = [] for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_image_size) expected_bboxes.append(_compute_expected_bbox(bbox, inv_pcoeffs)) if len(expected_bboxes) > 1: expected_bboxes = torch.stack(expected_bboxes) @@ -853,7 +853,7 @@ def _compute_expected_bbox(bbox, pcoeffs_): def test_correctness_center_crop_bounding_box(device, output_size): def _compute_expected_bbox(bbox, output_size_): format_ = bbox.format - image_size_ = bbox.image_size + image_size_ = bbox.spatial_size bbox = convert_format_bounding_box(bbox, format_, features.BoundingBoxFormat.XYWH) if len(output_size_) == 1: @@ -870,7 +870,7 @@ def _compute_expected_bbox(bbox, output_size_): out_bbox = features.BoundingBox( out_bbox, format=features.BoundingBoxFormat.XYWH, - image_size=output_size_, + spatial_size=output_size_, dtype=bbox.dtype, device=bbox.device, ) @@ -879,7 +879,7 @@ def _compute_expected_bbox(bbox, output_size_): for bboxes in make_bounding_boxes(extra_dims=((4,),)): bboxes = bboxes.to(device) bboxes_format = bboxes.format - bboxes_image_size = bboxes.image_size + bboxes_image_size = bboxes.spatial_size output_boxes, output_image_size = F.center_crop_bounding_box( bboxes, bboxes_format, bboxes_image_size, output_size @@ -890,7 +890,7 @@ def _compute_expected_bbox(bbox, output_size_): expected_bboxes = [] for bbox in bboxes: - bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + bbox = features.BoundingBox(bbox, format=bboxes_format, spatial_size=bboxes_image_size) expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) if len(expected_bboxes) > 1: diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py index 9a8ed67dde2..5301559999a 100644 --- a/test/test_prototype_transforms_utils.py +++ b/test/test_prototype_transforms_utils.py @@ -11,8 +11,8 @@ IMAGE = make_image(color_space=features.ColorSpace.RGB) -BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size) -MASK = make_detection_mask(size=IMAGE.image_size) +BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.spatial_size) +MASK = make_detection_mask(size=IMAGE.spatial_size) @pytest.mark.parametrize( diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index a00bf2e2cc9..823ec12cc4d 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -110,7 +110,7 @@ def _prepare_sample( image=image, ann_path=ann_path, bounding_box=BoundingBox( - ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size + ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", spatial_size=image.spatial_size ), contour=_Feature(ann["obj_contour"].T), ) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index a0a0218458b..3382b62b6ce 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -144,7 +144,7 @@ def _prepare_sample( bounding_box=BoundingBox( [int(bounding_box[key]) for key in ("x_1", "y_1", "width", "height")], format="xywh", - image_size=image.image_size, + spatial_size=image.spatial_size, ), landmarks={ landmark: _Feature((int(landmarks[f"{landmark}_x"]), int(landmarks[f"{landmark}_y"]))) diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 16a16998bf7..641c72670d5 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -97,25 +97,25 @@ def _resources(self) -> List[OnlineResource]: ) return [images, meta] - def _segmentation_to_mask(self, segmentation: Any, *, is_crowd: bool, image_size: Tuple[int, int]) -> torch.Tensor: + def _segmentation_to_mask(self, segmentation: Any, *, is_crowd: bool, spatial_size: Tuple[int, int]) -> torch.Tensor: from pycocotools import mask if is_crowd: - segmentation = mask.frPyObjects(segmentation, *image_size) + segmentation = mask.frPyObjects(segmentation, *spatial_size) else: - segmentation = mask.merge(mask.frPyObjects(segmentation, *image_size)) + segmentation = mask.merge(mask.frPyObjects(segmentation, *spatial_size)) return torch.from_numpy(mask.decode(segmentation)).to(torch.bool) def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[str, Any]) -> Dict[str, Any]: - image_size = (image_meta["height"], image_meta["width"]) + spatial_size = (image_meta["height"], image_meta["width"]) labels = [ann["category_id"] for ann in anns] return dict( # TODO: create a segmentation feature segmentations=_Feature( torch.stack( [ - self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], image_size=image_size) + self._segmentation_to_mask(ann["segmentation"], is_crowd=ann["iscrowd"], spatial_size=spatial_size) for ann in anns ] ) @@ -125,7 +125,7 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st bounding_boxes=BoundingBox( [ann["bbox"] for ann in anns], format="xywh", - image_size=image_size, + spatial_size=spatial_size, ), labels=Label(labels, categories=self._categories), super_categories=[self._category_to_super_category[self._categories[label]] for label in labels], diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index f1531615c23..4260c0e885a 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -130,13 +130,13 @@ def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: return path.with_suffix(".jpg").name def _2011_prepare_ann( - self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], image_size: Tuple[int, int] + self, data: Tuple[str, Tuple[List[str], Tuple[str, BinaryIO]]], spatial_size: Tuple[int, int] ) -> Dict[str, Any]: _, (bounding_box_data, segmentation_data) = data segmentation_path, segmentation_buffer = segmentation_data return dict( bounding_box=BoundingBox( - [float(part) for part in bounding_box_data[1:]], format="xywh", image_size=image_size + [float(part) for part in bounding_box_data[1:]], format="xywh", spatial_size=spatial_size ), segmentation_path=segmentation_path, segmentation=EncodedImage.from_file(segmentation_buffer), @@ -149,7 +149,7 @@ def _2010_anns_key(self, data: Tuple[str, BinaryIO]) -> Tuple[str, Tuple[str, Bi path = pathlib.Path(data[0]) return path.with_suffix(".jpg").name, data - def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: Tuple[int, int]) -> Dict[str, Any]: + def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], spatial_size: Tuple[int, int]) -> Dict[str, Any]: _, (path, buffer) = data content = read_mat(buffer) return dict( @@ -157,7 +157,7 @@ def _2010_prepare_ann(self, data: Tuple[str, Tuple[str, BinaryIO]], image_size: bounding_box=BoundingBox( [int(content["bbox"][coord]) for coord in ("left", "bottom", "right", "top")], format="xyxy", - image_size=image_size, + spatial_size=spatial_size, ), segmentation=_Feature(content["seg"]), ) @@ -175,7 +175,7 @@ def _prepare_sample( image = EncodedImage.from_file(buffer) return dict( - prepare_ann_fn(anns_data, image.image_size), + prepare_ann_fn(anns_data, image.spatial_size), image=image, label=Label( int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]) - 1, diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index 8dc0a8240c8..e11dc2bb4ca 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -78,7 +78,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[ bounding_box = BoundingBox( [int(csv_info[k]) for k in ("Roi.X1", "Roi.Y1", "Roi.X2", "Roi.Y2")], format="xyxy", - image_size=(int(csv_info["Height"]), int(csv_info["Width"])), + spatial_size=(int(csv_info["Height"]), int(csv_info["Width"])), ) return { diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 011204f2bfb..a0e7a377e48 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -89,7 +89,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int, path=path, image=image, label=Label(target[4] - 1, categories=self._categories), - bounding_box=BoundingBox(target[:4], format="xyxy", image_size=image.image_size), + bounding_box=BoundingBox(target[:4], format="xyxy", spatial_size=image.spatial_size), ) def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 84a9b3a7f51..8db82b4aac3 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -108,7 +108,7 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: for instance in instances ], format="xyxy", - image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), + spatial_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), ), labels=Label( [self._categories.index(instance["name"]) for instance in instances], categories=self._categories diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 7b69af5f9bb..8ab9cb6afcd 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -17,13 +17,13 @@ class BoundingBoxFormat(StrEnum): class BoundingBox(_Feature): format: BoundingBoxFormat - image_size: Tuple[int, int] + spatial_size: Tuple[int, int] @classmethod - def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, image_size: Tuple[int, int]) -> BoundingBox: + def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat, spatial_size: Tuple[int, int]) -> BoundingBox: bounding_box = tensor.as_subclass(cls) bounding_box.format = format - bounding_box.image_size = image_size + bounding_box.spatial_size = spatial_size return bounding_box def __new__( @@ -31,7 +31,7 @@ def __new__( data: Any, *, format: Union[BoundingBoxFormat, str], - image_size: Tuple[int, int], + spatial_size: Tuple[int, int], dtype: Optional[torch.dtype] = None, device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, @@ -41,7 +41,7 @@ def __new__( if isinstance(format, str): format = BoundingBoxFormat.from_str(format.upper()) - return cls._wrap(tensor, format=format, image_size=image_size) + return cls._wrap(tensor, format=format, spatial_size=spatial_size) @classmethod def wrap_like( @@ -50,16 +50,16 @@ def wrap_like( tensor: torch.Tensor, *, format: Optional[BoundingBoxFormat] = None, - image_size: Optional[Tuple[int, int]] = None, + spatial_size: Optional[Tuple[int, int]] = None, ) -> BoundingBox: return cls._wrap( tensor, format=format if format is not None else other.format, - image_size=image_size if image_size is not None else other.image_size, + spatial_size=spatial_size if spatial_size is not None else other.spatial_size, ) def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] - return self._make_repr(format=self.format, image_size=self.image_size) + return self._make_repr(format=self.format, spatial_size=self.spatial_size) def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: if isinstance(format, str): @@ -70,11 +70,11 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: ) def horizontal_flip(self) -> BoundingBox: - output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size) + output = self._F.horizontal_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size) return BoundingBox.wrap_like(self, output) def vertical_flip(self) -> BoundingBox: - output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size) + output = self._F.vertical_flip_bounding_box(self, format=self.format, spatial_size=self.spatial_size) return BoundingBox.wrap_like(self, output) def resize( # type: ignore[override] @@ -84,20 +84,20 @@ def resize( # type: ignore[override] max_size: Optional[int] = None, antialias: bool = False, ) -> BoundingBox: - output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size) - return BoundingBox.wrap_like(self, output, image_size=image_size) + output, spatial_size = self._F.resize_bounding_box(self, spatial_size=self.spatial_size, size=size, max_size=max_size) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: - output, image_size = self._F.crop_bounding_box( + output, spatial_size = self._F.crop_bounding_box( self, self.format, top=top, left=left, height=height, width=width ) - return BoundingBox.wrap_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) def center_crop(self, output_size: List[int]) -> BoundingBox: - output, image_size = self._F.center_crop_bounding_box( - self, format=self.format, image_size=self.image_size, output_size=output_size + output, spatial_size = self._F.center_crop_bounding_box( + self, format=self.format, spatial_size=self.spatial_size, output_size=output_size ) - return BoundingBox.wrap_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) def resized_crop( self, @@ -109,8 +109,8 @@ def resized_crop( interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: bool = False, ) -> BoundingBox: - output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) - return BoundingBox.wrap_like(self, output, image_size=image_size) + output, spatial_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) def pad( self, @@ -118,10 +118,10 @@ def pad( fill: FillTypeJIT = None, padding_mode: str = "constant", ) -> BoundingBox: - output, image_size = self._F.pad_bounding_box( - self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode + output, spatial_size = self._F.pad_bounding_box( + self, format=self.format, spatial_size=self.spatial_size, padding=padding, padding_mode=padding_mode ) - return BoundingBox.wrap_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) def rotate( self, @@ -131,10 +131,10 @@ def rotate( fill: FillTypeJIT = None, center: Optional[List[float]] = None, ) -> BoundingBox: - output, image_size = self._F.rotate_bounding_box( - self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center + output, spatial_size = self._F.rotate_bounding_box( + self, format=self.format, spatial_size=self.spatial_size, angle=angle, expand=expand, center=center ) - return BoundingBox.wrap_like(self, output, image_size=image_size) + return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) def affine( self, @@ -149,7 +149,7 @@ def affine( output = self._F.affine_bounding_box( self, self.format, - self.image_size, + self.spatial_size, angle, translate=translate, scale=scale, diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 4b963986b4f..9347b4eca6e 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -49,12 +49,12 @@ def from_path(cls: Type[D], path: Union[str, os.PathLike], **kwargs: Any) -> D: class EncodedImage(EncodedData): # TODO: Use @functools.cached_property if we can depend on Python 3.8 @property - def image_size(self) -> Tuple[int, int]: - if not hasattr(self, "_image_size"): + def spatial_size(self) -> Tuple[int, int]: + if not hasattr(self, "_spatial_size"): with PIL.Image.open(ReadOnlyTensorBuffer(self)) as image: - self._image_size = image.height, image.width + self._spatial_size = image.height, image.width - return self._image_size + return self._spatial_size class EncodedVideo(EncodedData): diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 23f81678d79..6d52a178b84 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -105,7 +105,7 @@ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[overr return self._make_repr(color_space=self.color_space) @property - def image_size(self) -> Tuple[int, int]: + def spatial_size(self) -> Tuple[int, int]: return cast(Tuple[int, int], tuple(self.shape[-2:])) @property diff --git a/torchvision/prototype/features/_mask.py b/torchvision/prototype/features/_mask.py index 7b49ce8e85e..2da10195e80 100644 --- a/torchvision/prototype/features/_mask.py +++ b/torchvision/prototype/features/_mask.py @@ -33,7 +33,7 @@ def wrap_like( return cls._wrap(tensor) @property - def image_size(self) -> Tuple[int, int]: + def spatial_size(self) -> Tuple[int, int]: return cast(Tuple[int, int], tuple(self.shape[-2:])) def horizontal_flip(self) -> Mask: diff --git a/torchvision/prototype/features/_video.py b/torchvision/prototype/features/_video.py index e32c36d5d9f..ca4253c73bb 100644 --- a/torchvision/prototype/features/_video.py +++ b/torchvision/prototype/features/_video.py @@ -54,9 +54,8 @@ def wrap_like(cls, other: Video, tensor: torch.Tensor, *, color_space: Optional[ def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override] return self._make_repr(color_space=self.color_space) - # TODO: rename this (and all instances of this term to spatial size) @property - def image_size(self) -> Tuple[int, int]: + def spatial_size(self) -> Tuple[int, int]: return cast(Tuple[int, int], tuple(self.shape[-2:])) @property diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 371ea7f69c5..920af79436b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -690,7 +690,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(output, features.BoundingBox): bboxes = output[is_within_crop_area] - bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size) + bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size) output = features.BoundingBox.wrap_like(output, bboxes) elif isinstance(output, features.Mask): # apply is_within_crop_area if mask is one-hot encoded @@ -811,7 +811,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: bounding_boxes = features.BoundingBox.wrap_like( bounding_boxes, F.clamp_bounding_box( - bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size + bounding_boxes, format=bounding_boxes.format, spatial_size=bounding_boxes.spatial_size ), ) height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:] @@ -851,7 +851,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: elif isinstance(inpt, features.BoundingBox): inpt = features.BoundingBox.wrap_like( inpt, - F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size), + F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, spatial_size=inpt.spatial_size), ) if params["needs_pad"]: diff --git a/torchvision/prototype/transforms/_meta.py b/torchvision/prototype/transforms/_meta.py index e5c7d05b017..dc109269f79 100644 --- a/torchvision/prototype/transforms/_meta.py +++ b/torchvision/prototype/transforms/_meta.py @@ -68,5 +68,5 @@ class ClampBoundingBoxes(Transform): _transformed_types = (features.BoundingBox,) def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> features.BoundingBox: - output = F.clamp_bounding_box(inpt, format=inpt.format, image_size=inpt.image_size) + output = F.clamp_bounding_box(inpt, format=inpt.format, spatial_size=inpt.spatial_size) return features.BoundingBox.wrap_like(inpt, output) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 44b4986aba0..590a13310a2 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -32,7 +32,7 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: def horizontal_flip_bounding_box( - bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] + bounding_box: torch.Tensor, format: features.BoundingBoxFormat, spatial_size: Tuple[int, int] ) -> torch.Tensor: shape = bounding_box.shape @@ -40,7 +40,7 @@ def horizontal_flip_bounding_box( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY ).view(-1, 4) - bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]] + bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]] return convert_format_bounding_box( bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False @@ -69,7 +69,7 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: def vertical_flip_bounding_box( - bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] + bounding_box: torch.Tensor, format: features.BoundingBoxFormat, spatial_size: Tuple[int, int] ) -> torch.Tensor: shape = bounding_box.shape @@ -77,7 +77,7 @@ def vertical_flip_bounding_box( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY ).view(-1, 4) - bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]] + bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]] return convert_format_bounding_box( bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False @@ -104,11 +104,11 @@ def vertical_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def _compute_resized_output_size( - image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None + spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None ) -> List[int]: if isinstance(size, int): size = [size] - return __compute_resized_output_size(image_size, size=size, max_size=max_size) + return __compute_resized_output_size(spatial_size, size=size, max_size=max_size) def resize_image_tensor( @@ -162,10 +162,10 @@ def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = N def resize_bounding_box( - bounding_box: torch.Tensor, image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None + bounding_box: torch.Tensor, spatial_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None ) -> Tuple[torch.Tensor, Tuple[int, int]]: - old_height, old_width = image_size - new_height, new_width = _compute_resized_output_size(image_size, size=size, max_size=max_size) + old_height, old_width = spatial_size + new_height, new_width = _compute_resized_output_size(spatial_size, size=size, max_size=max_size) ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) return ( bounding_box.view(-1, 2, 2).mul(ratios).to(bounding_box.dtype).view(bounding_box.shape), @@ -312,7 +312,7 @@ def affine_image_pil( def _affine_bounding_box_xyxy( bounding_box: torch.Tensor, - image_size: Tuple[int, int], + spatial_size: Tuple[int, int], angle: Union[int, float], translate: List[float], scale: float, @@ -325,7 +325,7 @@ def _affine_bounding_box_xyxy( ) if center is None: - height, width = image_size + height, width = spatial_size center = [width * 0.5, height * 0.5] dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 @@ -359,7 +359,7 @@ def _affine_bounding_box_xyxy( if expand: # Compute minimum point for transformed image frame: # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. - height, width = image_size + height, width = spatial_size points = torch.tensor( [ [0.0, 0.0, 1.0], @@ -378,15 +378,15 @@ def _affine_bounding_box_xyxy( # Estimate meta-data for image with inverted=True and with center=[0,0] affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height) - image_size = (new_height, new_width) + spatial_size = (new_height, new_width) - return out_bboxes.to(bounding_box.dtype), image_size + return out_bboxes.to(bounding_box.dtype), spatial_size def affine_bounding_box( bounding_box: torch.Tensor, format: features.BoundingBoxFormat, - image_size: Tuple[int, int], + spatial_size: Tuple[int, int], angle: Union[int, float], translate: List[float], scale: float, @@ -398,7 +398,7 @@ def affine_bounding_box( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY ).view(-1, 4) - out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, image_size, angle, translate, scale, shear, center) + out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center) # out_bboxes should be of shape [N boxes, 4] @@ -573,7 +573,7 @@ def rotate_image_pil( def rotate_bounding_box( bounding_box: torch.Tensor, format: features.BoundingBoxFormat, - image_size: Tuple[int, int], + spatial_size: Tuple[int, int], angle: float, expand: bool = False, center: Optional[List[float]] = None, @@ -587,9 +587,9 @@ def rotate_bounding_box( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY ).view(-1, 4) - out_bboxes, image_size = _affine_bounding_box_xyxy( + out_bboxes, spatial_size = _affine_bounding_box_xyxy( bounding_box, - image_size, + spatial_size, angle=-angle, translate=[0.0, 0.0], scale=1.0, @@ -602,7 +602,7 @@ def rotate_bounding_box( convert_format_bounding_box( out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False ).view(original_shape), - image_size, + spatial_size, ) @@ -756,7 +756,7 @@ def pad_mask( def pad_bounding_box( bounding_box: torch.Tensor, format: features.BoundingBoxFormat, - image_size: Tuple[int, int], + spatial_size: Tuple[int, int], padding: Union[int, List[int]], padding_mode: str = "constant", ) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -775,7 +775,7 @@ def pad_bounding_box( bounding_box[..., 2] += left bounding_box[..., 3] += top - height, width = image_size + height, width = spatial_size height += top + bottom width += left + right @@ -1066,10 +1066,10 @@ def elastic_bounding_box( ).view(-1, 4) # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it - # Or add image_size arg and check displacement shape - image_size = displacement.shape[-3], displacement.shape[-2] + # Or add spatial_size arg and check displacement shape + spatial_size = displacement.shape[-3], displacement.shape[-2] - id_grid = _FT._create_identity_grid(list(image_size)).to(bounding_box.device) + id_grid = _FT._create_identity_grid(list(spatial_size)).to(bounding_box.device) # We construct an approximation of inverse grid as inv_grid = id_grid - displacement # This is not an exact inverse of the grid inv_grid = id_grid - displacement @@ -1079,7 +1079,7 @@ def elastic_bounding_box( index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long) index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long) # Transform points: - t_size = torch.tensor(image_size[::-1], device=displacement.device, dtype=displacement.dtype) + t_size = torch.tensor(spatial_size[::-1], device=displacement.device, dtype=displacement.dtype) transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5 transformed_points = transformed_points.view(-1, 4, 2) @@ -1199,11 +1199,11 @@ def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL def center_crop_bounding_box( bounding_box: torch.Tensor, format: features.BoundingBoxFormat, - image_size: Tuple[int, int], + spatial_size: Tuple[int, int], output_size: List[int], ) -> Tuple[torch.Tensor, Tuple[int, int]]: crop_height, crop_width = _center_crop_parse_output_size(output_size) - crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size) + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *spatial_size) return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left, height=crop_height, width=crop_width) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index c03d65c951b..435834d118a 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -18,7 +18,7 @@ def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]: return get_dimensions_image_tensor(image) elif isinstance(image, (features.Image, features.Video)): channels = image.num_channels - height, width = image.image_size + height, width = image.spatial_size return [channels, height, width] else: return get_dimensions_image_pil(image) @@ -63,9 +63,9 @@ def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]: if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)): return get_spatial_size_image_tensor(inpt) elif isinstance(inpt, features._Feature): - image_size = getattr(inpt, "image_size", None) - if image_size is not None: - return list(image_size) + spatial_size = getattr(inpt, "spatial_size", None) + if spatial_size is not None: + return list(spatial_size) else: raise ValueError(f"Type {inpt.__class__} doesn't have spatial size.") else: @@ -125,13 +125,13 @@ def convert_format_bounding_box( def clamp_bounding_box( - bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int] + bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] ) -> torch.Tensor: # TODO: (PERF) Possible speed up clamping if we have different implementations for each bbox format. # Not sure if they yield equivalent results. xyxy_boxes = convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY) - xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1]) - xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0]) + xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) + xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False)