diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index e65fb2d6a6a..8852f9864c8 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -835,7 +835,7 @@ def sample_inputs_rotate_video(): F.rotate_bounding_box, sample_inputs_fn=sample_inputs_rotate_bounding_box, closeness_kwargs={ - **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6), + **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), **scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), }, ), diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index e6d2321fc80..4663cdac3da 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -1,5 +1,7 @@ import pytest import torch + +from PIL import Image from torchvision.prototype import datapoints @@ -130,3 +132,30 @@ def test_wrap_like(): assert type(label_new) is datapoints.Label assert label_new.data_ptr() == output.data_ptr() assert label_new.categories is label.categories + + +@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)]) +def test_image_instance(data): + image = datapoints.Image(data) + assert isinstance(image, torch.Tensor) + assert image.ndim == 3 and image.shape[0] == 3 + + +@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)]) +def test_mask_instance(data): + mask = datapoints.Mask(data) + assert isinstance(mask, torch.Tensor) + assert mask.ndim == 3 and mask.shape[0] == 1 + + +@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]]) +@pytest.mark.parametrize( + "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] +) +def test_bbox_instance(data, format): + bboxes = datapoints.BoundingBox(data, format=format, spatial_size=(32, 32)) + assert isinstance(bboxes, torch.Tensor) + assert bboxes.ndim == 2 and bboxes.shape[1] == 4 + if isinstance(format, str): + format = datapoints.BoundingBoxFormat.from_str(format.upper()) + assert bboxes.format == format diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 910d2bb24ac..e60d61e5f90 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -99,7 +99,7 @@ def identity(item): def pil_image_to_mask(pil_image): - return datapoints.Mask(F.to_image_tensor(pil_image).squeeze(0)) + return datapoints.Mask(pil_image) def list_of_dicts_to_dict_of_lists(list_of_dicts): diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index 4ffeb37d5eb..e999d8243e3 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -23,6 +23,11 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: Optional[bool] = None, ) -> Image: + if isinstance(data, PIL.Image.Image): + from torchvision.prototype.transforms import functional as F + + data = F.pil_to_tensor(data) + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if tensor.ndim < 2: raise ValueError diff --git a/torchvision/prototype/datapoints/_mask.py b/torchvision/prototype/datapoints/_mask.py index 834f990512b..55476cd503d 100644 --- a/torchvision/prototype/datapoints/_mask.py +++ b/torchvision/prototype/datapoints/_mask.py @@ -2,6 +2,7 @@ from typing import Any, List, Optional, Tuple, Union +import PIL.Image import torch from torchvision.transforms import InterpolationMode @@ -21,6 +22,11 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: Optional[bool] = None, ) -> Mask: + if isinstance(data, PIL.Image.Image): + from torchvision.prototype.transforms import functional as F + + data = F.pil_to_tensor(data) + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor)