From 10f42eeaf446060450ea0da278025e535456732a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 13 Feb 2023 15:54:31 +0100 Subject: [PATCH 1/4] Fixes "make Image(some_pil_image)" from TODO list (https://github.com/pytorch/vision/issues/7217) Added tests --- test/test_prototype_datapoints.py | 23 ++++++++++++++++++++++ torchvision/prototype/datapoints/_image.py | 5 +++++ torchvision/prototype/datapoints/_mask.py | 6 ++++++ 3 files changed, 34 insertions(+) diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index e6d2321fc80..728269b497c 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -1,5 +1,7 @@ import pytest import torch + +from PIL import Image from torchvision.prototype import datapoints @@ -130,3 +132,24 @@ def test_wrap_like(): assert type(label_new) is datapoints.Label assert label_new.data_ptr() == output.data_ptr() assert label_new.categories is label.categories + + +@pytest.mark.parametrize("data", [torch.rand(3, 32, 32), Image.new("RGB", (32, 32), color=123)]) +def test_image_instance(data): + image = datapoints.Image(data) + assert isinstance(image, torch.Tensor) + assert image.ndim == 3 and image.shape[0] == 3 + + +@pytest.mark.parametrize("data", [torch.randint(0, 10, size=(1, 32, 32)), Image.new("L", (32, 32), color=2)]) +def test_mask_instance(data): + mask = datapoints.Mask(data) + assert isinstance(mask, torch.Tensor) + assert mask.ndim == 3 and mask.shape[0] == 1 + + +@pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]]) +def test_bbox_instance(data): + bboxes = datapoints.BoundingBox(data, format="XYXY", spatial_size=(32, 32)) + assert isinstance(bboxes, torch.Tensor) + assert bboxes.ndim == 2 and bboxes.shape[1] == 4 diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index 4ffeb37d5eb..318e7660245 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -23,6 +23,11 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: Optional[bool] = None, ) -> Image: + if isinstance(data, PIL.Image.Image): + from ..transforms import functional as F + + data = F.pil_to_tensor(data) + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) if tensor.ndim < 2: raise ValueError diff --git a/torchvision/prototype/datapoints/_mask.py b/torchvision/prototype/datapoints/_mask.py index 834f990512b..600cec73418 100644 --- a/torchvision/prototype/datapoints/_mask.py +++ b/torchvision/prototype/datapoints/_mask.py @@ -2,6 +2,7 @@ from typing import Any, List, Optional, Tuple, Union +import PIL.Image import torch from torchvision.transforms import InterpolationMode @@ -21,6 +22,11 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: Optional[bool] = None, ) -> Mask: + if isinstance(data, PIL.Image.Image): + from ..transforms import functional as F + + data = F.pil_to_tensor(data) + tensor = cls._to_tensor(data, dtype=dtype, device=device, requires_grad=requires_grad) return cls._wrap(tensor) From 43eeaa7f6e1c5afacca221d57bfbcc616f5cb7e1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 13 Feb 2023 16:01:38 +0100 Subject: [PATCH 2/4] Updated dataset wrapper --- torchvision/prototype/datapoints/_dataset_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 1159261054b..638390c66cf 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -99,7 +99,7 @@ def identity(item): def pil_image_to_mask(pil_image): - return datapoints.Mask(F.to_image_tensor(pil_image).squeeze(0)) + return datapoints.Mask(pil_image) def list_of_dicts_to_dict_of_lists(list_of_dicts): From 2780607c6b4988b9dc751ecfaabcc0f4ac69e3b6 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 13 Feb 2023 16:22:24 +0100 Subject: [PATCH 3/4] Addressed review comments --- test/test_prototype_datapoints.py | 10 ++++++++-- torchvision/prototype/datapoints/_image.py | 2 +- torchvision/prototype/datapoints/_mask.py | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_datapoints.py b/test/test_prototype_datapoints.py index 728269b497c..4663cdac3da 100644 --- a/test/test_prototype_datapoints.py +++ b/test/test_prototype_datapoints.py @@ -149,7 +149,13 @@ def test_mask_instance(data): @pytest.mark.parametrize("data", [torch.randint(0, 32, size=(5, 4)), [[0, 0, 5, 5], [2, 2, 7, 7]]]) -def test_bbox_instance(data): - bboxes = datapoints.BoundingBox(data, format="XYXY", spatial_size=(32, 32)) +@pytest.mark.parametrize( + "format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH] +) +def test_bbox_instance(data, format): + bboxes = datapoints.BoundingBox(data, format=format, spatial_size=(32, 32)) assert isinstance(bboxes, torch.Tensor) assert bboxes.ndim == 2 and bboxes.shape[1] == 4 + if isinstance(format, str): + format = datapoints.BoundingBoxFormat.from_str(format.upper()) + assert bboxes.format == format diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index 318e7660245..e999d8243e3 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -24,7 +24,7 @@ def __new__( requires_grad: Optional[bool] = None, ) -> Image: if isinstance(data, PIL.Image.Image): - from ..transforms import functional as F + from torchvision.prototype.transforms import functional as F data = F.pil_to_tensor(data) diff --git a/torchvision/prototype/datapoints/_mask.py b/torchvision/prototype/datapoints/_mask.py index 600cec73418..55476cd503d 100644 --- a/torchvision/prototype/datapoints/_mask.py +++ b/torchvision/prototype/datapoints/_mask.py @@ -23,7 +23,7 @@ def __new__( requires_grad: Optional[bool] = None, ) -> Mask: if isinstance(data, PIL.Image.Image): - from ..transforms import functional as F + from torchvision.prototype.transforms import functional as F data = F.pil_to_tensor(data) From 01aba01d5c5c89e853818bca7c3ade47e7ddc248 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 13 Feb 2023 16:29:05 +0100 Subject: [PATCH 4/4] Increase tolerence for flaky cpu-bbox-rotation eager vs scripted tests --- test/prototype_transforms_kernel_infos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index eddf76440c5..8ae65a30ec3 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -854,7 +854,7 @@ def sample_inputs_rotate_video(): F.rotate_bounding_box, sample_inputs_fn=sample_inputs_rotate_bounding_box, closeness_kwargs={ - **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-6, rtol=1e-6), + **scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), **scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), }, ),